Skip to content

Commit

Permalink
x-retry-attempt to StreamClientInterceptor (#733)
Browse files Browse the repository at this point in the history
* x-retry-attempt to StreamClientInterceptor

* unit test for StreamClientInterceptor AttemptMetadata

---------

Co-authored-by: a.boklazhenko <[email protected]>
  • Loading branch information
Boklazhenko and a.boklazhenko authored Nov 28, 2024
1 parent ba6f8b9 commit 6aea589
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 2 deletions.
14 changes: 12 additions & 2 deletions interceptors/retry/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,15 @@ func StreamClientInterceptor(optFuncs ...CallOption) grpc.StreamClientIntercepto
callOpts.onRetryCallback(parentCtx, attempt, lastErr)
}
var newStreamer grpc.ClientStream
newStreamer, lastErr = streamer(parentCtx, desc, cc, method, grpcOpts...)
newStreamer, lastErr = streamer(perStreamContext(parentCtx, callOpts, attempt), desc, cc, method, grpcOpts...)
if lastErr == nil {
retryingStreamer := &serverStreamingRetryingStream{
ClientStream: newStreamer,
callOpts: callOpts,
parentCtx: parentCtx,
streamerCall: func(ctx context.Context) (grpc.ClientStream, error) {
return streamer(ctx, desc, cc, method, grpcOpts...)
attempt++
return streamer(perStreamContext(ctx, callOpts, attempt), desc, cc, method, grpcOpts...)
},
}
return retryingStreamer, nil
Expand Down Expand Up @@ -296,6 +297,15 @@ func perCallContext(parentCtx context.Context, callOpts *options, attempt uint)
return ctx, cancel
}

func perStreamContext(parentCtx context.Context, callOpts *options, attempt uint) context.Context {
ctx := parentCtx
if attempt > 0 && callOpts.includeHeader {
mdClone := metadata.ExtractOutgoing(ctx).Clone().Set(AttemptMetadataKey, fmt.Sprintf("%d", attempt))
ctx = mdClone.ToOutgoing(ctx)
}
return ctx
}

func contextErrToGrpcErr(err error) error {
switch err {
case context.DeadlineExceeded:
Expand Down
72 changes: 72 additions & 0 deletions interceptors/retry/retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package retry
import (
"context"
"io"
"strconv"
"strings"
"sync"
"testing"
Expand All @@ -17,6 +18,7 @@ import (
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

Expand Down Expand Up @@ -432,3 +434,73 @@ func TestJitterUp(t *testing.T) {
assert.True(t, highCount != 0, "at least one sample should reach to >%s", high)
assert.True(t, lowCount != 0, "at least one sample should to <%s", low)
}

type failingClientStream struct {
RecvMsgErr error
}

func (s *failingClientStream) Header() (metadata.MD, error) {
return nil, nil
}

func (s *failingClientStream) Trailer() metadata.MD {
return nil
}

func (s *failingClientStream) CloseSend() error {
return nil
}

func (s *failingClientStream) Context() context.Context {
return context.Background()
}

func (s *failingClientStream) SendMsg(m any) error {
return nil
}

func (s *failingClientStream) RecvMsg(m any) error {
return s.RecvMsgErr
}

func TestStreamClientInterceptorAttemptMetadata(t *testing.T) {
retryCount := 5
attempt := 0
recvMsgErr := status.Error(codes.Unavailable, "unavailable")

var testStreamer grpc.Streamer = func(
ctx context.Context,
desc *grpc.StreamDesc,
cc *grpc.ClientConn,
method string,
opts ...grpc.CallOption,
) (grpc.ClientStream, error) {
if attempt > 0 {
md, ok := metadata.FromOutgoingContext(ctx)
require.True(t, ok)

raw := md.Get(AttemptMetadataKey)
require.Len(t, raw, 1)

attemptMetadataValue, err := strconv.Atoi(raw[0])
require.NoError(t, err)

require.Equal(t, attempt, attemptMetadataValue)
}

attempt++

return &failingClientStream{
RecvMsgErr: recvMsgErr,
}, nil
}

streamClientInterceptor := StreamClientInterceptor(WithCodes(codes.Unavailable), WithMax(uint(retryCount)))
clientStream, err := streamClientInterceptor(context.Background(), &grpc.StreamDesc{}, nil, "some_method", testStreamer)
require.NoError(t, err)

err = clientStream.RecvMsg(nil)
require.ErrorIs(t, err, recvMsgErr)

require.Equal(t, retryCount, attempt)
}

0 comments on commit 6aea589

Please sign in to comment.