From 97de28d106fcc14f0550fdfa027768da977f9101 Mon Sep 17 00:00:00 2001 From: Arman Bilge Date: Mon, 15 Jan 2024 05:38:23 +0000 Subject: [PATCH 1/3] Implement workaround for init order errors in JVM lambdas --- build.sbt | 1 + .../scala/feral/lambda/IOLambdaPlatform.scala | 27 ++++++++++++++++++- .../scala/feral/lambda/IOLambdaJvmSuite.scala | 25 +++++++++++++++++ 3 files changed, 52 insertions(+), 1 deletion(-) diff --git a/build.sbt b/build.sbt index 76e8ba15..46985906 100644 --- a/build.sbt +++ b/build.sbt @@ -109,6 +109,7 @@ lazy val lambda = crossProject(JSPlatform, JVMPlatform) ) ) .jvmSettings( + Test / fork := true, libraryDependencies ++= Seq( "com.amazonaws" % "aws-lambda-java-core" % "1.2.3", "co.fs2" %%% "fs2-io" % fs2Version diff --git a/lambda/jvm/src/main/scala/feral/lambda/IOLambdaPlatform.scala b/lambda/jvm/src/main/scala/feral/lambda/IOLambdaPlatform.scala index 8c82a6e6..0adf5276 100644 --- a/lambda/jvm/src/main/scala/feral/lambda/IOLambdaPlatform.scala +++ b/lambda/jvm/src/main/scala/feral/lambda/IOLambdaPlatform.scala @@ -16,8 +16,11 @@ package feral.lambda +import cats.effect.Async import cats.effect.IO +import cats.effect.Resource import cats.effect.std.Dispatcher +import cats.effect.syntax.all._ import cats.syntax.all._ import com.amazonaws.services.lambda.{runtime => lambdaRuntime} import io.circe.Printer @@ -29,11 +32,33 @@ import java.io.OutputStream import java.io.OutputStreamWriter import java.nio.channels.Channels import scala.concurrent.duration._ +import scala.util.control.NonFatal private[lambda] abstract class IOLambdaPlatform[Event, Result] extends lambdaRuntime.RequestStreamHandler { this: IOLambda[Event, Result] => private[this] val (dispatcher, handle) = { + val handler = { + val h = + try this.handler + catch { case ex if NonFatal(ex) => null } + + if (h ne null) { h.map(IO.pure(_)) } + else { + val lambdaName = getClass().getSimpleName() + val msg = + s"""|There was an error initializing `$lambdaName` during startup. + |Falling back to initialize-during-first-invocation strategy. + |To fix, try replacing any `val`s in `$lambdaName` with `def`s.""".stripMargin + System.err.println(msg) + + Async[Resource[IO, *]].defer(this.handler).memoize.map { + case Resource.Eval(ioa) => ioa + case _ => throw new AssertionError + } + } + } + Dispatcher .parallel[IO](await = false) .product(handler) @@ -50,7 +75,7 @@ private[lambda] abstract class IOLambdaPlatform[Event, Result] val context = Context.fromJava[IO](runtimeContext) dispatcher .unsafeRunTimed( - handle(Invocation.pure(event, context)), + handle.flatMap(_(Invocation.pure(event, context))), runtimeContext.getRemainingTimeInMillis().millis ) .foreach { result => diff --git a/lambda/jvm/src/test/scala/feral/lambda/IOLambdaJvmSuite.scala b/lambda/jvm/src/test/scala/feral/lambda/IOLambdaJvmSuite.scala index dfc0e438..f9840af1 100644 --- a/lambda/jvm/src/test/scala/feral/lambda/IOLambdaJvmSuite.scala +++ b/lambda/jvm/src/test/scala/feral/lambda/IOLambdaJvmSuite.scala @@ -83,6 +83,31 @@ class IOLambdaJvmSuite extends FunSuite { ) } + test("gracefully handles broken initialization due to `val`") { + val os = new ByteArrayOutputStream + + val lambda1 = new IOLambda[Unit, Unit] { + val handler = Resource.pure(_ => IO(None)) + } + + lambda1.handleRequest( + new ByteArrayInputStream("{}".getBytes()), + os, + DummyContext + ) + + val lambda2 = new IOLambda[Unit, Unit] { + def handler = resource.as(_ => IO(None)) + val resource = Resource.unit[IO] + } + + lambda2.handleRequest( + new ByteArrayInputStream("{}".getBytes()), + os, + DummyContext + ) + } + object DummyContext extends runtime.Context { override def getAwsRequestId(): String = "" override def getLogGroupName(): String = "" From d49477d0abcf7ec40c17450c436f093061d19caa Mon Sep 17 00:00:00 2001 From: Arman Bilge Date: Mon, 15 Jan 2024 14:20:34 +0000 Subject: [PATCH 2/3] More rigorous assertions about when init occurs --- .../scala/feral/lambda/IOLambdaJvmSuite.scala | 74 +++++++++---------- 1 file changed, 36 insertions(+), 38 deletions(-) diff --git a/lambda/jvm/src/test/scala/feral/lambda/IOLambdaJvmSuite.scala b/lambda/jvm/src/test/scala/feral/lambda/IOLambdaJvmSuite.scala index f9840af1..6f2246d9 100644 --- a/lambda/jvm/src/test/scala/feral/lambda/IOLambdaJvmSuite.scala +++ b/lambda/jvm/src/test/scala/feral/lambda/IOLambdaJvmSuite.scala @@ -31,7 +31,19 @@ import java.util.concurrent.atomic.AtomicInteger class IOLambdaJvmSuite extends FunSuite { - test("initializes handler once") { + implicit class HandleOps[A, B](lambda: IOLambda[A, B]) { + def handleRequestHelper(in: String): String = { + val os = new ByteArrayOutputStream + lambda.handleRequest( + new ByteArrayInputStream(in.getBytes()), + os, + DummyContext + ) + new String(os.toByteArray()) + } + } + + test("initializes handler once during construction") { val allocationCounter = new AtomicInteger val invokeCounter = new AtomicInteger @@ -41,18 +53,12 @@ class IOLambdaJvmSuite extends FunSuite { .as(_.event.map(Some(_)) <* IO(invokeCounter.getAndIncrement())) } + assertEquals(allocationCounter.get(), 1) + val chars = 'A' to 'Z' chars.foreach { c => - val os = new ByteArrayOutputStream - val json = s""""$c"""" - lambda.handleRequest( - new ByteArrayInputStream(json.getBytes()), - os, - DummyContext - ) - - assertEquals(new String(os.toByteArray()), json) + assertEquals(lambda.handleRequestHelper(json), json) } assertEquals(allocationCounter.get(), 1) @@ -68,44 +74,36 @@ class IOLambdaJvmSuite extends FunSuite { def handler = Resource.pure(_ => IO(Some(output))) } - val os = new ByteArrayOutputStream - - lambda.handleRequest( - new ByteArrayInputStream(input.toString.getBytes()), - os, - DummyContext - ) - assertEquals( - jawn.parseByteArray(os.toByteArray()), - Right(output), - new String(os.toByteArray()) + jawn.parse(lambda.handleRequestHelper(input.noSpaces)), + Right(output) ) } test("gracefully handles broken initialization due to `val`") { - val os = new ByteArrayOutputStream - val lambda1 = new IOLambda[Unit, Unit] { - val handler = Resource.pure(_ => IO(None)) + def go(mkLambda: AtomicInteger => IOLambda[Unit, Unit]): Unit = { + val counter = new AtomicInteger + val lambda = mkLambda(counter) + assertEquals(counter.get(), 0) // init failed + lambda.handleRequestHelper("{}") + assertEquals(counter.get(), 1) // inited + lambda.handleRequestHelper("{}") + assertEquals(counter.get(), 1) // did not re-init } - lambda1.handleRequest( - new ByteArrayInputStream("{}".getBytes()), - os, - DummyContext - ) - - val lambda2 = new IOLambda[Unit, Unit] { - def handler = resource.as(_ => IO(None)) - val resource = Resource.unit[IO] + go { counter => + new IOLambda[Unit, Unit] { + val handler = Resource.eval(IO(counter.getAndIncrement())).as(_ => IO(None)) + } } - lambda2.handleRequest( - new ByteArrayInputStream("{}".getBytes()), - os, - DummyContext - ) + go { counter => + new IOLambda[Unit, Unit] { + def handler = resource.as(_ => IO(None)) + val resource = Resource.eval(IO(counter.getAndIncrement())) + } + } } object DummyContext extends runtime.Context { From b6463d616b21d491b26760abfbe15ffb77fc2cbe Mon Sep 17 00:00:00 2001 From: Arman Bilge Date: Mon, 15 Jan 2024 18:37:14 +0000 Subject: [PATCH 3/3] Use a safer `memoize` workaround --- .../src/main/scala/feral/lambda/IOLambdaPlatform.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/lambda/jvm/src/main/scala/feral/lambda/IOLambdaPlatform.scala b/lambda/jvm/src/main/scala/feral/lambda/IOLambdaPlatform.scala index 0adf5276..163e2efa 100644 --- a/lambda/jvm/src/main/scala/feral/lambda/IOLambdaPlatform.scala +++ b/lambda/jvm/src/main/scala/feral/lambda/IOLambdaPlatform.scala @@ -43,8 +43,9 @@ private[lambda] abstract class IOLambdaPlatform[Event, Result] try this.handler catch { case ex if NonFatal(ex) => null } - if (h ne null) { h.map(IO.pure(_)) } - else { + if (h ne null) { + h.map(IO.pure(_)) + } else { val lambdaName = getClass().getSimpleName() val msg = s"""|There was an error initializing `$lambdaName` during startup. @@ -52,10 +53,7 @@ private[lambda] abstract class IOLambdaPlatform[Event, Result] |To fix, try replacing any `val`s in `$lambdaName` with `def`s.""".stripMargin System.err.println(msg) - Async[Resource[IO, *]].defer(this.handler).memoize.map { - case Resource.Eval(ioa) => ioa - case _ => throw new AssertionError - } + Async[Resource[IO, *]].defer(this.handler).memoize.map(_.allocated.map(_._1)) } }