Skip to content

Commit

Permalink
Make use of Ktor websocket extensions and serialization
Browse files Browse the repository at this point in the history
Add Compression.
  • Loading branch information
DRSchlaubi committed Apr 9, 2023
1 parent d3b1312 commit 3a2a53c
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 302 deletions.
29 changes: 17 additions & 12 deletions core/src/commonMain/kotlin/builder/kord/KordBuilderUtil.kt
Original file line number Diff line number Diff line change
@@ -1,22 +1,37 @@
package dev.kord.core.builder.kord

import dev.kord.common.annotation.KordInternal
import dev.kord.common.annotation.KordUnsafe
import dev.kord.common.entity.Snowflake
import dev.kord.common.http.HttpEngine
import dev.kord.gateway.WebSocketCompression
import io.ktor.client.*
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.plugins.websocket.*
import io.ktor.serialization.kotlinx.*
import io.ktor.serialization.kotlinx.json.*
import io.ktor.util.*
import kotlinx.serialization.json.Json

@OptIn(KordUnsafe::class)
internal fun HttpClientConfig<*>.defaultConfig() {
expectSuccess = false

val json = Json {
encodeDefaults = false
allowStructuredMapKeys = true
ignoreUnknownKeys = true
isLenient = true
}
install(ContentNegotiation) {
json()
json(json)
}
install(WebSockets) {
contentConverter = KotlinxWebsocketSerializationConverter(json)
extensions {
install(WebSocketCompression)
}
}
install(WebSockets)
}

/** @suppress */
Expand All @@ -26,18 +41,8 @@ public fun HttpClient?.configure(): HttpClient {
defaultConfig()
}

val json = Json {
encodeDefaults = false
allowStructuredMapKeys = true
ignoreUnknownKeys = true
isLenient = true
}

return HttpClient(HttpEngine) {
defaultConfig()
install(ContentNegotiation) {
json(json)
}
}
}

Expand Down
9 changes: 7 additions & 2 deletions gateway/src/commonMain/kotlin/Command.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,31 @@ import dev.kord.common.serialization.InstantInEpochMillisecondsSerializer
import kotlinx.atomicfu.atomic
import kotlinx.datetime.Instant
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.buildClassSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.SerializationStrategy as KSerializationStrategy

