From 2774d560cf9e39ddd4881509f2575a52ac9be5ea Mon Sep 17 00:00:00 2001
From: tangjiafu <jiafu.tang@qq.com>
Date: Wed, 17 Apr 2024 16:42:16 +0800
Subject: [PATCH 1/3] share trailers through matval

---
 .../PekkoNettyGrpcClientGraphStage.scala      | 42 +++++++++++--------
 1 file changed, 25 insertions(+), 17 deletions(-)

diff --git a/runtime/src/main/scala/org/apache/pekko/grpc/internal/PekkoNettyGrpcClientGraphStage.scala b/runtime/src/main/scala/org/apache/pekko/grpc/internal/PekkoNettyGrpcClientGraphStage.scala
index e129b5b3..d2d33de9 100644
--- a/runtime/src/main/scala/org/apache/pekko/grpc/internal/PekkoNettyGrpcClientGraphStage.scala
+++ b/runtime/src/main/scala/org/apache/pekko/grpc/internal/PekkoNettyGrpcClientGraphStage.scala
@@ -16,7 +16,7 @@ package org.apache.pekko.grpc.internal
 import org.apache.pekko
 import pekko.annotation.InternalApi
 import pekko.dispatch.ExecutionContexts
-import pekko.grpc.GrpcResponseMetadata
+import pekko.grpc.{ GrpcResponseMetadata, GrpcServiceException }
 import pekko.stream
 import pekko.stream.{ Attributes => _, _ }
 import pekko.stream.stage._
@@ -24,12 +24,13 @@ import pekko.util.FutureConverters._
 import io.grpc._
 
 import scala.concurrent.{ Future, Promise }
