package com.edvorg.trade.common.client

import com.edvorg.trade.common.model.ConnectorStatus
import io.ktor.client.HttpClient
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.plugins.websocket.DefaultClientWebSocketSession
import io.ktor.client.plugins.websocket.WebSockets
import io.ktor.client.plugins.websocket.webSocket
import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.parameter
import io.ktor.http.Url
import io.ktor.websocket.Frame
import io.ktor.websocket.readBytes
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.getAndUpdate
import kotlinx.atomicfu.updateAndGet
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.channels.consumeEach
import kotlinx.coroutines.delay
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import mu.KotlinLogging
import kotlin.time.Duration.Companion.seconds

expect fun defaultHttpEngine(): HttpClientEngine

abstract class WebSocketClient(
    open val id: String,
    open val url: Url,
    private val disablePinger: Boolean,
    open val authCode: String? = null,
    private val retryConnectCount: Int? = null,
    private val requestSetup: (HttpRequestBuilder.() -> Unit)? = null,
) {
    companion object {
        private val logger = KotlinLogging.logger {}
        private val timeout = 3.seconds
    }

    private var pingsSkipped = 0

    private fun makeClient() = HttpClient(defaultHttpEngine()) {
        install(WebSockets) {
            maxFrameSize = Long.MAX_VALUE
            pingInterval = if (disablePinger) {
                -1
            } else {
                1000
            }
        }
    }

    private data class Connection(
        val scope: CoroutineScope,
        val onStatusUpdated: (ConnectorStatus) -> Unit,
        val session: DefaultClientWebSocketSession?,
    )

    private val connection = atomic<Connection?>(null)

    private fun getConnectionOrThrow(): Connection {
        return connection.value ?: throw Error("connection is not established")
    }

    private fun isConnectionAlive(): Boolean {
        return pingsSkipped <= 5
    }

    abstract fun processFrame(frame: Frame)
    abstract suspend fun onConnect()

    open fun HttpRequestBuilder.requestSetup() {
    }

    suspend fun startConnection(preStart: CoroutineScope.() -> Unit, onStatusUpdated: (ConnectorStatus) -> Unit) {
        logger.info {
            "websocket $id: connecting to websocket at $url"
        }

        connection.updateAndGet {
            onStatusUpdated(ConnectorStatus.Connecting)

            val scope = CoroutineScope(Dispatchers.Default)
            scope.preStart()
            Connection(
                scope,
                onStatusUpdated,
                null,
            )
        }?.let { (scope) ->
            scope.launch {
                try {
                    var retries = 0
                    while (isActive && (retryConnectCount == null || retries < retryConnectCount)) {
                        onStatusUpdated(ConnectorStatus.Connecting)
                        val client = makeClient()
                        try {
                            client.webSocket(
                                url.toString(),
                                {
                                    authCode?.let {
                                        parameter("auth_code", it)
                                    }
                                    requestSetup?.invoke(this)
                                    requestSetup()
                                },
                            ) {
                                val session = this

                                logger.info { "websocket $id: connection established" }

                                pingsSkipped = 0
                                retries = 0

                                connection.getAndSet(
                                    Connection(
                                        scope,
                                        onStatusUpdated,
                                        session,
                                    ),
                                )
                                onStatusUpdated(ConnectorStatus.Connected)
                                onConnect()

                                session.launch {
                                    delay(1.seconds)
                                    while (isActive) {
                                        ping(outgoing)
                                        delay(1.seconds)
                                    }
                                }

                                session.launch {
                                    while (isActive) {
                                        if (!isConnectionAlive()) {
                                            session.cancel()
                                            logger.error { "websocket $id: connection stalled" }
                                            break
                                        } else {
                                            delay(2000)
                                        }
                                    }
                                }

                                incoming.consumeEach {
                                    if (it is Frame.Ping) {
                                        outgoing.send(Frame.Pong(it.readBytes()))
                                        return@consumeEach
                                    }

                                    try {
                                        processFrame(it)
                                    } catch (e: Throwable) {
                                        logger.error(e) {
                                            "websocket $id: unable to process frame, reconnecting"
                                        }
                                        session.cancel()
                                    }
                                }
                            }
                        } catch (e: Throwable) {
                            logger.error(e) { "websocket $id: connect failed" }
                        } finally {
                            connection.getAndSet(
                                Connection(
                                    scope,
                                    onStatusUpdated,
                                    null,
                                ),
                            )
                            onStatusUpdated(ConnectorStatus.Failed)
                            onDisconnect()
                            client.close()
                            retries += 1
                            logger.info { "websocket $id: $retries retrying in $timeout" }
                            delay(timeout)
                        }
                    }
                } finally {
                    onStatusUpdated(ConnectorStatus.Disconnected)
                }
            }
        }
    }

    fun stopConnection() {
        connection.getAndUpdate { null }?.let { (scope, onStatusUpdated) ->
            onStatusUpdated(ConnectorStatus.Disconnecting)

            scope.cancel("websocket $id: connection stopped")
        }
    }

    abstract fun onDisconnect()

    private suspend fun ping(outgoing: SendChannel<Frame>) {
        try {
            val frame = Frame.Ping(byteArrayOf())
            outgoing.send(frame)
            pingsSkipped = 0
        } catch (e: Throwable) {
            pingsSkipped += 1
            logger.info { "websocket $id: unable to send ping, skipped pings $pingsSkipped" }
        }
    }

    fun send(frame: Frame) {
        val session = connection.value?.session

        if (session == null) {
            logger.warn { "websocket $id: unable to send frame, session is not active" }
            return
        }

        session.launch {
            try {
                session.outgoing.send(frame)
            } catch (e: Throwable) {
                logger.error(e) { "websocket $id: unable to send frame" }
            }
        }
    }
}
