From 3f9ccc8b5e4d5941d53f08653bab2e8bc2ce9ef3 Mon Sep 17 00:00:00 2001 From: Adam Warski Date: Tue, 2 Jul 2024 15:10:23 +0200 Subject: [PATCH] Run ZIO HTTP middlewares only once (#3856) --- doc/server/ziohttp.md | 8 -- .../server/ziohttp/ZioHttpInterpreter.scala | 103 +++++++++++++----- .../server/ziohttp/ZioHttpServerTest.scala | 22 ++++ 3 files changed, 100 insertions(+), 33 deletions(-) diff --git a/doc/server/ziohttp.md b/doc/server/ziohttp.md index 10066e46ee..256e82b345 100644 --- a/doc/server/ziohttp.md +++ b/doc/server/ziohttp.md @@ -56,14 +56,6 @@ val countCharactersHttp: Routes[Any, Response] = ZioHttpInterpreter().toHttp(countCharactersEndpoint.zServerLogic(countCharacters)) ``` -```{note} -A single ZIO-Http application can contain both tapir-managed and ZIO-Http-managed routes. However, because of the -routing implementation in ZIO Http, the shape of the paths that tapir and other ZIO Http handlers serve should not -overlap. The shape of the path includes exact path segments, single- and multi-wildcards. Otherwise, request handling -will throw an exception. We don't expect users to encounter this as a problem, however the implementation here -diverges a bit comparing to other interpreters. -``` - ## Server logic When defining the business logic for an endpoint, the following methods are available, which replace the diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala index 87aa855748..2561b0a3bd 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala @@ -4,12 +4,15 @@ import sttp.capabilities.WebSockets import sttp.capabilities.zio.ZioStreams import sttp.model.{Header => SttpHeader} import sttp.monad.MonadError +import sttp.tapir.EndpointInput +import sttp.tapir.internal.RichEndpointInput import sttp.tapir.server.interceptor.RequestResult import sttp.tapir.server.interceptor.reject.RejectInterceptor -import sttp.tapir.server.interpreter.{FilterServerEndpoints, ServerInterpreter} +import sttp.tapir.server.interpreter.ServerInterpreter import sttp.tapir.server.model.ServerResponse import sttp.tapir.ztapir._ import zio._ +import zio.http.codec.PathCodec import zio.http.{Header => ZioHttpHeader, Headers => ZioHttpHeaders, _} trait ZioHttpInterpreter[R] { @@ -27,40 +30,90 @@ trait ZioHttpInterpreter[R] { val zioHttpResponseBody = new ZioHttpToResponseBody val interceptors = RejectInterceptor.disableWhenSingleEndpoint(widenedServerOptions.interceptors, widenedSes) - Routes.singleton { - Handler.fromFunctionZIO[(Path, Request)] { case (_: Path, request: Request) => + def handleRequest(req: Request, filteredEndpoints: List[ZServerEndpoint[R & R2, ZioStreams with WebSockets]]) = + Handler.fromZIO { val interpreter = new ServerInterpreter[ZioStreams with WebSockets, RIO[R & R2, *], ZioResponseBody, ZioStreams]( - FilterServerEndpoints(widenedSes), + _ => filteredEndpoints, zioHttpRequestBody, zioHttpResponseBody, interceptors, zioHttpServerOptions.deleteFile ) + val serverRequest = ZioHttpServerRequest(req) - if (request.url.encode.trim.isEmpty) { - ZIO.logError("Received an apparently empty request URI, not handling: " + request) *> - ZIO.fail(Response.internalServerError("Empty request URI")) - } else { - val serverRequest = ZioHttpServerRequest(request) - interpreter - .apply(serverRequest) - .foldCauseZIO( - cause => ZIO.logErrorCause(cause) *> ZIO.fail(Response.internalServerError(cause.squash.getMessage)), - { - case RequestResult.Response(resp) => - resp.body match { - case None => handleHttpResponse(resp, None) - case Some(Right(body)) => handleHttpResponse(resp, Some(body)) - case Some(Left(body)) => handleWebSocketResponse(body, zioHttpServerOptions.customWebSocketConfig(serverRequest)) - } - - case RequestResult.Failure(_) => - ZIO.fail(Response.notFound) - } - ) + interpreter + .apply(serverRequest) + .foldCauseZIO( + cause => ZIO.logErrorCause(cause) *> ZIO.fail(Response.internalServerError(cause.squash.getMessage)), + { + case RequestResult.Response(resp) => + resp.body match { + case None => handleHttpResponse(resp, None) + case Some(Right(body)) => handleHttpResponse(resp, Some(body)) + case Some(Left(body)) => handleWebSocketResponse(body, zioHttpServerOptions.customWebSocketConfig(serverRequest)) + } + + case RequestResult.Failure(_) => ZIO.succeed(Response.notFound) + } + ) + } + + // Grouping the endpoints by path prefix template (fixed path components & single path captures). This way, if + // there are multiple endpoints - with/without trailing slash, with from-request extraction, or with path wildcards, + // they will be interpreted and disambiguated by the tapir logic, instead of ZIO HTTP's routing. Also, this covers + // multiple endpoints with different methods, and allows us to handle invalid methods. + val widenedSesGroupedByPathPrefixTemplate = widenedSes.groupBy { se => + val e = se.endpoint + val inputs = e.securityInput.and(e.input).asVectorOfBasicInputs() + val x = inputs.foldLeft("") { case (p, component) => + component match { + case _: EndpointInput.PathCapture[_] => p + "/?" + case i: EndpointInput.FixedPath[_] => p + "/" + i.s + case _ => p } } + x + } + + val handlers: List[Route[R & R2, Response]] = widenedSesGroupedByPathPrefixTemplate.toList.map { case (_, sesForPathTemplate) => + // The pattern that we generate should be the same for all endpoints in a group + val e = sesForPathTemplate.head.endpoint + val inputs = e.securityInput.and(e.input).asVectorOfBasicInputs() + + val hasPath = inputs.exists { + case _: EndpointInput.PathCapture[_] => true + case _: EndpointInput.PathsCapture[_] => true + case _: EndpointInput.FixedPath[_] => true + case _ => false + } + + val pattern = if (hasPath) { + val initialPattern = RoutePattern(Method.ANY, PathCodec.empty).asInstanceOf[RoutePattern[Any]] + // The second tuple parameter specifies if PathCodec.trailing has already been added to the route's pattern - + // it can only be added once. It's possible that an endpoint contains both ExtractFromRequest & PathsCapture, + // which would cause it to be added twice. + inputs + .foldLeft((initialPattern, false)) { case ((p, trailingAdded), component) => + component match { + case i: EndpointInput.PathCapture[_] => + ((p / PathCodec.string(i.name.getOrElse("?"))).asInstanceOf[RoutePattern[Any]], trailingAdded) + case _: EndpointInput.ExtractFromRequest[_] if !trailingAdded => + ((p / PathCodec.trailing).asInstanceOf[RoutePattern[Any]], true) + case _: EndpointInput.PathsCapture[_] if !trailingAdded => ((p / PathCodec.trailing).asInstanceOf[RoutePattern[Any]], true) + case i: EndpointInput.FixedPath[_] => (p / PathCodec.literal(i.s), trailingAdded) + case _ => (p, trailingAdded) + } + } + ._1 + } else { + // if there are no path inputs, we return a catch-all + RoutePattern(Method.ANY, PathCodec.trailing).asInstanceOf[RoutePattern[Any]] + } + + Route.handled(pattern)(Handler.fromFunctionHandler { (request: Request) => handleRequest(request, sesForPathTemplate) }) } + + Routes(Chunk.fromIterable(handlers)) } private def handleWebSocketResponse( diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala index e462a81630..e7789272b4 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala @@ -141,6 +141,28 @@ class ZioHttpServerTest extends TestSuite { Unsafe.unsafe(implicit u => r.unsafe.runToFuture(test)) }, + Test("zio http middlewares only run once, with two endpoints") { + val test: UIO[Assertion] = for { + ref <- Ref.make("") + ep1 = endpoint.get.in("p1").out(stringBody).zServerLogic[Any](_ => ref.updateAndGet(_ + "1")) + ep2 = endpoint.get.in("p2").out(stringBody).zServerLogic[Any](_ => ref.updateAndGet(_ + "2")) + route = ZioHttpInterpreter().toHttp(ep1) ++ ZioHttpInterpreter().toHttp(ep2) + app = route @@ Middleware.allowZIO((_: Request) => ref.update(_ + "M").as(true)) + _ <- app + .runZIO(Request.get(url = URL(Path.empty / "p1"))) + .flatMap(response => response.body.asString) + .map(_ shouldBe "M1") + .catchAll(_ => ZIO.succeed(fail("Unable to extract body from Http response"))) + _ <- ref.set("") + result <- app + .runZIO(Request.get(url = URL(Path.empty / "p2"))) + .flatMap(response => response.body.asString) + .map(_ shouldBe "M2") + .catchAll(_ => ZIO.succeed(fail("Unable to extract body from Http response"))) + } yield result + + Unsafe.unsafe(implicit u => r.unsafe.runToFuture(test)) + }, // https://github.com/softwaremill/tapir/issues/2849 Test("Streaming works through the stub backend") { // given