diff --git a/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/JdkHttpServerInterpreter.scala b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/JdkHttpServerInterpreter.scala index 4cadfc2ff6..2bae0b1233 100644 --- a/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/JdkHttpServerInterpreter.scala +++ b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/JdkHttpServerInterpreter.scala @@ -16,7 +16,7 @@ trait JdkHttpServerInterpreter { def toHandler(ses: List[ServerEndpoint[Any, Id]]): HttpHandler = { val filteredEndpoints = FilterServerEndpoints[Any, Id](ses) - val requestBody = new JdkHttpRequestBody(jdkHttpServerOptions.createFile) + val requestBody = new JdkHttpRequestBody(jdkHttpServerOptions.createFile, jdkHttpServerOptions.multipartFileThresholdBytes) val responseBody = new JdkHttpToResponseBody val interceptors = RejectInterceptor.disableWhenSingleEndpoint(jdkHttpServerOptions.interceptors, ses) diff --git a/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/JdkHttpServerOptions.scala b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/JdkHttpServerOptions.scala index b69d84ba51..fd163f218e 100644 --- a/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/JdkHttpServerOptions.scala +++ b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/JdkHttpServerOptions.scala @@ -39,6 +39,10 @@ import java.util.logging.{Level, Logger} * Sets the size of server's tcp connection backlog. This is the maximum number of queued incoming connections to allow on the listening * socket. Queued TCP connections exceeding this limit may be rejected by the TCP implementation. If set to 0 or less the system default * for backlog size will be used. Default is 0. + * + * @param multipartFileThresholdBytes + * Sets the threshold of bytes of a multipart upload to trigger writing the multipart contents to a temporary file rather than keeping it + * entirely in memory. Default is 50MB. */ case class JdkHttpServerOptions( interceptors: List[Interceptor[Id]], @@ -50,7 +54,8 @@ case class JdkHttpServerOptions( host: String = "0.0.0.0", executor: Option[Executor] = None, httpsConfigurator: Option[HttpsConfigurator] = None, - backlogSize: Int = 0 + backlogSize: Int = 0, + multipartFileThresholdBytes: Long = 52_428_800 ) { require(0 <= port && port <= 65535, "Port has to be in 1-65535 range or 0 if random!") def prependInterceptor(i: Interceptor[Id]): JdkHttpServerOptions = copy(interceptors = i :: interceptors) diff --git a/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala index d67b772258..4ebbaf4ed9 100644 --- a/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala +++ b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala @@ -14,7 +14,7 @@ import java.io._ import java.nio.ByteBuffer import java.nio.file.{Files, StandardCopyOption} -private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile) extends RequestBody[Id, NoStreams] { +private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile, multipartFileThresholdBytes: Long) extends RequestBody[Id, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW]): RawValue[RAW] = { @@ -46,7 +46,12 @@ private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile .flatMap( _.split(";") .find(_.trim().startsWith(boundaryPrefix)) - .map(line => s"--${line.trim().substring(boundaryPrefix.length)}") + .map(line => { + val boundary = line.trim().substring(boundaryPrefix.length) + if (boundary.length > 70) + throw new IllegalArgumentException("Multipart boundary must be no longer than 70 characters.") + s"--$boundary" + }) ) .getOrElse(throw new IllegalArgumentException("Unable to extract multipart boundary from multipart request")) } @@ -55,12 +60,11 @@ private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile val httpExchange = jdkHttpRequest(request) val boundary = extractBoundary(httpExchange) - parseMultipartBody(httpExchange, boundary).flatMap(parsedPart => + parseMultipartBody(httpExchange.getRequestBody, boundary, multipartFileThresholdBytes).flatMap(parsedPart => parsedPart.getName.flatMap(name => m.partType(name) .map(partType => { - val bodyInputStream = new ByteArrayInputStream(parsedPart.body) - val bodyRawValue = toRaw(request, partType, bodyInputStream) + val bodyRawValue = toRaw(request, partType, parsedPart.getBody) Part( name, bodyRawValue.value, diff --git a/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/KMPMatcher.scala b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/KMPMatcher.scala new file mode 100644 index 0000000000..150bb364dc --- /dev/null +++ b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/KMPMatcher.scala @@ -0,0 +1,60 @@ +package sttp.tapir.server.jdkhttp.internal +import scala.collection.mutable + +class KMPMatcher(delimiter: Array[Byte]) { + private val table = KMPMatcher.buildLongestPrefixSuffixTable(delimiter) + private var matches: Int = 0 + + def noMatches = this.matches == 0 + def getMatches: Int = this.matches + def getDelimiter: Array[Byte] = this.delimiter + + def matchByte(b: Byte): KMPMatcher.MatchResult = { + val numMatchesBeforeReset = getMatches + while (getMatches > 0 && b != delimiter(getMatches)) { + this.matches = this.table(getMatches - 1) + } + + val matchesBeforeCurrentByte = getMatches + + if (b == delimiter(matches)) { + matches += 1 + if (this.matches == delimiter.length) { + this.matches = 0 + KMPMatcher.Match + } else { + KMPMatcher.NotMatched(numMatchesBeforeReset - matchesBeforeCurrentByte) + } + } else { + KMPMatcher.NotMatched(numMatchesBeforeReset - matchesBeforeCurrentByte) + } + } +} + +object KMPMatcher { + sealed trait MatchResult + case object Match extends MatchResult + case class NotMatched(numNoLongerMatchedBytes: Int) extends MatchResult + + private def buildLongestPrefixSuffixTable(s: Array[Byte]): mutable.ArrayBuffer[Int] = { + val lookupTable = mutable.ArrayBuffer.fill(s.length)(-1) + lookupTable(0) = 0 + var len = 0 + var i = 1 + while (i < s.length) { + if (s(i) == s(len)) { + len += 1 + lookupTable(i) = len + i += 1 + } else { + if (len == 0) { + lookupTable(i) = 0 + i = i + 1 + } else { + len = lookupTable(len - 1) + } + } + } + lookupTable + } +} diff --git a/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/ParsedMultiPart.scala b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/ParsedMultiPart.scala index 0db086bcf0..311ec1499f 100644 --- a/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/ParsedMultiPart.scala +++ b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/ParsedMultiPart.scala @@ -1,12 +1,30 @@ package sttp.tapir.server.jdkhttp.internal -import com.sun.net.httpserver.HttpExchange import sttp.model.Header -import java.io.{BufferedReader, InputStreamReader} +import sttp.tapir.Defaults.createTempFile +import sttp.tapir.TapirFile +import sttp.tapir.server.jdkhttp.internal.KMPMatcher.{Match, NotMatched} -case class ParsedMultiPart(headers: Map[String, Seq[String]], body: Array[Byte]) { +import java.io.{ + BufferedOutputStream, + ByteArrayInputStream, + ByteArrayOutputStream, + FileInputStream, + FileOutputStream, + InputStream, + OutputStream +} +import scala.collection.mutable + +case class ParsedMultiPart() { + private var body: InputStream = new ByteArrayInputStream(Array.empty) + private val headers: mutable.Map[String, Seq[String]] = mutable.Map.empty + + def getBody: InputStream = body def getHeader(headerName: String): Option[String] = headers.get(headerName).flatMap(_.headOption) - def fileItemHeaders: Seq[Header] = headers.toSeq.flatMap { case (name, values) => values.map(value => Header(name, value)) } + def fileItemHeaders: Seq[Header] = headers.toSeq.flatMap { case (name, values) => + values.map(value => Header(name, value)) + } def getDispositionParams: Map[String, String] = { val headerValue = getHeader("content-disposition") @@ -33,59 +51,153 @@ case class ParsedMultiPart(headers: Map[String, Seq[String]], body: Array[Byte]) .map(_.replaceAll("^\"|\"$", "")) ) - def addHeader(l: String): ParsedMultiPart = { + def addHeader(l: String): Unit = { val (name, value) = l.splitAt(l.indexOf(":")) val headerName = name.trim.toLowerCase val headerValue = value.stripPrefix(":").trim - val newHeaderEntry = (headerName -> (this.headers.getOrElse(headerName, Seq.empty) :+ headerValue)) - this.copy(headers = headers + newHeaderEntry) + val newHeaderEntry = headerName -> (this.headers.getOrElse(headerName, Seq.empty) :+ headerValue) + this.headers.addOne(newHeaderEntry) } + def withBody(body: InputStream): ParsedMultiPart = { + this.body = body + this + } } object ParsedMultiPart { - def empty: ParsedMultiPart = new ParsedMultiPart(Map.empty, Array.empty) - - sealed trait ParseState - case object Default extends ParseState - case object AfterBoundary extends ParseState - case object AfterHeaderSpace extends ParseState - - private case class ParseData( - currentPart: ParsedMultiPart, - completedParts: List[ParsedMultiPart], - parseState: ParseState - ) { - def changeState(state: ParseState): ParseData = this.copy(parseState = state) - def addHeader(header: String): ParseData = this.copy(currentPart = currentPart.addHeader(header)) - def addBody(body: Array[Byte]): ParseData = this.copy(currentPart = currentPart.copy(body = currentPart.body ++ body)) - def completePart(): ParseData = this.currentPart.getName match { - case Some(_) => - this.copy( - completedParts = completedParts :+ currentPart, - currentPart = empty, - parseState = AfterBoundary - ) - case None => changeState(AfterBoundary) + + sealed trait ParseStatus {} + private case object LookForInitialBoundary extends ParseStatus + private case object ParsePartHeaders extends ParseStatus + private case object ParsePartBody extends ParseStatus + + private val CRLF = Array[Byte]('\r', '\n') + private val END_DELIMITER = Array[Byte]('-', '-') + private val bufferSize = 8192 + + private class ParseState(boundary: Array[Byte], multipartFileThresholdBytes: Long) { + var completedParts: List[ParsedMultiPart] = List.empty + private val buffer = new mutable.ArrayBuffer[Byte](bufferSize) + private var done = false + private var currentPart: ParsedMultiPart = new ParsedMultiPart() + private var bodySize: Int = 0 + private var stream: PartStream = ByteStream() + private var parseState: ParseStatus = LookForInitialBoundary + + private val initialBoundaryMatcher = new KMPMatcher(boundary ++ CRLF) + private val boundaryMatcher = new KMPMatcher(CRLF ++ boundary ++ CRLF) + private val endMatcher = new KMPMatcher(CRLF ++ boundary ++ END_DELIMITER) + + def isDone: Boolean = done + + def updateStateWith(currentByte: Int): Unit = { + buffer.addOne(currentByte.toByte) + this.parseState match { + case LookForInitialBoundary => parseInitialBoundary(currentByte) + case ParsePartHeaders => parsePartHeaders() + case ParsePartBody => parsePartBody(currentByte) + } + } + + private def parseInitialBoundary(currentByte: Int): Unit = { + val foundBoundary = initialBoundaryMatcher.matchByte(currentByte.toByte) + if (foundBoundary == Match) { + changeState(ParsePartHeaders) + } } - } - def parseMultipartBody(httpExchange: HttpExchange, boundary: String): Seq[ParsedMultiPart] = { - val reader = new BufferedReader(new InputStreamReader(httpExchange.getRequestBody)) - val initialParseState: ParseData = ParseData(empty, List.empty, Default) - Iterator - .continually(reader.readLine()) - .takeWhile(_ != null) - .foldLeft(initialParseState) { case (state, line) => - state.parseState match { - case Default if line.startsWith(boundary) => state.changeState(AfterBoundary) - case Default => state - case AfterBoundary if line.trim.isEmpty => state.changeState(AfterHeaderSpace) - case AfterBoundary => state.addHeader(line) - case AfterHeaderSpace if !line.startsWith(boundary) => state.addBody(line.getBytes()) - case AfterHeaderSpace => state.completePart() + private def parsePartHeaders(): Unit = { + if (buffer.endsWith(CRLF)) { + buffer.view.map(_.toChar).mkString match { + case headerLine if !headerLine.isBlank => addHeader(headerLine) + case _ => changeState(ParsePartBody) } + buffer.clear() + } else if (buffer.length >= bufferSize) { + throw new RuntimeException( + s"Reached max size of $bufferSize bytes before reaching newline when parsing header." + ) } - .completedParts + } + + private def parsePartBody(currentByte: Int): Unit = { + bodySize += 1 + val foundFinalBoundary = endMatcher.matchByte(currentByte.toByte) + val foundBoundary = boundaryMatcher.matchByte(currentByte.toByte) + convertStreamToFileIfThresholdMet() + + (foundBoundary, foundFinalBoundary) match { + case (Match, _) => handleEndOfBody(boundaryMatcher.getDelimiter.length) + case (_, Match) => done = true; handleEndOfBody(endMatcher.getDelimiter.length) + case _ if endMatcher.noMatches && boundaryMatcher.noMatches => writeAllBytes() + case (NotMatched(x), NotMatched(y)) => writeUnmatchedBytes(Math.min(x, y)) + } + } + + def writeAllBytes(): Unit = { + buffer.foreach(byte => stream.underlying.write(byte.toInt)) + buffer.clear() + } + + def handleEndOfBody(delimiterLength: Int): Unit = { + val bytesToWrite = buffer.view.slice(0, buffer.length - delimiterLength) + bytesToWrite.foreach(byte => stream.underlying.write(byte.toInt)) + val bodyInputStream = stream match { + case FileStream(tempFile, stream) => stream.close(); new FileInputStream(tempFile) + case ByteStream(stream) => new ByteArrayInputStream(stream.toByteArray) + } + completePart(bodyInputStream) + stream = ByteStream() + bodySize = 0 + } + + def writeUnmatchedBytes(unmatchedBytes: Int): Unit = { + val bytesToWrite = buffer.view.slice(0, unmatchedBytes) + bytesToWrite.foreach(byte => stream.underlying.write(byte.toInt)) + buffer.dropInPlace(unmatchedBytes) + } + + private def changeState(state: ParseStatus): Unit = { + buffer.clear() + this.parseState = state + } + private def addHeader(header: String): Unit = this.currentPart.addHeader(header) + private def completePart(body: InputStream): Unit = { + this.currentPart.getName match { + case Some(_) => + this.completedParts = completedParts :+ currentPart.withBody(body) + this.currentPart = new ParsedMultiPart() + case None => + } + this.changeState(ParsePartHeaders) + } + + private sealed trait PartStream { val underlying: OutputStream } + private case class FileStream(tempFile: TapirFile, underlying: BufferedOutputStream) extends PartStream + private case class ByteStream(underlying: ByteArrayOutputStream = new ByteArrayOutputStream()) extends PartStream + private def convertStreamToFileIfThresholdMet(): Unit = stream match { + case ByteStream(os) if bodySize >= multipartFileThresholdBytes => + val newFile = createTempFile() + val fileOutputStream = new FileOutputStream(newFile) + os.writeTo(fileOutputStream) + os.close() + stream = FileStream(newFile, new BufferedOutputStream(fileOutputStream)) + case _ => + } + } + + def parseMultipartBody(inputStream: InputStream, boundary: String, multipartFileThresholdBytes: Long): Seq[ParsedMultiPart] = { + val boundaryBytes = boundary.getBytes + val state = new ParseState(boundaryBytes, multipartFileThresholdBytes) + + while (!state.isDone) { + val currentByte = inputStream.read() + if (currentByte == -1) + throw new RuntimeException("Parsing multipart failed, ran out of bytes before finding boundary") + state.updateStateWith(currentByte) + } + + state.completedParts } } diff --git a/server/jdkhttp-server/src/test/scala/sttp/tapir/server/jdkhttp/internal/KMPMatcherTest.scala b/server/jdkhttp-server/src/test/scala/sttp/tapir/server/jdkhttp/internal/KMPMatcherTest.scala new file mode 100644 index 0000000000..1a489f77d0 --- /dev/null +++ b/server/jdkhttp-server/src/test/scala/sttp/tapir/server/jdkhttp/internal/KMPMatcherTest.scala @@ -0,0 +1,37 @@ +package sttp.tapir.server.jdkhttp.internal + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import sttp.tapir.server.jdkhttp.internal.KMPMatcher.{Match, NotMatched} + +class KMPMatcherTest extends AnyFlatSpec with Matchers { + + it should "match over a set of bytes and not allow writing of any bytes if only matching" in { + val matchBytes = "--abc\r\n".getBytes + val matcher = new KMPMatcher(matchBytes) + matcher.matchByte('-'.toByte) shouldBe NotMatched(0) + matcher.matchByte('-'.toByte) shouldBe NotMatched(0) + matcher.matchByte('a'.toByte) shouldBe NotMatched(0) + matcher.matchByte('b'.toByte) shouldBe NotMatched(0) + matcher.matchByte('c'.toByte) shouldBe NotMatched(0) + matcher.matchByte('\r'.toByte) shouldBe NotMatched(0) + matcher.matchByte('\n'.toByte) shouldBe Match + } + + it should "match over a set of bytes and allow writing of any non-matched bytes" in { + val matchBytes = "--abc\r\n".getBytes + val matcher = new KMPMatcher(matchBytes) + matcher.matchByte('-'.toByte) shouldBe NotMatched(0) + matcher.matchByte('-'.toByte) shouldBe NotMatched(0) + matcher.matchByte('-'.toByte) shouldBe NotMatched(1) + matcher.matchByte('a'.toByte) shouldBe NotMatched(0) + matcher.matchByte('a'.toByte) shouldBe NotMatched(3) + matcher.matchByte('-'.toByte) shouldBe NotMatched(0) + matcher.matchByte('-'.toByte) shouldBe NotMatched(0) + matcher.matchByte('a'.toByte) shouldBe NotMatched(0) + matcher.matchByte('b'.toByte) shouldBe NotMatched(0) + matcher.matchByte('c'.toByte) shouldBe NotMatched(0) + matcher.matchByte('\r'.toByte) shouldBe NotMatched(0) + matcher.matchByte('\n'.toByte) shouldBe Match + } +} diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMultipartTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMultipartTests.scala index 3775fd816f..d83d9232bc 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMultipartTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMultipartTests.scala @@ -127,6 +127,33 @@ class ServerMultipartTests[F[_], OPTIONS, ROUTE]( r.body should include("file1:peach mario") r.body should include("file2:daisy luigi") } + }, + testServer(in_raw_multipart_out_string, "boundary substring in body")((parts: Seq[Part[Array[Byte]]]) => + pureResult( + parts.map(part => s"${part.name}:${new String(part.body)}").mkString("\n__\n").asRight[Unit] + ) + ) { (backend, baseUri) => + val testBody = "--AAB\r\n" + + "Content-Disposition: form-data; name=\"firstPart\"\r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + "BODYONE\r\n" + + "--AA\r\n" + + "--AAB\r\n" + + "Content-Disposition: form-data; name=\"secondPart\"\r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + "BODYTWO\r\n" + + "--AAB--\r\n" + basicStringRequest + .post(uri"$baseUri/api/echo/multipart") + .header("Content-Type", "multipart/form-data; boundary=AAB") + .body(testBody) + .send(backend) + .map { r => + r.code shouldBe StatusCode.Ok + r.body should be("firstPart:BODYONE\r\n--AA\n__\nsecondPart:BODYTWO") + } } ) }