Skip to content

Commit

Permalink
Merge pull request #3227 from jnatten/jdk-multipart-tmp
Browse files Browse the repository at this point in the history
jdkhttp-server: Write multipart parts bigger than threshold to files
  • Loading branch information
adamw authored Nov 9, 2023
2 parents f458cb6 + f20c1ea commit 185b893
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ trait JdkHttpServerInterpreter {

def toHandler(ses: List[ServerEndpoint[Any, Id]]): HttpHandler = {
val filteredEndpoints = FilterServerEndpoints[Any, Id](ses)
val requestBody = new JdkHttpRequestBody(jdkHttpServerOptions.createFile)
val requestBody = new JdkHttpRequestBody(jdkHttpServerOptions.createFile, jdkHttpServerOptions.multipartFileThresholdBytes)
val responseBody = new JdkHttpToResponseBody
val interceptors = RejectInterceptor.disableWhenSingleEndpoint(jdkHttpServerOptions.interceptors, ses)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ import java.util.logging.{Level, Logger}
* Sets the size of server's tcp connection backlog. This is the maximum number of queued incoming connections to allow on the listening
* socket. Queued TCP connections exceeding this limit may be rejected by the TCP implementation. If set to 0 or less the system default
* for backlog size will be used. Default is 0.
*
* @param multipartFileThresholdBytes
* Sets the threshold of bytes of a multipart upload to trigger writing the multipart contents to a temporary file rather than keeping it
* entirely in memory. Default is 50MB.
*/
case class JdkHttpServerOptions(
interceptors: List[Interceptor[Id]],
Expand All @@ -50,7 +54,8 @@ case class JdkHttpServerOptions(
host: String = "0.0.0.0",
executor: Option[Executor] = None,
httpsConfigurator: Option[HttpsConfigurator] = None,
backlogSize: Int = 0
backlogSize: Int = 0,
multipartFileThresholdBytes: Long = 52_428_800
) {
require(0 <= port && port <= 65535, "Port has to be in 1-65535 range or 0 if random!")
def prependInterceptor(i: Interceptor[Id]): JdkHttpServerOptions = copy(interceptors = i :: interceptors)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import java.io._
import java.nio.ByteBuffer
import java.nio.file.{Files, StandardCopyOption}

private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile) extends RequestBody[Id, NoStreams] {
private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile, multipartFileThresholdBytes: Long) extends RequestBody[Id, NoStreams] {
override val streams: capabilities.Streams[NoStreams] = NoStreams

override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW]): RawValue[RAW] = {
Expand Down Expand Up @@ -46,7 +46,12 @@ private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile
.flatMap(
_.split(";")
.find(_.trim().startsWith(boundaryPrefix))
.map(line => s"--${line.trim().substring(boundaryPrefix.length)}")
.map(line => {
val boundary = line.trim().substring(boundaryPrefix.length)
if (boundary.length > 70)
throw new IllegalArgumentException("Multipart boundary must be no longer than 70 characters.")
s"--$boundary"
})
)
.getOrElse(throw new IllegalArgumentException("Unable to extract multipart boundary from multipart request"))
}
Expand All @@ -55,12 +60,11 @@ private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile
val httpExchange = jdkHttpRequest(request)
val boundary = extractBoundary(httpExchange)

parseMultipartBody(httpExchange, boundary).flatMap(parsedPart =>
parseMultipartBody(httpExchange.getRequestBody, boundary, multipartFileThresholdBytes).flatMap(parsedPart =>
parsedPart.getName.flatMap(name =>
m.partType(name)
.map(partType => {
val bodyInputStream = new ByteArrayInputStream(parsedPart.body)
val bodyRawValue = toRaw(request, partType, bodyInputStream)
val bodyRawValue = toRaw(request, partType, parsedPart.getBody)
Part(
name,
bodyRawValue.value,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package sttp.tapir.server.jdkhttp.internal
import scala.collection.mutable

class KMPMatcher(delimiter: Array[Byte]) {
private val table = KMPMatcher.buildLongestPrefixSuffixTable(delimiter)
private var matches: Int = 0

def noMatches = this.matches == 0
def getMatches: Int = this.matches
def getDelimiter: Array[Byte] = this.delimiter

def matchByte(b: Byte): KMPMatcher.MatchResult = {
val numMatchesBeforeReset = getMatches
while (getMatches > 0 && b != delimiter(getMatches)) {
this.matches = this.table(getMatches - 1)
}

val matchesBeforeCurrentByte = getMatches

if (b == delimiter(matches)) {
matches += 1
if (this.matches == delimiter.length) {
this.matches = 0
KMPMatcher.Match
} else {
KMPMatcher.NotMatched(numMatchesBeforeReset - matchesBeforeCurrentByte)
}
} else {
KMPMatcher.NotMatched(numMatchesBeforeReset - matchesBeforeCurrentByte)
}
}
}

object KMPMatcher {
sealed trait MatchResult
case object Match extends MatchResult
case class NotMatched(numNoLongerMatchedBytes: Int) extends MatchResult

private def buildLongestPrefixSuffixTable(s: Array[Byte]): mutable.ArrayBuffer[Int] = {
val lookupTable = mutable.ArrayBuffer.fill(s.length)(-1)
lookupTable(0) = 0
var len = 0
var i = 1
while (i < s.length) {
if (s(i) == s(len)) {
len += 1
lookupTable(i) = len
i += 1
} else {
if (len == 0) {
lookupTable(i) = 0
i = i + 1
} else {
len = lookupTable(len - 1)
}
}
}
lookupTable
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
package sttp.tapir.server.jdkhttp.internal

import com.sun.net.httpserver.HttpExchange
import sttp.model.Header
import java.io.{BufferedReader, InputStreamReader}
import sttp.tapir.Defaults.createTempFile
import sttp.tapir.TapirFile
import sttp.tapir.server.jdkhttp.internal.KMPMatcher.{Match, NotMatched}

case class ParsedMultiPart(headers: Map[String, Seq[String]], body: Array[Byte]) {
import java.io.{
BufferedOutputStream,
ByteArrayInputStream,
ByteArrayOutputStream,
FileInputStream,
FileOutputStream,
InputStream,
OutputStream
}
import scala.collection.mutable

case class ParsedMultiPart() {
private var body: InputStream = new ByteArrayInputStream(Array.empty)
private val headers: mutable.Map[String, Seq[String]] = mutable.Map.empty

def getBody: InputStream = body
def getHeader(headerName: String): Option[String] = headers.get(headerName).flatMap(_.headOption)
def fileItemHeaders: Seq[Header] = headers.toSeq.flatMap { case (name, values) => values.map(value => Header(name, value)) }
def fileItemHeaders: Seq[Header] = headers.toSeq.flatMap { case (name, values) =>
values.map(value => Header(name, value))
}

def getDispositionParams: Map[String, String] = {
val headerValue = getHeader("content-disposition")
Expand All @@ -33,59 +51,153 @@ case class ParsedMultiPart(headers: Map[String, Seq[String]], body: Array[Byte])
.map(_.replaceAll("^\"|\"$", ""))
)

def addHeader(l: String): ParsedMultiPart = {
def addHeader(l: String): Unit = {
val (name, value) = l.splitAt(l.indexOf(":"))
val headerName = name.trim.toLowerCase
val headerValue = value.stripPrefix(":").trim
val newHeaderEntry = (headerName -> (this.headers.getOrElse(headerName, Seq.empty) :+ headerValue))
this.copy(headers = headers + newHeaderEntry)
val newHeaderEntry = headerName -> (this.headers.getOrElse(headerName, Seq.empty) :+ headerValue)
this.headers.addOne(newHeaderEntry)
}

def withBody(body: InputStream): ParsedMultiPart = {
this.body = body
this
}
}

object ParsedMultiPart {
def empty: ParsedMultiPart = new ParsedMultiPart(Map.empty, Array.empty)

sealed trait ParseState
case object Default extends ParseState
case object AfterBoundary extends ParseState
case object AfterHeaderSpace extends ParseState

private case class ParseData(
currentPart: ParsedMultiPart,
completedParts: List[ParsedMultiPart],
parseState: ParseState
) {
def changeState(state: ParseState): ParseData = this.copy(parseState = state)
def addHeader(header: String): ParseData = this.copy(currentPart = currentPart.addHeader(header))
def addBody(body: Array[Byte]): ParseData = this.copy(currentPart = currentPart.copy(body = currentPart.body ++ body))
def completePart(): ParseData = this.currentPart.getName match {
case Some(_) =>
this.copy(
completedParts = completedParts :+ currentPart,
currentPart = empty,
parseState = AfterBoundary
)
case None => changeState(AfterBoundary)

sealed trait ParseStatus {}
private case object LookForInitialBoundary extends ParseStatus
private case object ParsePartHeaders extends ParseStatus
private case object ParsePartBody extends ParseStatus

private val CRLF = Array[Byte]('\r', '\n')
private val END_DELIMITER = Array[Byte]('-', '-')
private val bufferSize = 8192

private class ParseState(boundary: Array[Byte], multipartFileThresholdBytes: Long) {
var completedParts: List[ParsedMultiPart] = List.empty
private val buffer = new mutable.ArrayBuffer[Byte](bufferSize)
private var done = false
private var currentPart: ParsedMultiPart = new ParsedMultiPart()
private var bodySize: Int = 0
private var stream: PartStream = ByteStream()
private var parseState: ParseStatus = LookForInitialBoundary

private val initialBoundaryMatcher = new KMPMatcher(boundary ++ CRLF)
private val boundaryMatcher = new KMPMatcher(CRLF ++ boundary ++ CRLF)
private val endMatcher = new KMPMatcher(CRLF ++ boundary ++ END_DELIMITER)

def isDone: Boolean = done

def updateStateWith(currentByte: Int): Unit = {
buffer.addOne(currentByte.toByte)
this.parseState match {
case LookForInitialBoundary => parseInitialBoundary(currentByte)
case ParsePartHeaders => parsePartHeaders()
case ParsePartBody => parsePartBody(currentByte)
}
}

private def parseInitialBoundary(currentByte: Int): Unit = {
val foundBoundary = initialBoundaryMatcher.matchByte(currentByte.toByte)
if (foundBoundary == Match) {
changeState(ParsePartHeaders)
}
}
}

def parseMultipartBody(httpExchange: HttpExchange, boundary: String): Seq[ParsedMultiPart] = {
val reader = new BufferedReader(new InputStreamReader(httpExchange.getRequestBody))
val initialParseState: ParseData = ParseData(empty, List.empty, Default)
Iterator
.continually(reader.readLine())
.takeWhile(_ != null)
.foldLeft(initialParseState) { case (state, line) =>
state.parseState match {
case Default if line.startsWith(boundary) => state.changeState(AfterBoundary)
case Default => state
case AfterBoundary if line.trim.isEmpty => state.changeState(AfterHeaderSpace)
case AfterBoundary => state.addHeader(line)
case AfterHeaderSpace if !line.startsWith(boundary) => state.addBody(line.getBytes())
case AfterHeaderSpace => state.completePart()
private def parsePartHeaders(): Unit = {
if (buffer.endsWith(CRLF)) {
buffer.view.map(_.toChar).mkString match {
case headerLine if !headerLine.isBlank => addHeader(headerLine)
case _ => changeState(ParsePartBody)
}
buffer.clear()
} else if (buffer.length >= bufferSize) {
throw new RuntimeException(
s"Reached max size of $bufferSize bytes before reaching newline when parsing header."
)
}
.completedParts
}

private def parsePartBody(currentByte: Int): Unit = {
bodySize += 1
val foundFinalBoundary = endMatcher.matchByte(currentByte.toByte)
val foundBoundary = boundaryMatcher.matchByte(currentByte.toByte)
convertStreamToFileIfThresholdMet()

(foundBoundary, foundFinalBoundary) match {
case (Match, _) => handleEndOfBody(boundaryMatcher.getDelimiter.length)
case (_, Match) => done = true; handleEndOfBody(endMatcher.getDelimiter.length)
case _ if endMatcher.noMatches && boundaryMatcher.noMatches => writeAllBytes()
case (NotMatched(x), NotMatched(y)) => writeUnmatchedBytes(Math.min(x, y))
}
}

def writeAllBytes(): Unit = {
buffer.foreach(byte => stream.underlying.write(byte.toInt))
buffer.clear()
}

def handleEndOfBody(delimiterLength: Int): Unit = {
val bytesToWrite = buffer.view.slice(0, buffer.length - delimiterLength)
bytesToWrite.foreach(byte => stream.underlying.write(byte.toInt))
val bodyInputStream = stream match {
case FileStream(tempFile, stream) => stream.close(); new FileInputStream(tempFile)
case ByteStream(stream) => new ByteArrayInputStream(stream.toByteArray)
}
completePart(bodyInputStream)
stream = ByteStream()
bodySize = 0
}

def writeUnmatchedBytes(unmatchedBytes: Int): Unit = {
val bytesToWrite = buffer.view.slice(0, unmatchedBytes)
bytesToWrite.foreach(byte => stream.underlying.write(byte.toInt))
buffer.dropInPlace(unmatchedBytes)
}

private def changeState(state: ParseStatus): Unit = {
buffer.clear()
this.parseState = state
}
private def addHeader(header: String): Unit = this.currentPart.addHeader(header)
private def completePart(body: InputStream): Unit = {
this.currentPart.getName match {
case Some(_) =>
this.completedParts = completedParts :+ currentPart.withBody(body)
this.currentPart = new ParsedMultiPart()
case None =>
}
this.changeState(ParsePartHeaders)
}

private sealed trait PartStream { val underlying: OutputStream }
private case class FileStream(tempFile: TapirFile, underlying: BufferedOutputStream) extends PartStream
private case class ByteStream(underlying: ByteArrayOutputStream = new ByteArrayOutputStream()) extends PartStream
private def convertStreamToFileIfThresholdMet(): Unit = stream match {
case ByteStream(os) if bodySize >= multipartFileThresholdBytes =>
val newFile = createTempFile()
val fileOutputStream = new FileOutputStream(newFile)
os.writeTo(fileOutputStream)
os.close()
stream = FileStream(newFile, new BufferedOutputStream(fileOutputStream))
case _ =>
}
}

def parseMultipartBody(inputStream: InputStream, boundary: String, multipartFileThresholdBytes: Long): Seq[ParsedMultiPart] = {
val boundaryBytes = boundary.getBytes
val state = new ParseState(boundaryBytes, multipartFileThresholdBytes)

while (!state.isDone) {
val currentByte = inputStream.read()
if (currentByte == -1)
throw new RuntimeException("Parsing multipart failed, ran out of bytes before finding boundary")
state.updateStateWith(currentByte)
}

state.completedParts
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package sttp.tapir.server.jdkhttp.internal

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import sttp.tapir.server.jdkhttp.internal.KMPMatcher.{Match, NotMatched}

class KMPMatcherTest extends AnyFlatSpec with Matchers {

it should "match over a set of bytes and not allow writing of any bytes if only matching" in {
val matchBytes = "--abc\r\n".getBytes
val matcher = new KMPMatcher(matchBytes)
matcher.matchByte('-'.toByte) shouldBe NotMatched(0)
matcher.matchByte('-'.toByte) shouldBe NotMatched(0)
matcher.matchByte('a'.toByte) shouldBe NotMatched(0)
matcher.matchByte('b'.toByte) shouldBe NotMatched(0)
matcher.matchByte('c'.toByte) shouldBe NotMatched(0)
matcher.matchByte('\r'.toByte) shouldBe NotMatched(0)
matcher.matchByte('\n'.toByte) shouldBe Match
}

it should "match over a set of bytes and allow writing of any non-matched bytes" in {
val matchBytes = "--abc\r\n".getBytes
val matcher = new KMPMatcher(matchBytes)
matcher.matchByte('-'.toByte) shouldBe NotMatched(0)
matcher.matchByte('-'.toByte) shouldBe NotMatched(0)
matcher.matchByte('-'.toByte) shouldBe NotMatched(1)
matcher.matchByte('a'.toByte) shouldBe NotMatched(0)
matcher.matchByte('a'.toByte) shouldBe NotMatched(3)
matcher.matchByte('-'.toByte) shouldBe NotMatched(0)
matcher.matchByte('-'.toByte) shouldBe NotMatched(0)
matcher.matchByte('a'.toByte) shouldBe NotMatched(0)
matcher.matchByte('b'.toByte) shouldBe NotMatched(0)
matcher.matchByte('c'.toByte) shouldBe NotMatched(0)
matcher.matchByte('\r'.toByte) shouldBe NotMatched(0)
matcher.matchByte('\n'.toByte) shouldBe Match
}
}
Loading

0 comments on commit 185b893

Please sign in to comment.