+import scala.util.Success
 
 @InternalApi
 private object PekkoNettyGrpcClientGraphStage {
   sealed trait ControlMessage
   case object ReadyForSending extends ControlMessage
-  case class Closed(status: Status, trailer: Metadata) extends ControlMessage
+  case class Closed(status: Status) extends ControlMessage
 }
 
 /**
@@ -64,8 +65,6 @@ private final class PekkoNettyGrpcClientGraphStage[I, O](
       inheritedAttributes: stream.Attributes): (GraphStageLogic, Future[GrpcResponseMetadata]) = {
     import PekkoNettyGrpcClientGraphStage._
     val matVal = Promise[GrpcResponseMetadata]()
-    val trailerPromise = Promise[Metadata]()
-
     val logic = new GraphStageLogic(shape) with InHandler with OutHandler {
       // this is here just to fail single response requests getting more responses
       // duplicating behavior in io.grpc.stub.ClientCalls
@@ -76,8 +75,8 @@ private final class PekkoNettyGrpcClientGraphStage[I, O](
       val callback = getAsyncCallback[Any] {
         case msg: ControlMessage =>
           msg match {
-            case ReadyForSending         => if (!isClosed(in) && !hasBeenPulled(in)) tryPull(in)
-            case Closed(status, trailer) => onCallClosed(status, trailer)
+            case ReadyForSending => if (!isClosed(in) && !hasBeenPulled(in)) tryPull(in)
+            case Closed(status)  => onCallClosed(status)
           }
         case element: O @unchecked =>
           if (!streamingResponse) {
@@ -94,29 +93,32 @@ private final class PekkoNettyGrpcClientGraphStage[I, O](
       val listener = new ClientCall.Listener[O] {
         override def onReady(): Unit =
           callback.invoke(ReadyForSending)
-        override def onHeaders(responseHeaders: Metadata): Unit =
+
+        override def onHeaders(responseHeaders: Metadata): Unit = {
           matVal.success(new GrpcResponseMetadata {
             private lazy val sMetadata = MetadataImpl.scalaMetadataFromGoogleGrpcMetadata(responseHeaders)
             private lazy val jMetadata = MetadataImpl.javaMetadataFromGoogleGrpcMetadata(responseHeaders)
             def headers = sMetadata
             def getHeaders() = jMetadata
 
-            private lazy val sTrailers =
-              trailerPromise.future.map(MetadataImpl.scalaMetadataFromGoogleGrpcMetadata)(ExecutionContexts.parasitic)
-            private lazy val jTrailers = trailerPromise.future
-              .map(MetadataImpl.javaMetadataFromGoogleGrpcMetadata)(ExecutionContexts.parasitic)
-              .asJava
+            private lazy val sTrailers = Future.successful(sMetadata)
+            private lazy val jTrailers = Future.successful(jMetadata).asJava
             def trailers = sTrailers
             def getTrailers() = jTrailers
           })
+        }
+
         override def onMessage(message: O): Unit =
           callback.invoke(message)
+
         override def onClose(status: Status, trailers: Metadata): Unit = {
-          trailerPromise.success(trailers)
-          callback.invoke(Closed(status, trailers))
+          onHeaders(trailers)
+          callback.invoke(Closed(status))
         }
       }
+
       override def preStart(): Unit = {
+
         call = channel.newCall(descriptor, options)
         call.start(listener, headers.toGoogleGrpcMetadata())
 
@@ -134,12 +136,14 @@ private final class PekkoNettyGrpcClientGraphStage[I, O](
         // request so pull early to get things going
         pull(in)
       }
+
       override def onPush(): Unit = {
         call.sendMessage(grab(in))
         if (call.isReady && !hasBeenPulled(in)) {
           pull(in)
         }
       }
+
       override def onUpstreamFinish(): Unit = {
         call.halfClose()
         if (isClosed(out)) {
@@ -148,6 +152,7 @@ private final class PekkoNettyGrpcClientGraphStage[I, O](
           completeStage()
         }
       }
+
       override def onUpstreamFailure(ex: Throwable): Unit = {
         call.cancel("Failure from upstream", ex)
         call = null
@@ -159,6 +164,7 @@ private final class PekkoNettyGrpcClientGraphStage[I, O](
           call.request(1)
           requested += 1
         }
+
       override def onDownstreamFinish(cause: Throwable): Unit =
         if (isClosed(out)) {
           call.cancel("Downstream cancelled", cause)
@@ -166,12 +172,14 @@ private final class PekkoNettyGrpcClientGraphStage[I, O](
           completeStage()
         }
 
-      def onCallClosed(status: Status, trailers: Metadata): Unit = {
+      def onCallClosed(status: Status): Unit = {
         if (status.isOk()) {
-          // FIXME share trailers through matval
           completeStage()
         } else {
-          failStage(status.asRuntimeException(trailers))
+          matVal.future.onComplete {
+            case Success(metadata) => failStage(new GrpcServiceException(status, metadata.headers))
+            case _                 => failStage(new GrpcServiceException(status))
+          }(ExecutionContexts.parasitic)
         }
         call = null
       }

From 4038f02a0aeff188fab660a05d8326095151d29d Mon Sep 17 00:00:00 2001
From: tangjiafu <jiafu.tang@qq.com>
Date: Wed, 17 Apr 2024 19:50:59 +0800
Subject: [PATCH 2/3] wip

---
 .../pekko/grpc/internal/PekkoNettyGrpcClientGraphStage.scala  | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/runtime/src/main/scala/org/apache/pekko/grpc/internal/PekkoNettyGrpcClientGraphStage.scala b/runtime/src/main/scala/org/apache/pekko/grpc/internal/PekkoNettyGrpcClientGraphStage.scala
index d2d33de9..f26a7346 100644
--- a/runtime/src/main/scala/org/apache/pekko/grpc/internal/PekkoNettyGrpcClientGraphStage.scala
+++ b/runtime/src/main/scala/org/apache/pekko/grpc/internal/PekkoNettyGrpcClientGraphStage.scala
@@ -112,7 +112,9 @@ private final class PekkoNettyGrpcClientGraphStage[I, O](
           callback.invoke(message)
 
         override def onClose(status: Status, trailers: Metadata): Unit = {
-          onHeaders(trailers)
+          if (!matVal.isCompleted) {
+            onHeaders(trailers)
+          }
           callback.invoke(Closed(status))
         }
       }

From 8162b87a6ae0398f6406b9cffa7490859d5e2cf2 Mon Sep 17 00:00:00 2001
From: tangjiafu <jiafu.tang@qq.com>
Date: Wed, 17 Apr 2024 20:17:38 +0800
Subject: [PATCH 3/3] wip

---
 .../grpc/internal/PekkoNettyGrpcClientGraphStage.scala    | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/runtime/src/main/scala/org/apache/pekko/grpc/internal/PekkoNettyGrpcClientGraphStage.scala b/runtime/src/main/scala/org/apache/pekko/grpc/internal/PekkoNettyGrpcClientGraphStage.scala
index f26a7346..09edc2b1 100644
--- a/runtime/src/main/scala/org/apache/pekko/grpc/internal/PekkoNettyGrpcClientGraphStage.scala
+++ b/runtime/src/main/scala/org/apache/pekko/grpc/internal/PekkoNettyGrpcClientGraphStage.scala
@@ -13,15 +13,15 @@
 
 package org.apache.pekko.grpc.internal
 
+import io.grpc._
 import org.apache.pekko
 import pekko.annotation.InternalApi
 import pekko.dispatch.ExecutionContexts
-import pekko.grpc.{ GrpcResponseMetadata, GrpcServiceException }
+import pekko.grpc.GrpcResponseMetadata
 import pekko.stream
 import pekko.stream.{ Attributes => _, _ }
 import pekko.stream.stage._
 import pekko.util.FutureConverters._
-import io.grpc._
 
 import scala.concurrent.{ Future, Promise }
 import scala.util.Success
@@ -179,8 +179,8 @@ private final class PekkoNettyGrpcClientGraphStage[I, O](
           completeStage()
         } else {
           matVal.future.onComplete {
-            case Success(metadata) => failStage(new GrpcServiceException(status, metadata.headers))
-            case _                 => failStage(new GrpcServiceException(status))
+            case Success(metadata) => failStage(status.asRuntimeException(metadata.headers.raw.orNull))
+            case _                 => failStage(status.asRuntimeException())
           }(ExecutionContexts.parasitic)
         }
         call = null