diff --git a/jetty/src/main/scala/com/avsystem/commons/jetty/rpc/JettyRPCFramework.scala b/jetty/src/main/scala/com/avsystem/commons/jetty/rpc/JettyRPCFramework.scala index 3a938797e..19c58bd69 100644 --- a/jetty/src/main/scala/com/avsystem/commons/jetty/rpc/JettyRPCFramework.scala +++ b/jetty/src/main/scala/com/avsystem/commons/jetty/rpc/JettyRPCFramework.scala @@ -1,21 +1,20 @@ package com.avsystem.commons package jetty.rpc -import java.nio.charset.StandardCharsets import com.avsystem.commons.rpc.StandardRPCFramework import com.avsystem.commons.serialization.json.{JsonStringInput, JsonStringOutput, RawJson} import com.avsystem.commons.serialization.{GenCodec, HasGenCodec} import com.typesafe.scalalogging.LazyLogging - -import javax.servlet.http.{HttpServletRequest, HttpServletResponse} -import org.eclipse.jetty.client.HttpClient -import org.eclipse.jetty.client.api.Result -import org.eclipse.jetty.client.util.{BufferingResponseListener, StringContentProvider, StringRequestContent} +import jakarta.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} +import org.eclipse.jetty.client.{BufferingResponseListener, HttpClient, Result, StringRequestContent} +import org.eclipse.jetty.ee10.servlet.ServletContextHandler import org.eclipse.jetty.http.{HttpMethod, HttpStatus, MimeTypes} -import org.eclipse.jetty.server.handler.AbstractHandler -import org.eclipse.jetty.server.{Handler, Request} +import org.eclipse.jetty.server.Handler -import scala.concurrent.duration._ +import java.nio.charset.StandardCharsets +import java.util.concurrent.atomic.AtomicBoolean +import scala.concurrent.duration.* +import scala.util.Using object JettyRPCFramework extends StandardRPCFramework with LazyLogging { class RawValue(val s: String) extends AnyVal @@ -89,20 +88,28 @@ object JettyRPCFramework extends StandardRPCFramework with LazyLogging { request(HttpMethod.PUT, call) } - class RPCHandler(rootRpc: RawRPC, contextTimeout: FiniteDuration) extends AbstractHandler { - override def handle(target: String, baseRequest: Request, request: HttpServletRequest, response: HttpServletResponse): Unit = { - baseRequest.setHandled(true) - - val content = Iterator.continually(request.getReader.readLine()) - .takeWhile(_ != null) - .mkString("\n") - - val call = read[Call](new RawValue(content)) + class RPCHandler(rootRpc: RawRPC, contextTimeout: FiniteDuration) extends HttpServlet { + override def service(request: HttpServletRequest, response: HttpServletResponse): Unit = { + // readRequest must execute in request thread but we want exceptions to be handled uniformly, hence the Try + val content = + Using(request.getReader)(reader => + Iterator.continually(reader.readLine()).takeWhile(_ != null).mkString("\n") + ) + val call = content.map(content => read[Call](new RawValue(content))) HttpMethod.fromString(request.getMethod) match { case HttpMethod.POST => - val async = request.startAsync().setup(_.setTimeout(contextTimeout.toMillis)) - handlePost(call).andThenNow { + val asyncContext = request.startAsync().setup(_.setTimeout(contextTimeout.toMillis)) + val completed = new AtomicBoolean(false) + // Need to protect asyncContext from being completed twice because after a timeout the + // servlet may recycle the same context instance between subsequent requests (not cool) + // https://stackoverflow.com/a/27744537 + def completeWith(code: => Unit): Unit = + if (!completed.getAndSet(true)) { + code + asyncContext.complete() + } + completeWith(Future.fromTry(call).flatMapNow(handlePost).andThenNow { case Success(responseContent) => response.setContentType(MimeTypes.Type.APPLICATION_JSON.asString()) response.setCharacterEncoding(StandardCharsets.UTF_8.name()) @@ -110,9 +117,9 @@ object JettyRPCFramework extends StandardRPCFramework with LazyLogging { case Failure(t) => response.sendError(HttpStatus.INTERNAL_SERVER_ERROR_500, t.getMessage) logger.error("Failed to handle RPC call", t) - }.andThenNow { case _ => async.complete() } + }) case HttpMethod.PUT => - handlePut(call) + call.map(handlePut).get case _ => throw new IllegalArgumentException(s"Request HTTP method is ${request.getMethod}, only POST or PUT are supported") } @@ -132,11 +139,12 @@ object JettyRPCFramework extends StandardRPCFramework with LazyLogging { invoke(call)(_.fire) } - def newHandler[T](impl: T, contextTimeout: FiniteDuration = 30.seconds)( - implicit asRawRPC: AsRawRPC[T]): Handler = - new RPCHandler(asRawRPC.asRaw(impl), contextTimeout) + def newServlet[T: AsRawRPC](impl: T, contextTimeout: FiniteDuration = 30.seconds): HttpServlet = + new RPCHandler(AsRawRPC[T].asRaw(impl), contextTimeout) + + def newHandler[T: AsRawRPC](impl: T, contextTimeout: FiniteDuration = 30.seconds): Handler = + new ServletContextHandler().setup(_.addServlet(newServlet(impl, contextTimeout), "/*")) - def newClient[T](httpClient: HttpClient, uri: String, maxResponseLength: Int = 2 * 1024 * 1024)( - implicit asRealRPC: AsRealRPC[T]): T = - asRealRPC.asReal(new RPCClient(httpClient, uri, maxResponseLength).rawRPC) + def newClient[T: AsRealRPC](httpClient: HttpClient, uri: String, maxResponseLength: Int = 2 * 1024 * 1024): T = + AsRealRPC[T].asReal(new RPCClient(httpClient, uri, maxResponseLength).rawRPC) } diff --git a/project/Commons.scala b/project/Commons.scala index d20714536..00b0aa93e 100644 --- a/project/Commons.scala +++ b/project/Commons.scala @@ -29,7 +29,7 @@ object Commons extends ProjectGroup("commons") { val scalatestVersion = "3.2.19" val scalatestplusScalacheckVersion = "3.2.14.0" val scalacheckVersion = "1.18.0" - val jettyVersion = "10.0.22" + val jettyVersion = "12.0.12" val mongoVersion = "5.1.2" val springVersion = "5.3.37" val typesafeConfigVersion = "1.4.3" @@ -365,10 +365,8 @@ object Commons extends ProjectGroup("commons") { jvmCommonSettings, libraryDependencies ++= Seq( "org.eclipse.jetty" % "jetty-client" % jettyVersion, - "org.eclipse.jetty" % "jetty-server" % jettyVersion, + "org.eclipse.jetty.ee10" % "jetty-ee10-servlet" % jettyVersion, "com.typesafe.scala-logging" %% "scala-logging" % scalaLoggingVersion, - - "org.eclipse.jetty" % "jetty-servlet" % jettyVersion % Test, ), )