diff --git a/backend/client.go b/backend/client.go index a78bf4b..5344023 100644 --- a/backend/client.go +++ b/backend/client.go @@ -73,7 +73,7 @@ func (c *backendClient) ScheduleNewOrchestration(ctx context.Context, orchestrat func (c *backendClient) FetchOrchestrationMetadata(ctx context.Context, id api.InstanceID) (*api.OrchestrationMetadata, error) { metadata, err := c.be.GetOrchestrationMetadata(ctx, id) if err != nil { - return nil, fmt.Errorf("Failed to fetch orchestration metadata: %w", err) + return nil, fmt.Errorf("failed to fetch orchestration metadata: %w", err) } return metadata, nil } diff --git a/backend/executor.go b/backend/executor.go index f32ed7f..9ef96ea 100644 --- a/backend/executor.go +++ b/backend/executor.go @@ -4,6 +4,7 @@ import ( context "context" "errors" "fmt" + "strconv" "strings" "sync" "time" @@ -28,17 +29,20 @@ var errShuttingDown error = status.Error(codes.Canceled, "shutting down") type ExecutionResults struct { Response *protos.OrchestratorResponse - complete chan interface{} + complete chan struct{} + pending chan string } type activityExecutionResult struct { response *protos.ActivityResponse - complete chan interface{} + complete chan struct{} + pending chan string } type Executor interface { ExecuteOrchestrator(ctx context.Context, iid api.InstanceID, oldEvents []*protos.HistoryEvent, newEvents []*protos.HistoryEvent) (*ExecutionResults, error) ExecuteActivity(context.Context, api.InstanceID, *protos.HistoryEvent) (*protos.HistoryEvent, error) + Shutdown(ctx context.Context) error } type grpcExecutor struct { @@ -96,9 +100,8 @@ func NewGrpcExecutor(be Backend, logger Logger, opts ...grpcExecutorOptions) (ex // ExecuteOrchestrator implements Executor func (executor *grpcExecutor) ExecuteOrchestrator(ctx context.Context, iid api.InstanceID, oldEvents []*protos.HistoryEvent, newEvents []*protos.HistoryEvent) (*ExecutionResults, error) { - result := &ExecutionResults{complete: make(chan interface{})} + result := &ExecutionResults{complete: make(chan struct{})} executor.pendingOrchestrators.Store(iid, result) - defer executor.pendingOrchestrators.Delete(iid) workItem := &protos.WorkItem{ Request: &protos.WorkItem_OrchestratorRequest{ @@ -121,12 +124,15 @@ func (executor *grpcExecutor) ExecuteOrchestrator(ctx context.Context, iid api.I } // Wait for the connected worker to signal that it's done executing the work-item - // TODO: Timeout logic - i.e. handle the case where we never hear back from the remote worker (due to a hang, etc.). select { case <-ctx.Done(): executor.logger.Warnf("%s: context canceled before receiving orchestrator result", iid) return nil, ctx.Err() case <-result.complete: + executor.logger.Debugf("%s: orchestrator got result", iid) + if result.Response == nil { + return nil, errors.New("operation aborted") + } } return result, nil @@ -135,9 +141,8 @@ func (executor *grpcExecutor) ExecuteOrchestrator(ctx context.Context, iid api.I // ExecuteActivity implements Executor func (executor *grpcExecutor) ExecuteActivity(ctx context.Context, iid api.InstanceID, e *protos.HistoryEvent) (*protos.HistoryEvent, error) { key := getActivityExecutionKey(string(iid), e.EventId) - result := &activityExecutionResult{complete: make(chan interface{})} + result := &activityExecutionResult{complete: make(chan struct{})} executor.pendingActivities.Store(key, result) - defer executor.pendingActivities.Delete(key) task := e.GetTaskScheduled() workItem := &protos.WorkItem{ @@ -162,12 +167,15 @@ func (executor *grpcExecutor) ExecuteActivity(ctx context.Context, iid api.Insta } // Wait for the connected worker to signal that it's done executing the work-item - // TODO: Timeout logic select { case <-ctx.Done(): executor.logger.Warnf("%s/%s#%d: context canceled before receiving activity result", iid, task.Name, e.EventId) return nil, ctx.Err() case <-result.complete: + executor.logger.Debugf("%s: activity got result", key) + if result.response == nil { + return nil, errors.New("operation aborted") + } } var responseEvent *protos.HistoryEvent @@ -181,9 +189,27 @@ func (executor *grpcExecutor) ExecuteActivity(ctx context.Context, iid api.Insta } // Shutdown implements Executor -func (g *grpcExecutor) Shutdown(ctx context.Context) { +func (g *grpcExecutor) Shutdown(ctx context.Context) error { // closing the work item queue is a signal for shutdown close(g.workItemQueue) + + // Iterate through all pending items and close them to unblock the goroutines waiting on this + g.pendingActivities.Range(func(_, value any) bool { + p, ok := value.(*activityExecutionResult) + if ok { + close(p.complete) + } + return true + }) + g.pendingOrchestrators.Range(func(_, value any) bool { + p, ok := value.(*ExecutionResults) + if ok { + close(p.complete) + } + return true + }) + + return nil } // Hello implements protos.TaskHubSidecarServiceServer @@ -202,22 +228,73 @@ func (g *grpcExecutor) GetWorkItems(req *protos.GetWorkItemsRequest, stream prot callback := g.onWorkItemConnection if callback != nil { if err := callback(stream.Context()); err != nil { - message := fmt.Sprint("unable to establish work item stream at this time: ", err) + message := "unable to establish work item stream at this time: " + err.Error() g.logger.Warn(message) return status.Errorf(codes.Unavailable, message) } } + // Collect all pending activities on this stream + // Note: we don't need sync.Map's here because access is only on this thread + pendingActivities := make(map[string]struct{}) + pendingActivityCh := make(chan string, 1) + pendingOrchestrators := make(map[string]struct{}) + pendingOrchestratorCh := make(chan string, 1) + defer func() { + // If there's any pending activity left, remove them + for key := range pendingActivities { + g.logger.Debugf("cleaning up pending activity: %s", key) + p, ok := g.pendingActivities.LoadAndDelete(key) + if ok { + pending := p.(*activityExecutionResult) + close(pending.complete) + } + } + for key := range pendingOrchestrators { + g.logger.Debugf("cleaning up pending orchestrator: %s", key) + p, ok := g.pendingOrchestrators.LoadAndDelete(api.InstanceID(key)) + if ok { + pending := p.(*ExecutionResults) + close(pending.complete) + } + } + }() + // The worker client invokes this method, which streams back work-items as they arrive. for { select { case <-stream.Context().Done(): - g.logger.Infof("work item stream closed") + g.logger.Info("work item stream closed") return nil - case wi := <-g.workItemQueue: + case wi, ok := <-g.workItemQueue: + if !ok { + continue + } + switch x := wi.Request.(type) { + case *protos.WorkItem_OrchestratorRequest: + key := x.OrchestratorRequest.GetInstanceId() + pendingOrchestrators[key] = struct{}{} + p, ok := g.pendingOrchestrators.Load(api.InstanceID(key)) + if ok { + p.(*ExecutionResults).pending = pendingOrchestratorCh + } + case *protos.WorkItem_ActivityRequest: + key := getActivityExecutionKey(x.ActivityRequest.GetOrchestrationInstance().GetInstanceId(), x.ActivityRequest.GetTaskId()) + pendingActivities[key] = struct{}{} + p, ok := g.pendingActivities.Load(key) + if ok { + p.(*activityExecutionResult).pending = pendingActivityCh + } + } + if err := stream.Send(wi); err != nil { + g.logger.Errorf("encountered an error while sending work item: %v", err) return err } + case key := <-pendingActivityCh: + delete(pendingActivities, key) + case key := <-pendingOrchestratorCh: + delete(pendingOrchestrators, key) case <-g.streamShutdownChan: return errShuttingDown } @@ -227,31 +304,57 @@ func (g *grpcExecutor) GetWorkItems(req *protos.GetWorkItemsRequest, stream prot // CompleteOrchestratorTask implements protos.TaskHubSidecarServiceServer func (g *grpcExecutor) CompleteOrchestratorTask(ctx context.Context, res *protos.OrchestratorResponse) (*protos.CompleteTaskResponse, error) { iid := api.InstanceID(res.InstanceId) - if p, ok := g.pendingOrchestrators.Load(iid); ok { - pending := p.(*ExecutionResults) - pending.Response = res - pending.complete <- true + if g.deletePendingOrchestrator(iid, res) { return emptyCompleteTaskResponse, nil } return emptyCompleteTaskResponse, fmt.Errorf("unknown instance ID: %s", res.InstanceId) } +func (g *grpcExecutor) deletePendingOrchestrator(iid api.InstanceID, res *protos.OrchestratorResponse) bool { + p, ok := g.pendingOrchestrators.LoadAndDelete(iid) + if !ok { + return false + } + + // Note that res can be nil in case of certain failures + pending := p.(*ExecutionResults) + pending.Response = res + if pending.pending != nil { + pending.pending <- string(iid) + } + close(pending.complete) + return true +} + // CompleteActivityTask implements protos.TaskHubSidecarServiceServer func (g *grpcExecutor) CompleteActivityTask(ctx context.Context, res *protos.ActivityResponse) (*protos.CompleteTaskResponse, error) { key := getActivityExecutionKey(res.InstanceId, res.TaskId) - if p, ok := g.pendingActivities.Load(key); ok { - pending := p.(*activityExecutionResult) - pending.response = res - pending.complete <- true + if g.deletePendingActivityTask(key, res) { return emptyCompleteTaskResponse, nil } return emptyCompleteTaskResponse, fmt.Errorf("unknown instance ID/task ID combo: %s", key) } +func (g *grpcExecutor) deletePendingActivityTask(key string, res *protos.ActivityResponse) bool { + p, ok := g.pendingActivities.LoadAndDelete(key) + if !ok { + return false + } + + // Note that res can be nil in case of certain failures + pending := p.(*activityExecutionResult) + pending.response = res + if pending.pending != nil { + pending.pending <- key + } + close(pending.complete) + return true +} + func getActivityExecutionKey(iid string, taskID int32) string { - return fmt.Sprintf("%s/%d", iid, taskID) + return iid + "/" + strconv.FormatInt(int64(taskID), 10) } // CreateTaskHub implements protos.TaskHubSidecarServiceServer diff --git a/client/worker_grpc.go b/client/worker_grpc.go index 6532e8d..69c1bea 100644 --- a/client/worker_grpc.go +++ b/client/worker_grpc.go @@ -2,6 +2,7 @@ package client import ( "context" + "errors" "fmt" "io" "time" @@ -49,6 +50,14 @@ func (c *TaskHubGrpcClient) StartWorkItemListener(ctx context.Context, r *task.T go func() { c.logger.Info("starting background processor") + defer func() { + c.logger.Info("stopping background processor") + // We must use a background context here as the stream's context is likely canceled + shutdownErr := executor.Shutdown(context.Background()) + if shutdownErr != nil { + c.logger.Warnf("error while shutting down background processor: %v", shutdownErr) + } + }() for { // TODO: Manage concurrency workItem, err := stream.Recv() @@ -64,15 +73,13 @@ func (c *TaskHubGrpcClient) StartWorkItemListener(ctx context.Context, r *task.T c.logger.Errorf("background processor received stream error: %v", err) - if err == io.EOF { + if errors.Is(err, io.EOF) { retriable = true } else if grpcStatus, ok := status.FromError(err); ok { c.logger.Warnf("received grpc error code %v", grpcStatus.Code().String()) switch grpcStatus.Code() { - case codes.Unavailable: - fallthrough - case codes.Canceled: - fallthrough + case codes.Unavailable, codes.Canceled: + retriable = true default: retriable = true } diff --git a/task/executor.go b/task/executor.go index c382daf..fbb5e74 100644 --- a/task/executor.go +++ b/task/executor.go @@ -91,6 +91,11 @@ func (te *taskExecutor) ExecuteOrchestrator(ctx context.Context, id api.Instance return results, nil } +func (te taskExecutor) Shutdown(ctx context.Context) error { + // Nothing to do + return nil +} + func unmarshalData(data []byte, v any) error { if v == nil { return nil diff --git a/tests/grpc/grpc_test.go b/tests/grpc/grpc_test.go index 70c4185..c0b698e 100644 --- a/tests/grpc/grpc_test.go +++ b/tests/grpc/grpc_test.go @@ -69,6 +69,13 @@ func TestMain(m *testing.M) { timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() + err = grpcExecutor.Shutdown(timeoutCtx) + if err != nil { + log.Fatalf("failed to shutdown grpc Executor: %v", err) + } + + timeoutCtx, cancel = context.WithTimeout(ctx, 5*time.Second) + defer cancel() if err := taskHubWorker.Shutdown(timeoutCtx); err != nil { log.Fatalf("failed to shutdown worker: %v", err) } @@ -82,6 +89,63 @@ func startGrpcListener(t *testing.T, r *task.TaskRegistry) context.CancelFunc { return cancel } +func Test_Grpc_WaitForInstanceStart_Timeout(t *testing.T) { + r := task.NewTaskRegistry() + r.AddOrchestratorN("WaitForInstanceStartThrowsException", func(ctx *task.OrchestrationContext) (any, error) { + // sleep 5 seconds + time.Sleep(5 * time.Second) + return 42, nil + }) + + cancelListener := startGrpcListener(t, r) + defer cancelListener() + + id, err := grpcClient.ScheduleNewOrchestration(ctx, "WaitForInstanceStartThrowsException", api.WithInput("世界")) + require.NoError(t, err) + timeoutCtx, cancelTimeout := context.WithTimeout(ctx, time.Second) + defer cancelTimeout() + _, err = grpcClient.WaitForOrchestrationStart(timeoutCtx, id, api.WithFetchPayloads(true)) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "context deadline exceeded") + } + time.Sleep(1 * time.Second) +} + +func Test_Grpc_WaitForInstanceStart_ConnectionResume(t *testing.T) { + r := task.NewTaskRegistry() + r.AddOrchestratorN("WaitForInstanceStartThrowsException", func(ctx *task.OrchestrationContext) (any, error) { + // sleep 5 seconds + time.Sleep(5 * time.Second) + return 42, nil + }) + + cancelListener := startGrpcListener(t, r) + + id, err := grpcClient.ScheduleNewOrchestration(ctx, "WaitForInstanceStartThrowsException", api.WithInput("世界")) + require.NoError(t, err) + timeoutCtx, cancelTimeout := context.WithTimeout(ctx, time.Second) + defer cancelTimeout() + _, err = grpcClient.WaitForOrchestrationStart(timeoutCtx, id, api.WithFetchPayloads(true)) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "context deadline exceeded") + } + cancelListener() + time.Sleep(2 * time.Second) + + // reconnect + cancelListener = startGrpcListener(t, r) + defer cancelListener() + + // workitem should be retried and completed. + timeoutCtx, cancelTimeout = context.WithTimeout(ctx, 30*time.Second) + defer cancelTimeout() + metadata, err := grpcClient.WaitForOrchestrationCompletion(timeoutCtx, id, api.WithFetchPayloads(true)) + require.NoError(t, err) + assert.Equal(t, true, metadata.IsComplete()) + assert.Equal(t, "42", metadata.SerializedOutput) + time.Sleep(1 * time.Second) +} + func Test_Grpc_HelloOrchestration(t *testing.T) { r := task.NewTaskRegistry() r.AddOrchestratorN("SingleActivity", func(ctx *task.OrchestrationContext) (any, error) {