Skip to content

Commit

Permalink
Merge pull request #3291 from kyri-petrou/zio-http-custom-ws-config
Browse files Browse the repository at this point in the history
Add option to provide custom WS config in zio-http
  • Loading branch information
adamw authored Nov 2, 2023
2 parents 4eb2647 + d0104d5 commit 21e13c4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,18 @@ trait ZioHttpInterpreter[R] {
zioHttpServerOptions.deleteFile
)

val serverRequest = ZioHttpServerRequest(request)

interpreter
.apply(ZioHttpServerRequest(request))
.apply(serverRequest)
.foldCauseZIO(
cause => ZIO.logErrorCause(cause) *> ZIO.fail(Response.internalServerError(cause.squash.getMessage)),
{
case RequestResult.Response(resp) =>
resp.body match {
case None => handleHttpResponse(resp, None)
case Some(Right(body)) => handleHttpResponse(resp, Some(body))
case Some(Left(body)) => handleWebSocketResponse(body)
case Some(Left(body)) => handleWebSocketResponse(body, zioHttpServerOptions.customWebSocketConfig(serverRequest))
}

case RequestResult.Failure(_) =>
Expand All @@ -63,15 +65,19 @@ trait ZioHttpInterpreter[R] {
)
}

private def handleWebSocketResponse(webSocketHandler: WebSocketHandler): ZIO[Any, Nothing, Response] = {
Handler.webSocket { channel =>
private def handleWebSocketResponse(
webSocketHandler: WebSocketHandler,
webSocketConfig: Option[WebSocketConfig]
): ZIO[Any, Nothing, Response] = {
val app = Handler.webSocket { channel =>
for {
channelEventsQueue <- zio.Queue.unbounded[WebSocketChannelEvent]
messageReceptionFiber <- channel.receiveAll { message => channelEventsQueue.offer(message) }.fork
webSocketStream <- webSocketHandler(stream.ZStream.fromQueue(channelEventsQueue))
_ <- webSocketStream.mapZIO(channel.send).runDrain
} yield messageReceptionFiber.join
}.toResponse
}
webSocketConfig.fold(app)(app.withConfig).toResponse
}

private def handleHttpResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@ import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interceptor.log.DefaultServerLog
import sttp.tapir.server.interceptor.{CustomiseInterceptors, Interceptor}
import sttp.tapir.{Defaults, TapirFile}
import zio.http.WebSocketConfig
import zio.{Cause, RIO, Task, ZIO}

case class ZioHttpServerOptions[R](
createFile: ServerRequest => Task[TapirFile],
deleteFile: TapirFile => RIO[R, Unit],
interceptors: List[Interceptor[RIO[R, *]]]
interceptors: List[Interceptor[RIO[R, *]]],
customWebSocketConfig: ServerRequest => Option[WebSocketConfig]
) {
def prependInterceptor(i: Interceptor[RIO[R, *]]): ZioHttpServerOptions[R] =
copy(interceptors = i :: interceptors)
def appendInterceptor(i: Interceptor[RIO[R, *]]): ZioHttpServerOptions[R] =
copy(interceptors = interceptors :+ i)
def withCustomWebSocketConfig(f: ServerRequest => WebSocketConfig): ZioHttpServerOptions[R] =
copy(customWebSocketConfig = f.andThen(Some(_)))

def widen[R2 <: R]: ZioHttpServerOptions[R2] = this.asInstanceOf[ZioHttpServerOptions[R2]]
}
Expand All @@ -28,7 +32,8 @@ object ZioHttpServerOptions {
ZioHttpServerOptions(
defaultCreateFile,
defaultDeleteFile,
ci.interceptors
ci.interceptors,
_ => None
)
).serverLog(defaultServerLog[R])

Expand Down

0 comments on commit 21e13c4

Please sign in to comment.