Skip to content

Commit

Permalink
Run ZIO HTTP middlewares only once (#3856)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw authored Jul 2, 2024
1 parent 76a4667 commit 3f9ccc8
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 33 deletions.
8 changes: 0 additions & 8 deletions doc/server/ziohttp.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,6 @@ val countCharactersHttp: Routes[Any, Response] =
ZioHttpInterpreter().toHttp(countCharactersEndpoint.zServerLogic(countCharacters))
```

```{note}
A single ZIO-Http application can contain both tapir-managed and ZIO-Http-managed routes. However, because of the
routing implementation in ZIO Http, the shape of the paths that tapir and other ZIO Http handlers serve should not
overlap. The shape of the path includes exact path segments, single- and multi-wildcards. Otherwise, request handling
will throw an exception. We don't expect users to encounter this as a problem, however the implementation here
diverges a bit comparing to other interpreters.
```

## Server logic

When defining the business logic for an endpoint, the following methods are available, which replace the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ import sttp.capabilities.WebSockets
import sttp.capabilities.zio.ZioStreams
import sttp.model.{Header => SttpHeader}
import sttp.monad.MonadError
import sttp.tapir.EndpointInput
import sttp.tapir.internal.RichEndpointInput
import sttp.tapir.server.interceptor.RequestResult
import sttp.tapir.server.interceptor.reject.RejectInterceptor
import sttp.tapir.server.interpreter.{FilterServerEndpoints, ServerInterpreter}
import sttp.tapir.server.interpreter.ServerInterpreter
import sttp.tapir.server.model.ServerResponse
import sttp.tapir.ztapir._
import zio._
import zio.http.codec.PathCodec
import zio.http.{Header => ZioHttpHeader, Headers => ZioHttpHeaders, _}

trait ZioHttpInterpreter[R] {
Expand All @@ -27,40 +30,90 @@ trait ZioHttpInterpreter[R] {
val zioHttpResponseBody = new ZioHttpToResponseBody
val interceptors = RejectInterceptor.disableWhenSingleEndpoint(widenedServerOptions.interceptors, widenedSes)

Routes.singleton {
Handler.fromFunctionZIO[(Path, Request)] { case (_: Path, request: Request) =>
def handleRequest(req: Request, filteredEndpoints: List[ZServerEndpoint[R & R2, ZioStreams with WebSockets]]) =
Handler.fromZIO {
val interpreter = new ServerInterpreter[ZioStreams with WebSockets, RIO[R & R2, *], ZioResponseBody, ZioStreams](
FilterServerEndpoints(widenedSes),
_ => filteredEndpoints,
zioHttpRequestBody,
zioHttpResponseBody,
interceptors,
zioHttpServerOptions.deleteFile
)
val serverRequest = ZioHttpServerRequest(req)

if (request.url.encode.trim.isEmpty) {
ZIO.logError("Received an apparently empty request URI, not handling: " + request) *>
ZIO.fail(Response.internalServerError("Empty request URI"))
} else {
val serverRequest = ZioHttpServerRequest(request)
interpreter
.apply(serverRequest)
.foldCauseZIO(
cause => ZIO.logErrorCause(cause) *> ZIO.fail(Response.internalServerError(cause.squash.getMessage)),
{
case RequestResult.Response(resp) =>
resp.body match {
case None => handleHttpResponse(resp, None)
case Some(Right(body)) => handleHttpResponse(resp, Some(body))
case Some(Left(body)) => handleWebSocketResponse(body, zioHttpServerOptions.customWebSocketConfig(serverRequest))
}

case RequestResult.Failure(_) =>
ZIO.fail(Response.notFound)
}
)
interpreter
.apply(serverRequest)
.foldCauseZIO(
cause => ZIO.logErrorCause(cause) *> ZIO.fail(Response.internalServerError(cause.squash.getMessage)),
{
case RequestResult.Response(resp) =>
resp.body match {
case None => handleHttpResponse(resp, None)
case Some(Right(body)) => handleHttpResponse(resp, Some(body))
case Some(Left(body)) => handleWebSocketResponse(body, zioHttpServerOptions.customWebSocketConfig(serverRequest))
}

case RequestResult.Failure(_) => ZIO.succeed(Response.notFound)
}
)
}

// Grouping the endpoints by path prefix template (fixed path components & single path captures). This way, if
// there are multiple endpoints - with/without trailing slash, with from-request extraction, or with path wildcards,
// they will be interpreted and disambiguated by the tapir logic, instead of ZIO HTTP's routing. Also, this covers
// multiple endpoints with different methods, and allows us to handle invalid methods.
val widenedSesGroupedByPathPrefixTemplate = widenedSes.groupBy { se =>
val e = se.endpoint
val inputs = e.securityInput.and(e.input).asVectorOfBasicInputs()
val x = inputs.foldLeft("") { case (p, component) =>
component match {
case _: EndpointInput.PathCapture[_] => p + "/?"
case i: EndpointInput.FixedPath[_] => p + "/" + i.s
case _ => p
}
}
x
}

val handlers: List[Route[R & R2, Response]] = widenedSesGroupedByPathPrefixTemplate.toList.map { case (_, sesForPathTemplate) =>
// The pattern that we generate should be the same for all endpoints in a group
val e = sesForPathTemplate.head.endpoint
val inputs = e.securityInput.and(e.input).asVectorOfBasicInputs()

val hasPath = inputs.exists {
case _: EndpointInput.PathCapture[_] => true
case _: EndpointInput.PathsCapture[_] => true
case _: EndpointInput.FixedPath[_] => true
case _ => false
}

val pattern = if (hasPath) {
val initialPattern = RoutePattern(Method.ANY, PathCodec.empty).asInstanceOf[RoutePattern[Any]]
// The second tuple parameter specifies if PathCodec.trailing has already been added to the route's pattern -
// it can only be added once. It's possible that an endpoint contains both ExtractFromRequest & PathsCapture,
// which would cause it to be added twice.
inputs
.foldLeft((initialPattern, false)) { case ((p, trailingAdded), component) =>
component match {
case i: EndpointInput.PathCapture[_] =>
((p / PathCodec.string(i.name.getOrElse("?"))).asInstanceOf[RoutePattern[Any]], trailingAdded)
case _: EndpointInput.ExtractFromRequest[_] if !trailingAdded =>
((p / PathCodec.trailing).asInstanceOf[RoutePattern[Any]], true)
case _: EndpointInput.PathsCapture[_] if !trailingAdded => ((p / PathCodec.trailing).asInstanceOf[RoutePattern[Any]], true)
case i: EndpointInput.FixedPath[_] => (p / PathCodec.literal(i.s), trailingAdded)
case _ => (p, trailingAdded)
}
}
._1
} else {
// if there are no path inputs, we return a catch-all
RoutePattern(Method.ANY, PathCodec.trailing).asInstanceOf[RoutePattern[Any]]
}

Route.handled(pattern)(Handler.fromFunctionHandler { (request: Request) => handleRequest(request, sesForPathTemplate) })
}

Routes(Chunk.fromIterable(handlers))
}

private def handleWebSocketResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,28 @@ class ZioHttpServerTest extends TestSuite {

Unsafe.unsafe(implicit u => r.unsafe.runToFuture(test))
},
Test("zio http middlewares only run once, with two endpoints") {
val test: UIO[Assertion] = for {
ref <- Ref.make("")
ep1 = endpoint.get.in("p1").out(stringBody).zServerLogic[Any](_ => ref.updateAndGet(_ + "1"))
ep2 = endpoint.get.in("p2").out(stringBody).zServerLogic[Any](_ => ref.updateAndGet(_ + "2"))
route = ZioHttpInterpreter().toHttp(ep1) ++ ZioHttpInterpreter().toHttp(ep2)
app = route @@ Middleware.allowZIO((_: Request) => ref.update(_ + "M").as(true))
_ <- app
.runZIO(Request.get(url = URL(Path.empty / "p1")))
.flatMap(response => response.body.asString)
.map(_ shouldBe "M1")
.catchAll(_ => ZIO.succeed(fail("Unable to extract body from Http response")))
_ <- ref.set("")
result <- app
.runZIO(Request.get(url = URL(Path.empty / "p2")))
.flatMap(response => response.body.asString)
.map(_ shouldBe "M2")
.catchAll(_ => ZIO.succeed(fail("Unable to extract body from Http response")))
} yield result

Unsafe.unsafe(implicit u => r.unsafe.runToFuture(test))
},
// https://github.com/softwaremill/tapir/issues/2849
Test("Streaming works through the stub backend") {
// given
Expand Down

0 comments on commit 3f9ccc8

Please sign in to comment.