diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala index df25c4fefc..1fd67fa678 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala @@ -1,7 +1,11 @@ package sttp.tapir.server.netty.cats import cats.effect.IO +import cats.effect.kernel.Resource.ExitCase +import cats.effect.std.Queue +import cats.effect.unsafe.implicits.global import cats.syntax.all._ +import org.scalatest.EitherValues import org.scalatest.matchers.should.Matchers._ import sttp.capabilities.fs2.Fs2Streams import sttp.client3._ @@ -13,42 +17,48 @@ import sttp.tapir.{CodecFormat, _} import java.nio.charset.StandardCharsets import scala.concurrent.duration._ -import cats.effect.std.Queue -import cats.effect.unsafe.implicits.global -class NettyFs2StreamingCancellationTest[OPTIONS, ROUTE](createServerTest: CreateServerTest[IO, Fs2Streams[IO], OPTIONS, ROUTE]) { +class NettyFs2StreamingCancellationTest[OPTIONS, ROUTE](createServerTest: CreateServerTest[IO, Fs2Streams[IO], OPTIONS, ROUTE]) + extends EitherValues { import createServerTest._ implicit val m: MonadError[IO] = new CatsMonadError[IO]() + def tests(): List[Test] = List({ - val buffer = Queue.unbounded[IO, Byte].unsafeRunSync() + val buffer = Queue.unbounded[IO, Option[Byte]].unsafeRunSync() + + def readBuffer: IO[List[Byte]] = + fs2.Stream.fromQueueNoneTerminated(buffer).compile.toList + val body_20_slowly_emitted_bytes = - fs2.Stream.awakeEvery[IO](100.milliseconds).map(_ => 42.toByte).evalMap(b => { buffer.offer(b) >> IO.pure(b) }).take(100) + fs2.Stream + .awakeEvery[IO](100.milliseconds) + .map(_ => 42.toByte) + .onFinalizeCase { + case ExitCase.Canceled => buffer.offer(None) + case _ => IO.unit + } + testServer( endpoint.get .in("streamCanceled") .out(streamTextBody(Fs2Streams[IO])(CodecFormat.TextPlain(), Some(StandardCharsets.UTF_8))), "Client cancelling streaming triggers cancellation on the server" )(_ => pureResult(body_20_slowly_emitted_bytes.asRight[Unit])) { (backend, baseUri) => - - val expectedMaxAccumulated = 3 - + // How this test works: + // 1. The endpoint emits a byte continuously every 100 millis + // 2. The client connects and reads bytes, putting them in a buffer + // 3. The client cancels and disconnects after 1 second (using .timeout on the stream draining operation) + // 4. The endpoint logic reacts to cancelation and signals the end of the buffer (by putting a None in it) + // 5. The client tries to read all bytes from the buffer, which would fail with a timeout if the None element from point 4. wasn't triggered correctly basicRequest .get(uri"$baseUri/streamCanceled") + .response(asStreamUnsafe(Fs2Streams[IO])) .send(backend) - .timeout(300.millis) - .attempt >> - IO.sleep(600.millis) - .flatMap(_ => - buffer.size.flatMap(accumulated => - IO( - assert( - accumulated <= expectedMaxAccumulated, - s"Buffer accumulated $accumulated elements. Expected < $expectedMaxAccumulated due to cancellation." - ) - ) - ) - ) + .flatMap(_.body.value.evalMap(b => buffer.offer(Some(b))).compile.drain) + .timeout(1000.millis) + .attempt + .void >> readBuffer.timeout(5.seconds).map(bytes => assert(bytes.length > 1)) } }) } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index 88abed0a82..55b7df5162 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -190,10 +190,12 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( "empty client stream" )((_: Unit) => pureResult(emptyPipe.asRight[Unit])) { (backend, baseUri) => basicRequest - .response(asWebSocketAlways { (ws: WebSocket[IO]) => ws.eitherClose(ws.receiveText()) }) + .response(asWebSocketAlways { (ws: WebSocket[IO]) => + if (expectCloseResponse) ws.eitherClose(ws.receiveText()).map(Some(_)) else IO.pure(None) + }) .get(baseUri.scheme("ws")) .send(backend) - .map(_.body.left.map(_.statusCode) shouldBe Left(WebSocketFrame.close.statusCode)) + .map(r => assert(r.body.forall(_.left.map(_.statusCode) == Left(1000)))) }, testServer( endpoint