@Serializable(with = Command.SerializationStrategy::class)
public sealed class Command {

public data class Heartbeat(val sequenceNumber: Int?) : Command()

public object SerializationStrategy : KSerializationStrategy<Command> {
public object SerializationStrategy : KSerializer<Command> {

override val descriptor: SerialDescriptor = buildClassSerialDescriptor("Command") {
element("op", OpCode.serializer().descriptor)
element("d", JsonElement.serializer().descriptor)
}

override fun deserialize(decoder: Decoder): Command =
TODO("Deserializing gateway commands is not supported yet")

@OptIn(PrivilegedIntent::class)
override fun serialize(encoder: Encoder, value: Command) {
val composite = encoder.beginStructure(descriptor)
Expand Down
70 changes: 70 additions & 0 deletions gateway/src/commonMain/kotlin/Compression.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package dev.kord.gateway

import dev.kord.common.annotation.KordUnsafe
import io.ktor.util.*
import io.ktor.websocket.*
import java.io.ByteArrayOutputStream
import java.util.zip.Inflater
import java.util.zip.InflaterOutputStream

/**
* [WebSocketExtension] inflating incoming websocket requests using `zlib`.
*
* *Note:** Normally you don't need this and this is configured by Kord automatically, however, if you want to use
* a custom HTTP client, you might need to add this, don't use it if you don't use what you're doing
*/
@KordUnsafe
public class WebSocketCompression : WebSocketExtension<Unit> {
/**
* https://discord.com/developers/docs/topics/gateway#transport-compression
*
* > Every connection to the gateway should use its own unique zlib context.
*
* https://api.ktor.io/ktor-shared/ktor-websockets/io.ktor.websocket/-web-socket-extension/index.html
* > A WebSocket extension instance. This instance is created for each WebSocket request,
* for every installed extension by WebSocketExtensionFactory.
*/
private val inflater = Inflater()

override val factory: WebSocketExtensionFactory<Unit, out WebSocketExtension<Unit>>
get() = Companion
override val protocols: List<WebSocketExtensionHeader>
get() = emptyList()

override fun clientNegotiation(negotiatedProtocols: List<WebSocketExtensionHeader>): Boolean = true

override fun processIncomingFrame(frame: Frame): Frame {
return if (frame is Frame.Binary) {
frame.deflateData()
} else {
frame
}
}

// Discord doesn't support deflating of gateway commands
override fun processOutgoingFrame(frame: Frame): Frame = frame

override fun serverNegotiation(requestedProtocols: List<WebSocketExtensionHeader>): List<WebSocketExtensionHeader> =
requestedProtocols

private fun Frame.deflateData(): Frame {
val outputStream = ByteArrayOutputStream()
InflaterOutputStream(outputStream, inflater).use {
it.write(data)
}

return outputStream.use {
val raw = String(outputStream.toByteArray(), 0, outputStream.size(), Charsets.UTF_8)
Frame.Text(raw)
}
}

public companion object : WebSocketExtensionFactory<Unit, WebSocketCompression> {
override val key: AttributeKey<WebSocketCompression> = AttributeKey("WebSocketCompression")
override val rsv1: Boolean = false
override val rsv2: Boolean = false
override val rsv3: Boolean = false

override fun install(config: Unit.() -> Unit): WebSocketCompression = WebSocketCompression()
}
}
69 changes: 12 additions & 57 deletions gateway/src/commonMain/kotlin/DefaultGateway.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ import kotlinx.atomicfu.AtomicRef
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.update
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.serialization.json.Json
import mu.KotlinLogging
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
Expand Down Expand Up @@ -76,13 +76,6 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {

private val handshakeHandler: HandshakeHandler

private lateinit var inflater: Inflater

private val jsonParser = Json {
ignoreUnknownKeys = true
isLenient = true
}

private val stateMutex = Mutex()

init {
Expand Down Expand Up @@ -110,14 +103,9 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
}

defaultGatewayLogger.trace { "opening gateway connection to $gatewayUrl" }
socket = data.client.webSocketSession { url(gatewayUrl) }

/**
* https://discord.com/developers/docs/topics/gateway#transport-compression
*
* > Every connection to the gateway should use its own unique zlib context.
*/
inflater = Inflater()
socket = data.client.webSocketSession {
url(gatewayUrl)
}
} catch (exception: Exception) {
defaultGatewayLogger.error(exception)
if (exception.isTimeout()) {
Expand Down Expand Up @@ -167,31 +155,12 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
}


@OptIn(ExperimentalCoroutinesApi::class)
private suspend fun readSocket() {
socket.incoming.asFlow().buffer(Channel.UNLIMITED).collect {
when (it) {
is Frame.Binary, is Frame.Text -> read(it)
else -> { /*ignore*/
}
}
}
}

private suspend fun read(frame: Frame) {
defaultGatewayLogger.trace { "Received raw frame: $frame" }
val json = when {
compression -> with(inflater) { frame.inflateData() }
else -> frame.data.decodeToString()
}

try {
defaultGatewayLogger.trace { "Gateway <<< $json" }
val event = jsonParser.decodeFromString(Event.DeserializationStrategy, json) ?: return
while (!socket.incoming.isClosedForReceive) {
val event = socket.receiveDeserialized<Event>()
data.eventFlow.emit(event)
} catch (exception: Exception) {
defaultGatewayLogger.error(exception)
}

}

private suspend fun handleClose() {
Expand All @@ -209,6 +178,7 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
state.update { State.Stopped }
throw IllegalStateException("Gateway closed: ${reason.code} ${reason.message}")
}

discordReason.resetSession -> {
setStopped()
}
Expand All @@ -220,14 +190,6 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
state.update { State.Running(true) }
}

private fun <T> ReceiveChannel<T>.asFlow() = flow {
try {
for (value in this@asFlow) emit(value)
} catch (ignore: CancellationException) {
//reading was stopped from somewhere else, ignore
}
}

override suspend fun stop() {
check(state.value !is State.Detached) { "The resources of this gateway are detached, create another one" }
data.eventFlow.emit(Close.UserClose)
Expand Down Expand Up @@ -268,14 +230,7 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {

private suspend fun sendUnsafe(command: Command) {
data.sendRateLimiter.consume()
val json = Json.encodeToString(Command.SerializationStrategy, command)
if (command is Identify) {
defaultGatewayLogger.trace {
val copy = command.copy(token = "token")
"Gateway >>> ${Json.encodeToString(Command.SerializationStrategy, copy)}"
}
} else defaultGatewayLogger.trace { "Gateway >>> $json" }
socket.send(Frame.Text(json))
socket.sendSerialized(command)
}

@OptIn(ExperimentalCoroutinesApi::class)
Expand Down
16 changes: 11 additions & 5 deletions gateway/src/commonMain/kotlin/DefaultGatewayBuilder.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dev.kord.gateway

import dev.kord.common.KordConfiguration
import dev.kord.common.http.HttpEngine
import dev.kord.common.annotation.KordUnsafe
import dev.kord.common.ratelimit.IntervalRateLimiter
import dev.kord.common.ratelimit.RateLimiter
import dev.kord.gateway.ratelimit.IdentifyRateLimiter
Expand All @@ -12,10 +13,11 @@ import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.plugins.websocket.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.serialization.kotlinx.json.*
import io.ktor.serialization.kotlinx.*
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.serialization.json.Json
import kotlin.time.Duration.Companion.seconds

public class DefaultGatewayBuilder {
Expand All @@ -28,11 +30,15 @@ public class DefaultGatewayBuilder {
public var dispatcher: CoroutineDispatcher = Dispatchers.Default
public var eventFlow: MutableSharedFlow<Event> = MutableSharedFlow(extraBufferCapacity = Int.MAX_VALUE)

@OptIn(KordUnsafe::class)
public fun build(): DefaultGateway {
val client = client ?: HttpClient(HttpEngine) {
install(WebSockets)
install(ContentNegotiation) {
json()
val client = client ?: HttpClient(CIO) {
install(WebSockets) {
contentConverter = KotlinxWebsocketSerializationConverter(Json)

extensions {
install(WebSocketCompression)
}
}
}
val retry = reconnectRetry ?: LinearRetry(2.seconds, 20.seconds, 10)
Expand Down
Loading

0 comments on commit 3a2a53c

Please sign in to comment.