diff --git a/buildSrc/src/main/kotlin/test/server/tests/MultiPartFormData.kt b/buildSrc/src/main/kotlin/test/server/tests/MultiPartFormData.kt index 9b53033cf8c..8cd297df907 100644 --- a/buildSrc/src/main/kotlin/test/server/tests/MultiPartFormData.kt +++ b/buildSrc/src/main/kotlin/test/server/tests/MultiPartFormData.kt @@ -4,6 +4,7 @@ package test.server.tests +import io.ktor.client.request.forms.* import io.ktor.http.* import io.ktor.http.content.* import io.ktor.server.application.* @@ -34,6 +35,18 @@ internal fun Application.multiPartFormDataTest() { call.receiveMultipart().readPart() call.respond(HttpStatusCode.OK) } + post("receive") { + val multipart = MultiPartFormDataContent( + formData { + append("text", "Hello, World!") + append("file", ByteArray(1024) { it.toByte() }, Headers.build { + append(HttpHeaders.ContentDisposition, """form-data; name="file"; filename="test.bin"""") + append(HttpHeaders.ContentType, ContentType.Application.OctetStream.toString()) + }) + } + ) + call.respond(multipart) + } } } } diff --git a/ktor-client/ktor-client-core/build.gradle.kts b/ktor-client/ktor-client-core/build.gradle.kts index 19fdad7c0a8..f11621825cf 100644 --- a/ktor-client/ktor-client-core/build.gradle.kts +++ b/ktor-client/ktor-client-core/build.gradle.kts @@ -8,6 +8,7 @@ kotlin.sourceSets { commonMain { dependencies { api(project(":ktor-http")) + api(project(":ktor-http:ktor-http-cio")) api(project(":ktor-shared:ktor-events")) api(project(":ktor-shared:ktor-websocket-serialization")) api(project(":ktor-shared:ktor-sse")) diff --git a/ktor-client/ktor-client-core/common/src/io/ktor/client/plugins/DefaultTransform.kt b/ktor-client/ktor-client-core/common/src/io/ktor/client/plugins/DefaultTransform.kt index 6d35d2a0f6e..3273ab2860b 100644 --- a/ktor-client/ktor-client-core/common/src/io/ktor/client/plugins/DefaultTransform.kt +++ b/ktor-client/ktor-client-core/common/src/io/ktor/client/plugins/DefaultTransform.kt @@ -8,6 +8,7 @@ import io.ktor.client.* import io.ktor.client.request.* import io.ktor.client.statement.* import io.ktor.http.* +import io.ktor.http.cio.* import io.ktor.http.content.* import io.ktor.util.logging.* import io.ktor.utils.io.* @@ -111,6 +112,22 @@ public fun HttpClient.defaultTransformers() { proceedWith(HttpResponseContainer(info, response.status)) } + MultiPartData::class -> { + val rawContentType = checkNotNull(context.response.headers[HttpHeaders.ContentType]) { + "No content type provided for multipart" + } + val contentType = ContentType.parse(rawContentType) + check(contentType.match(ContentType.MultiPart.FormData)) { + "Expected multipart/form-data, got $contentType" + } + + val contentLength = context.response.headers[HttpHeaders.ContentLength]?.toLong() + val body = CIOMultipartDataBase(coroutineContext, body, rawContentType, contentLength) + val parsedResponse = HttpResponseContainer(info, body) + + proceedWith(parsedResponse) + } + else -> null } if (result != null) { diff --git a/ktor-client/ktor-client-tests/common/test/io/ktor/client/tests/MultiPartFormDataTest.kt b/ktor-client/ktor-client-tests/common/test/io/ktor/client/tests/MultiPartFormDataTest.kt index 462c2b552c7..7d947df4211 100644 --- a/ktor-client/ktor-client-tests/common/test/io/ktor/client/tests/MultiPartFormDataTest.kt +++ b/ktor-client/ktor-client-tests/common/test/io/ktor/client/tests/MultiPartFormDataTest.kt @@ -4,13 +4,15 @@ package io.ktor.client.tests +import io.ktor.client.call.* import io.ktor.client.request.* import io.ktor.client.request.forms.* import io.ktor.client.tests.utils.* import io.ktor.http.* +import io.ktor.http.content.* +import io.ktor.utils.io.* import kotlinx.io.* import kotlin.test.* -import kotlin.time.* /** * Tests client request with multi-part form data. @@ -48,4 +50,41 @@ class MultiPartFormDataTest : ClientLoader() { assertTrue(response.status.isSuccess()) } } + + @Test + fun testReceiveMultiPartFormData() = clientTests { + test { client -> + val response = client.post("$TEST_SERVER/multipart/receive") + + val multipart = response.body() + var textFound = false + var fileFound = false + + multipart.forEachPart { part -> + when (part) { + is PartData.FormItem -> { + assertEquals("text", part.name) + assertEquals("Hello, World!", part.value) + textFound = true + } + is PartData.FileItem -> { + assertEquals("file", part.name) + assertEquals("test.bin", part.originalFileName) + + val bytes = part.provider().readRemaining().readByteArray() + assertEquals(1024, bytes.size) + for (i in bytes.indices) { + assertEquals(i.toByte(), bytes[i]) + } + fileFound = true + } + else -> fail("Unexpected part type: ${part::class.simpleName}") + } + part.dispose() + } + + assertTrue(textFound, "Text part not found") + assertTrue(fileFound, "File part not found") + } + } } diff --git a/ktor-http/ktor-http-cio/api/ktor-http-cio.klib.api b/ktor-http/ktor-http-cio/api/ktor-http-cio.klib.api index b9a6c9df8b6..2418cb5200d 100644 --- a/ktor-http/ktor-http-cio/api/ktor-http-cio.klib.api +++ b/ktor-http/ktor-http-cio/api/ktor-http-cio.klib.api @@ -40,6 +40,15 @@ final class io.ktor.http.cio/CIOHeaders : io.ktor.http/Headers { // io.ktor.http final fun names(): kotlin.collections/Set // io.ktor.http.cio/CIOHeaders.names|names(){}[0] } +final class io.ktor.http.cio/CIOMultipartDataBase : io.ktor.http.content/MultiPartData, kotlinx.coroutines/CoroutineScope { // io.ktor.http.cio/CIOMultipartDataBase|null[0] + constructor (kotlin.coroutines/CoroutineContext, io.ktor.utils.io/ByteReadChannel, kotlin/CharSequence, kotlin/Long?, kotlin/Long = ...) // io.ktor.http.cio/CIOMultipartDataBase.|(kotlin.coroutines.CoroutineContext;io.ktor.utils.io.ByteReadChannel;kotlin.CharSequence;kotlin.Long?;kotlin.Long){}[0] + + final val coroutineContext // io.ktor.http.cio/CIOMultipartDataBase.coroutineContext|{}coroutineContext[0] + final fun (): kotlin.coroutines/CoroutineContext // io.ktor.http.cio/CIOMultipartDataBase.coroutineContext.|(){}[0] + + final suspend fun readPart(): io.ktor.http.content/PartData? // io.ktor.http.cio/CIOMultipartDataBase.readPart|readPart(){}[0] +} + final class io.ktor.http.cio/ConnectionOptions { // io.ktor.http.cio/ConnectionOptions|null[0] constructor (kotlin/Boolean = ..., kotlin/Boolean = ..., kotlin/Boolean = ..., kotlin.collections/List = ...) // io.ktor.http.cio/ConnectionOptions.|(kotlin.Boolean;kotlin.Boolean;kotlin.Boolean;kotlin.collections.List){}[0] @@ -117,9 +126,44 @@ final class io.ktor.http.cio/Response : io.ktor.http.cio/HttpMessage { // io.kto final fun (): kotlin/CharSequence // io.ktor.http.cio/Response.version.|(){}[0] } +sealed class io.ktor.http.cio/MultipartEvent { // io.ktor.http.cio/MultipartEvent|null[0] + abstract fun release() // io.ktor.http.cio/MultipartEvent.release|release(){}[0] + + final class Epilogue : io.ktor.http.cio/MultipartEvent { // io.ktor.http.cio/MultipartEvent.Epilogue|null[0] + constructor (kotlinx.io/Source) // io.ktor.http.cio/MultipartEvent.Epilogue.|(kotlinx.io.Source){}[0] + + final val body // io.ktor.http.cio/MultipartEvent.Epilogue.body|{}body[0] + final fun (): kotlinx.io/Source // io.ktor.http.cio/MultipartEvent.Epilogue.body.|(){}[0] + + final fun release() // io.ktor.http.cio/MultipartEvent.Epilogue.release|release(){}[0] + } + + final class MultipartPart : io.ktor.http.cio/MultipartEvent { // io.ktor.http.cio/MultipartEvent.MultipartPart|null[0] + constructor (kotlinx.coroutines/Deferred, io.ktor.utils.io/ByteReadChannel) // io.ktor.http.cio/MultipartEvent.MultipartPart.|(kotlinx.coroutines.Deferred;io.ktor.utils.io.ByteReadChannel){}[0] + + final val body // io.ktor.http.cio/MultipartEvent.MultipartPart.body|{}body[0] + final fun (): io.ktor.utils.io/ByteReadChannel // io.ktor.http.cio/MultipartEvent.MultipartPart.body.|(){}[0] + final val headers // io.ktor.http.cio/MultipartEvent.MultipartPart.headers|{}headers[0] + final fun (): kotlinx.coroutines/Deferred // io.ktor.http.cio/MultipartEvent.MultipartPart.headers.|(){}[0] + + final fun release() // io.ktor.http.cio/MultipartEvent.MultipartPart.release|release(){}[0] + } + + final class Preamble : io.ktor.http.cio/MultipartEvent { // io.ktor.http.cio/MultipartEvent.Preamble|null[0] + constructor (kotlinx.io/Source) // io.ktor.http.cio/MultipartEvent.Preamble.|(kotlinx.io.Source){}[0] + + final val body // io.ktor.http.cio/MultipartEvent.Preamble.body|{}body[0] + final fun (): kotlinx.io/Source // io.ktor.http.cio/MultipartEvent.Preamble.body.|(){}[0] + + final fun release() // io.ktor.http.cio/MultipartEvent.Preamble.release|release(){}[0] + } +} + final fun (kotlin/CharSequence).io.ktor.http.cio.internals/parseDecLong(): kotlin/Long // io.ktor.http.cio.internals/parseDecLong|parseDecLong@kotlin.CharSequence(){}[0] final fun (kotlinx.coroutines/CoroutineScope).io.ktor.http.cio/decodeChunked(io.ktor.utils.io/ByteReadChannel): io.ktor.utils.io/WriterJob // io.ktor.http.cio/decodeChunked|decodeChunked@kotlinx.coroutines.CoroutineScope(io.ktor.utils.io.ByteReadChannel){}[0] final fun (kotlinx.coroutines/CoroutineScope).io.ktor.http.cio/decodeChunked(io.ktor.utils.io/ByteReadChannel, kotlin/Long): io.ktor.utils.io/WriterJob // io.ktor.http.cio/decodeChunked|decodeChunked@kotlinx.coroutines.CoroutineScope(io.ktor.utils.io.ByteReadChannel;kotlin.Long){}[0] +final fun (kotlinx.coroutines/CoroutineScope).io.ktor.http.cio/parseMultipart(io.ktor.utils.io/ByteReadChannel, io.ktor.http.cio/HttpHeadersMap, kotlin/Long = ...): kotlinx.coroutines.channels/ReceiveChannel // io.ktor.http.cio/parseMultipart|parseMultipart@kotlinx.coroutines.CoroutineScope(io.ktor.utils.io.ByteReadChannel;io.ktor.http.cio.HttpHeadersMap;kotlin.Long){}[0] +final fun (kotlinx.coroutines/CoroutineScope).io.ktor.http.cio/parseMultipart(io.ktor.utils.io/ByteReadChannel, kotlin/CharSequence, kotlin/Long?, kotlin/Long = ...): kotlinx.coroutines.channels/ReceiveChannel // io.ktor.http.cio/parseMultipart|parseMultipart@kotlinx.coroutines.CoroutineScope(io.ktor.utils.io.ByteReadChannel;kotlin.CharSequence;kotlin.Long?;kotlin.Long){}[0] final fun io.ktor.http.cio/encodeChunked(io.ktor.utils.io/ByteWriteChannel, kotlin.coroutines/CoroutineContext): io.ktor.utils.io/ReaderJob // io.ktor.http.cio/encodeChunked|encodeChunked(io.ktor.utils.io.ByteWriteChannel;kotlin.coroutines.CoroutineContext){}[0] final fun io.ktor.http.cio/expectHttpBody(io.ktor.http.cio/Request): kotlin/Boolean // io.ktor.http.cio/expectHttpBody|expectHttpBody(io.ktor.http.cio.Request){}[0] final fun io.ktor.http.cio/expectHttpBody(io.ktor.http/HttpMethod, kotlin/Long, kotlin/CharSequence?, io.ktor.http.cio/ConnectionOptions?, kotlin/CharSequence?): kotlin/Boolean // io.ktor.http.cio/expectHttpBody|expectHttpBody(io.ktor.http.HttpMethod;kotlin.Long;kotlin.CharSequence?;io.ktor.http.cio.ConnectionOptions?;kotlin.CharSequence?){}[0] diff --git a/ktor-http/ktor-http-cio/jvm/src/io/ktor/http/cio/CIOMultipartDataBase.kt b/ktor-http/ktor-http-cio/common/src/io/ktor/http/cio/CIOMultipartDataBase.kt similarity index 88% rename from ktor-http/ktor-http-cio/jvm/src/io/ktor/http/cio/CIOMultipartDataBase.kt rename to ktor-http/ktor-http-cio/common/src/io/ktor/http/cio/CIOMultipartDataBase.kt index 0003a26548a..65c45647029 100644 --- a/ktor-http/ktor-http-cio/jvm/src/io/ktor/http/cio/CIOMultipartDataBase.kt +++ b/ktor-http/ktor-http-cio/common/src/io/ktor/http/cio/CIOMultipartDataBase.kt @@ -49,7 +49,7 @@ public class CIOMultipartDataBase( val event = events.receive() eventToData(event)?.let { return it } } - } catch (t: ClosedReceiveChannelException) { + } catch (_: ClosedReceiveChannelException) { return null } } @@ -77,13 +77,7 @@ public class CIOMultipartDataBase( val body = part.body if (filename == null) { - val packet = body.readRemaining() // formFieldLimit.toLong()) -// if (!body.exhausted()) { -// val cause = IllegalStateException("Form field size limit exceeded: $formFieldLimit") -// body.cancel(cause) -// throw cause -// } - + val packet = body.readRemaining() packet.use { return PartData.FormItem(it.readText(), { part.release() }, CIOHeaders(headers)) } diff --git a/ktor-http/ktor-http-cio/jvm/src/io/ktor/http/cio/Multipart.kt b/ktor-http/ktor-http-cio/common/src/io/ktor/http/cio/Multipart.kt similarity index 57% rename from ktor-http/ktor-http-cio/jvm/src/io/ktor/http/cio/Multipart.kt rename to ktor-http/ktor-http-cio/common/src/io/ktor/http/cio/Multipart.kt index 3eeab2f5afa..5c2b5c0530d 100644 --- a/ktor-http/ktor-http-cio/jvm/src/io/ktor/http/cio/Multipart.kt +++ b/ktor-http/ktor-http-cio/common/src/io/ktor/http/cio/Multipart.kt @@ -6,14 +6,11 @@ package io.ktor.http.cio import io.ktor.http.cio.internals.* import io.ktor.utils.io.* -import io.ktor.utils.io.ByteString import io.ktor.utils.io.core.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.io.* import kotlinx.io.bytestring.* -import java.io.EOFException -import java.nio.* /** * Represents a multipart content starting event. Every part need to be completely consumed or released via [release] @@ -56,9 +53,8 @@ public sealed class MultipartEvent { headers.getCompleted().release() } } - runBlocking { - body.discard() - } + + body.discardBlocking() } } @@ -75,11 +71,12 @@ public sealed class MultipartEvent { } } +internal expect fun ByteReadChannel.discardBlocking() + /** * Parse a multipart preamble * @return number of bytes copied */ - private suspend fun parsePreambleImpl( boundary: ByteString, input: ByteReadChannel, @@ -165,74 +162,73 @@ public fun CoroutineScope.parseMultipart( private val CrLf = ByteString("\r\n".toByteArray()) -@OptIn(ExperimentalCoroutinesApi::class, InternalAPI::class) +@OptIn(ExperimentalCoroutinesApi::class) private fun CoroutineScope.parseMultipart( boundaryPrefixed: ByteString, input: ByteReadChannel, totalLength: Long?, maxPartSize: Long -): ReceiveChannel = - produce { - val countedInput = input.counted() - val readBeforeParse = countedInput.totalBytesRead - val firstBoundary = boundaryPrefixed.substring(PrefixString.size) - - val preambleData = writer { - parsePreambleImpl(firstBoundary, countedInput, channel, 8192) - channel.flushAndClose() - }.channel.readRemaining() - - if (preambleData.remaining > 0L) { - send(MultipartEvent.Preamble(preambleData)) - } - - while (!countedInput.isClosedForRead && !countedInput.skipIfFound(PrefixString)) { - countedInput.skipIfFound(CrLf) - - val body = ByteChannel() - val headers = CompletableDeferred() - val part = MultipartEvent.MultipartPart(headers, body) - send(part) - - var headersMap: HttpHeadersMap? = null - try { - headersMap = parsePartHeadersImpl(countedInput) - if (!headers.complete(headersMap)) { - headersMap.release() - throw kotlin.coroutines.cancellation.CancellationException( - "Multipart processing has been cancelled" - ) - } - parsePartBodyImpl(boundaryPrefixed, countedInput, body, headersMap, maxPartSize) - body.close() - } catch (cause: Throwable) { - if (headers.completeExceptionally(cause)) { - headersMap?.release() - } - body.close(cause) - throw cause - } - } +): ReceiveChannel = produce { + val countedInput = input.counted() + val readBeforeParse = countedInput.totalBytesRead + val firstBoundary = boundaryPrefixed.substring(PrefixString.size) + + val preambleData = writer { + parsePreambleImpl(firstBoundary, countedInput, channel, 8193) + channel.flushAndClose() + }.channel.readRemaining() + + if (preambleData.remaining > 0L) { + send(MultipartEvent.Preamble(preambleData)) + } - // Can be followed by two carriage returns - countedInput.skipIfFound(CrLf) + while (!countedInput.isClosedForRead && !countedInput.skipIfFound(PrefixString)) { countedInput.skipIfFound(CrLf) - if (totalLength != null) { - val consumedExceptEpilogue = countedInput.totalBytesRead - readBeforeParse - val size = totalLength - consumedExceptEpilogue - if (size > Int.MAX_VALUE) throw IOException("Failed to parse multipart: prologue is too long") - if (size > 0) { - send(MultipartEvent.Epilogue(countedInput.readPacket(size.toInt()))) + val body = ByteChannel() + val headers = CompletableDeferred() + val part = MultipartEvent.MultipartPart(headers, body) + send(part) + + var headersMap: HttpHeadersMap? = null + try { + headersMap = parsePartHeadersImpl(countedInput) + if (!headers.complete(headersMap)) { + headersMap.release() + throw kotlin.coroutines.cancellation.CancellationException( + "Multipart processing has been cancelled" + ) } - } else { - val epilogueContent = countedInput.readRemaining() - if (!epilogueContent.exhausted()) { - send(MultipartEvent.Epilogue(epilogueContent)) + parsePartBodyImpl(boundaryPrefixed, countedInput, body, headersMap, maxPartSize) + body.close() + } catch (cause: Throwable) { + if (headers.completeExceptionally(cause)) { + headersMap?.release() } + body.close(cause) + throw cause } } + // Can be followed by two carriage returns + countedInput.skipIfFound(CrLf) + countedInput.skipIfFound(CrLf) + + if (totalLength != null) { + val consumedExceptEpilogue = countedInput.totalBytesRead - readBeforeParse + val size = totalLength - consumedExceptEpilogue + if (size > Int.MAX_VALUE) throw IOException("Failed to parse multipart: prologue is too long") + if (size > 0) { + send(MultipartEvent.Epilogue(countedInput.readPacket(size.toInt()))) + } + } else { + val epilogueContent = countedInput.readRemaining() + if (!epilogueContent.exhausted()) { + send(MultipartEvent.Epilogue(epilogueContent)) + } + } +} + private const val PrefixChar = '-'.code.toByte() private val PrefixString = ByteString(PrefixChar, PrefixChar) @@ -297,7 +293,7 @@ private fun findBoundary(contentType: CharSequence): Int { * Parse multipart boundary encoded in [contentType] header value * @return a buffer containing CRLF, prefix '--' and boundary bytes */ -internal fun parseBoundaryInternal(contentType: CharSequence): ByteBuffer { +internal fun parseBoundaryInternal(contentType: CharSequence): ByteArray { val boundaryParameter = findBoundary(contentType) if (boundaryParameter == -1) { @@ -305,11 +301,20 @@ internal fun parseBoundaryInternal(contentType: CharSequence): ByteBuffer { } val boundaryStart = boundaryParameter + 9 - val boundaryBytes: ByteBuffer = ByteBuffer.allocate(74) - boundaryBytes.put(0x0d) - boundaryBytes.put(0x0a) - boundaryBytes.put(PrefixChar) - boundaryBytes.put(PrefixChar) + val boundaryBytes = ByteArray(74) + var position = 0 + + fun put(value: Byte) { + if (position >= boundaryBytes.size) throw IOException( + "Failed to parse multipart: boundary shouldn't be longer than 70 characters" + ) + boundaryBytes[position++] = value + } + + put(0x0d) + put(0x0a) + put(PrefixChar) + put(PrefixChar) var state = 0 // 0 - skipping spaces, 1 - unquoted characters, 2 - quoted no escape, 3 - quoted after escape @@ -336,154 +341,39 @@ internal fun parseBoundaryInternal(contentType: CharSequence): ByteBuffer { } else -> { state = 1 - boundaryBytes.put(v.toByte()) + put(v.toByte()) } } } 1 -> { // non-quoted string if (ch == ' ' || ch == ',' || ch == ';') { // space, comma or semicolon (;) break@loop - } else if (boundaryBytes.hasRemaining()) { - boundaryBytes.put(v.toByte()) } else { - // RFC 2046, sec 5.1.1 - throw IOException("Failed to parse multipart: boundary shouldn't be longer than 70 characters") + put(v.toByte()) } } + 2 -> { if (ch == '\\') { state = 3 } else if (ch == '"') { break@loop - } else if (boundaryBytes.hasRemaining()) { - boundaryBytes.put(v.toByte()) } else { - // RFC 2046, sec 5.1.1 - throw IOException("Failed to parse multipart: boundary shouldn't be longer than 70 characters") + put(v.toByte()) } } 3 -> { - if (boundaryBytes.hasRemaining()) { - boundaryBytes.put(v.toByte()) - state = 2 - } else { - // RFC 2046, sec 5.1.1 - throw IOException("Failed to parse multipart: boundary shouldn't be longer than 70 characters") - } + put(v.toByte()) + state = 2 } } } - boundaryBytes.flip() - - if (boundaryBytes.remaining() == 4) { + if (position == 4) { throw IOException("Empty multipart boundary is not allowed") } - return boundaryBytes -} - -/** - * Tries to skip the specified [delimiter] or fails if encounters bytes differs from the required. - * @return `true` if the delimiter was found and skipped or `false` when EOF. - */ -internal suspend fun ByteReadChannel.skipDelimiterOrEof(delimiter: ByteBuffer): Boolean { - require(delimiter.hasRemaining()) - require(delimiter.remaining() <= DEFAULT_BUFFER_SIZE) { - "Delimiter of ${delimiter.remaining()} bytes is too long: at most $DEFAULT_BUFFER_SIZE bytes could be checked" - } - - var found = false - - lookAhead { - found = tryEnsureDelimiter(delimiter) == delimiter.remaining() - } - - if (found) { - return true - } - - return trySkipDelimiterSuspend(delimiter) -} - -private suspend fun ByteReadChannel.trySkipDelimiterSuspend(delimiter: ByteBuffer): Boolean { - var result = true - - lookAheadSuspend { - if (!awaitAtLeast(delimiter.remaining()) && !awaitAtLeast(1)) { - result = false - return@lookAheadSuspend - } - if (tryEnsureDelimiter(delimiter) != delimiter.remaining()) throw IOException("Broken delimiter occurred") - } - - return result -} - -private fun LookAheadSession.tryEnsureDelimiter(delimiter: ByteBuffer): Int { - val found = startsWithDelimiter(delimiter) - if (found == -1) throw IOException("Failed to skip delimiter: actual bytes differ from delimiter bytes") - if (found < delimiter.remaining()) return found - - consumed(delimiter.remaining()) - return delimiter.remaining() -} - -@Suppress("LoopToCallChain") -private fun ByteBuffer.startsWith( - prefix: ByteBuffer, - prefixSkip: Int = 0 -): Boolean { - val size = minOf(remaining(), prefix.remaining() - prefixSkip) - if (size <= 0) return false - - val position = position() - val prefixPosition = prefix.position() + prefixSkip - - for (i in 0 until size) { - if (get(position + i) != prefix.get(prefixPosition + i)) return false - } - - return true -} - -/** - * @return Number of bytes of the delimiter found (possibly 0 if no bytes available yet) or -1 if it doesn't start - */ -private fun LookAheadSession.startsWithDelimiter(delimiter: ByteBuffer): Int { - val buffer = request(0, 1) ?: return 0 - val index = buffer.indexOfPartial(delimiter) - if (index != 0) return -1 - - val found = minOf(buffer.remaining() - index, delimiter.remaining()) - val notKnown = delimiter.remaining() - found - - if (notKnown > 0) { - val next = request(index + found, notKnown) ?: return found - if (!next.startsWith(delimiter, found)) return -1 - } - - return delimiter.remaining() -} - -@Suppress("LoopToCallChain") -private fun ByteBuffer.indexOfPartial(sub: ByteBuffer): Int { - val subPosition = sub.position() - val subSize = sub.remaining() - val first = sub[subPosition] - val limit = limit() - - outer@ for (idx in position() until limit) { - if (get(idx) == first) { - for (j in 1 until subSize) { - if (idx + j == limit) break - if (get(idx + j) != sub.get(subPosition + j)) continue@outer - } - return idx - position() - } - } - - return -1 + return boundaryBytes.copyOfRange(0, position) } private fun throwLimitExceeded(actual: Long, limit: Long): Nothing = diff --git a/ktor-http/ktor-http-cio/common/src/io/ktor/http/cio/RequestResponseBuilderCommon.kt b/ktor-http/ktor-http-cio/common/src/io/ktor/http/cio/RequestResponseBuilderCommon.kt index 98d7c72e043..082354377aa 100644 --- a/ktor-http/ktor-http-cio/common/src/io/ktor/http/cio/RequestResponseBuilderCommon.kt +++ b/ktor-http/ktor-http-cio/common/src/io/ktor/http/cio/RequestResponseBuilderCommon.kt @@ -5,7 +5,6 @@ package io.ktor.http.cio import io.ktor.http.* -import io.ktor.utils.io.core.* import kotlinx.io.* /** @@ -53,7 +52,3 @@ public expect class RequestResponseBuilder() { */ public fun release() } - -private const val SP: Byte = 0x20 -private const val CR: Byte = 0x0d -private const val LF: Byte = 0x0a diff --git a/ktor-http/ktor-http-cio/jsAndWasmShared/src/io/ktor/http/cio/MultipartJsAndWasm.kt b/ktor-http/ktor-http-cio/jsAndWasmShared/src/io/ktor/http/cio/MultipartJsAndWasm.kt new file mode 100644 index 00000000000..2fb66098744 --- /dev/null +++ b/ktor-http/ktor-http-cio/jsAndWasmShared/src/io/ktor/http/cio/MultipartJsAndWasm.kt @@ -0,0 +1,11 @@ +package io.ktor.http.cio + +import io.ktor.utils.io.* + +/* + * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +internal actual fun ByteReadChannel.discardBlocking() { + cancel() +} diff --git a/ktor-http/ktor-http-cio/jvm/src/io/ktor/http/cio/RequestResponseBuilder.kt b/ktor-http/ktor-http-cio/jvm/src/io/ktor/http/cio/RequestResponseBuilder.kt index c38581fd8dd..8511dc41dbd 100644 --- a/ktor-http/ktor-http-cio/jvm/src/io/ktor/http/cio/RequestResponseBuilder.kt +++ b/ktor-http/ktor-http-cio/jvm/src/io/ktor/http/cio/RequestResponseBuilder.kt @@ -5,7 +5,6 @@ package io.ktor.http.cio import io.ktor.http.* -import io.ktor.utils.io.* import io.ktor.utils.io.core.* import kotlinx.io.* import java.nio.* diff --git a/ktor-http/ktor-http-cio/jvm/test/io/ktor/tests/http/cio/MultipartTest.kt b/ktor-http/ktor-http-cio/jvm/test/io/ktor/tests/http/cio/MultipartTest.kt index ca9372fbbfc..cc539c2cfac 100644 --- a/ktor-http/ktor-http-cio/jvm/test/io/ktor/tests/http/cio/MultipartTest.kt +++ b/ktor-http/ktor-http-cio/jvm/test/io/ktor/tests/http/cio/MultipartTest.kt @@ -432,11 +432,7 @@ class MultipartTest { private fun testBoundary(expectedBoundary: String, headerValue: String) { val boundary = parseBoundaryInternal(headerValue) - val actualBoundary = String( - boundary.array(), - boundary.arrayOffset() + boundary.position(), - boundary.remaining() - ) + val actualBoundary = String(boundary) assertEquals(expectedBoundary, actualBoundary) } diff --git a/ktor-http/ktor-http-cio/jvm/test/io/ktor/tests/http/cio/TrySkipDelimiterTest.kt b/ktor-http/ktor-http-cio/jvm/test/io/ktor/tests/http/cio/TrySkipDelimiterTest.kt deleted file mode 100644 index a2d7ff24930..00000000000 --- a/ktor-http/ktor-http-cio/jvm/test/io/ktor/tests/http/cio/TrySkipDelimiterTest.kt +++ /dev/null @@ -1,153 +0,0 @@ -/* -* Copyright 2014-2021 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. -*/ - -package io.ktor.tests.http.cio - -import io.ktor.http.cio.* -import io.ktor.utils.io.* -import kotlinx.coroutines.* -import kotlinx.coroutines.test.* -import java.nio.* -import kotlin.test.* - -class TrySkipDelimiterTest { - private val ch = ByteChannel() - - @Test - fun testSmoke(): Unit = runTest { - ch.writeFully(byteArrayOf(1, 2, 3)) - ch.close() - - val delimiter = ByteBuffer.wrap(byteArrayOf(1, 2)) - assertTrue(ch.skipDelimiterOrEof(delimiter)) - assertEquals(3, ch.readByte()) - assertTrue(ch.isClosedForRead) - } - - @OptIn(InternalAPI::class) - @Test - fun testSmokeWithOffsetShift(): Unit = runTest { - ch.writeFully(byteArrayOf(9, 1, 2, 3)) - ch.close() - - val delimiter = ByteBuffer.wrap(byteArrayOf(1, 2)) - ch.discard(1) - assertTrue(ch.skipDelimiterOrEof(delimiter)) - assertEquals(3, ch.readByte()) - assertTrue(ch.isClosedForRead) - } - - @OptIn(InternalAPI::class) - @Test - fun testEmpty(): Unit = runTest { - ch.close() - - val delimiter = ByteBuffer.wrap(byteArrayOf(1, 2)) - assertFalse(ch.skipDelimiterOrEof(delimiter)) - } - - @OptIn(InternalAPI::class) - @Test - fun testFull(): Unit = runTest { - ch.writeFully(byteArrayOf(1, 2)) - ch.close() - - val delimiter = ByteBuffer.wrap(byteArrayOf(1, 2)) - assertTrue(ch.skipDelimiterOrEof(delimiter)) - assertTrue(ch.isClosedForRead) - } - - @OptIn(InternalAPI::class) - @Test - fun testIncomplete(): Unit = runTest { - ch.writeFully(byteArrayOf(1, 2)) - ch.close() - - val delimiter = ByteBuffer.wrap(byteArrayOf(1, 2, 3)) - assertFails { - ch.skipDelimiterOrEof(delimiter) - } - } - - @OptIn(InternalAPI::class) - @Test - fun testOtherBytes(): Unit = runTest { - ch.writeFully(byteArrayOf(7, 8)) - ch.close() - - val delimiter = ByteBuffer.wrap(byteArrayOf(1, 2)) - - assertFails { - ch.skipDelimiterOrEof(delimiter) - } - - // content shouldn't be consumed - assertEquals(7, ch.readByte()) - assertEquals(8, ch.readByte()) - assertTrue(ch.isClosedForRead) - } - - @Test - fun testTimeSplit(): Unit = runTest { - val writer = launch(CoroutineName("writer"), start = CoroutineStart.LAZY) { - ch.writeByte(2) - ch.close() - } - - ch.writeByte(1) - ch.flush() - writer.start() - - val delimiter = ByteBuffer.wrap(byteArrayOf(1, 2)) - - assertTrue(ch.skipDelimiterOrEof(delimiter)) - - assertTrue(ch.isClosedForRead) - } - - @Test - fun testTimeSplitNonClosed(): Unit = runTest { - val writer = launch(CoroutineName("writer"), start = CoroutineStart.LAZY) { - ch.writeByte(2) - ch.flush() - } - - ch.writeByte(1) - ch.flush() - writer.start() - - val delimiter = ByteBuffer.wrap(byteArrayOf(1, 2)) - - assertTrue(ch.skipDelimiterOrEof(delimiter)) - assertFalse(ch.isClosedForRead) - ch.cancel() - } - - @Test - fun testTimeSplitWrongBytes(): Unit = runTest { - val writer = launch(CoroutineName("writer"), start = CoroutineStart.LAZY) { - ch.writeByte(33) - ch.flush() - } - - ch.writeByte(1) - ch.flush() - writer.start() - - val delimiter = ByteBuffer.wrap(byteArrayOf(1, 2)) - - assertFails { - ch.skipDelimiterOrEof(delimiter) - } - - assertEquals(2, ch.availableForRead) - } - - @Test - fun testSkipTooLongDelimiter(): Unit = runTest { - assertFails { - ch.skipDelimiterOrEof(ByteBuffer.allocate(DEFAULT_BUFFER_SIZE * 2)) - } - } -} diff --git a/ktor-http/ktor-http-cio/jvmAndPosix/src/MultipartJvmAndPosix.kt b/ktor-http/ktor-http-cio/jvmAndPosix/src/MultipartJvmAndPosix.kt new file mode 100644 index 00000000000..b3f276a0f98 --- /dev/null +++ b/ktor-http/ktor-http-cio/jvmAndPosix/src/MultipartJvmAndPosix.kt @@ -0,0 +1,14 @@ +package io.ktor.http.cio + +import io.ktor.utils.io.* +import kotlinx.coroutines.* + +/* + * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +internal actual fun ByteReadChannel.discardBlocking() { + runBlocking { + discard() + } +}