diff --git a/backend/activity.go b/backend/activity.go index 16fd2dd..86b1cdc 100644 --- a/backend/activity.go +++ b/backend/activity.go @@ -13,24 +13,55 @@ import ( ) type activityProcessor struct { - be Backend - executor ActivityExecutor + be Backend + executor ActivityExecutor + logger Logger + maxOutputSizeInBytes int +} + +type activityWorkerConfig struct { + workerOptions []NewTaskWorkerOptions + maxOutputSizeInBytes int +} + +// ActivityWorkerOption is a function that configures an activity worker. +type ActivityWorkerOption func(*activityWorkerConfig) + +// WithMaxConcurrentActivityInvocations sets the maximum number of concurrent activity invocations +// that the worker can process. If this limit is exceeded, the worker will block until the number of +// concurrent invocations drops below the limit. +func WithMaxConcurrentActivityInvocations(n int32) ActivityWorkerOption { + return func(o *activityWorkerConfig) { + o.workerOptions = append(o.workerOptions, WithMaxParallelism(n)) + } +} + +// WithMaxActivityOutputSizeInKB sets the maximum size of an activity's output. +// If an activity's output exceeds this size, the activity execution will fail with an error. +func WithMaxActivityOutputSizeInKB(n int) ActivityWorkerOption { + return func(o *activityWorkerConfig) { + o.maxOutputSizeInBytes = n * 1024 + } } type ActivityExecutor interface { ExecuteActivity(context.Context, api.InstanceID, *protos.HistoryEvent) (*protos.HistoryEvent, error) } -func NewActivityTaskWorker(be Backend, executor ActivityExecutor, logger Logger, opts ...NewTaskWorkerOptions) TaskWorker { - processor := newActivityProcessor(be, executor) - return NewTaskWorker(be, processor, logger, opts...) -} +func NewActivityTaskWorker(be Backend, executor ActivityExecutor, logger Logger, opts ...ActivityWorkerOption) TaskWorker { + config := &activityWorkerConfig{} + for _, configure := range opts { + configure(config) + } -func newActivityProcessor(be Backend, executor ActivityExecutor) TaskProcessor { - return &activityProcessor{ - be: be, - executor: executor, + processor := &activityProcessor{ + be: be, + executor: executor, + logger: logger, + maxOutputSizeInBytes: config.maxOutputSizeInBytes, } + + return NewTaskWorker(be, processor, logger, config.workerOptions...) } // Name implements TaskProcessor @@ -77,22 +108,21 @@ func (p *activityProcessor) ProcessWorkItem(ctx context.Context, wi WorkItem) er } return err } - - awi.Result = result - return nil -} - -// CompleteWorkItem implements TaskDispatcher -func (ap *activityProcessor) CompleteWorkItem(ctx context.Context, wi WorkItem) error { - awi := wi.(*ActivityWorkItem) - if awi.Result == nil { + if result == nil { return fmt.Errorf("can't complete work item '%s' with nil result", wi.Description()) } - if awi.Result.GetTaskCompleted() == nil && awi.Result.GetTaskFailed() == nil { + if result.GetTaskCompleted() == nil && result.GetTaskFailed() == nil { return fmt.Errorf("can't complete work item '%s', which isn't TaskCompleted or TaskFailed", wi.Description()) } - return ap.be.CompleteActivityWorkItem(ctx, awi) + if p.maxOutputSizeInBytes > 0 && helpers.GetProtoSize(result) > p.maxOutputSizeInBytes { + err = fmt.Errorf("activity output size %d exceeds limit of %d bytes", helpers.GetProtoSize(result), p.maxOutputSizeInBytes) + awi.Result = helpers.NewTaskFailedEvent(awi.NewEvent.EventId, helpers.NewTaskFailureDetails(err)) + } else { + awi.Result = result + } + + return p.be.CompleteActivityWorkItem(ctx, awi) } // AbandonWorkItem implements TaskDispatcher diff --git a/backend/backend.go b/backend/backend.go index adcd7b0..e4f5ae3 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -5,9 +5,11 @@ import ( "errors" "fmt" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/wrapperspb" + "github.com/microsoft/durabletask-go/api" "github.com/microsoft/durabletask-go/internal/protos" - "google.golang.org/protobuf/proto" ) var ( @@ -23,6 +25,25 @@ type ( TaskFailureDetails = protos.TaskFailureDetails ) +type OrchestrationStateChanges struct { + NewEvents []*HistoryEvent + NewTasks []*HistoryEvent + NewTimers []*HistoryEvent + NewMessages []OrchestratorMessage + CustomStatus *wrapperspb.StringValue + RuntimeStatus protos.OrchestrationStatus + ContinuedAsNew bool + IsPartial bool + HistoryStartIndex int +} + +func (c *OrchestrationStateChanges) IsEmpty() bool { + return len(c.NewEvents) == 0 && + len(c.NewTasks) == 0 && + len(c.NewTimers) == 0 && + len(c.NewMessages) == 0 +} + type OrchestrationIdReusePolicyOptions func(*protos.OrchestrationIdReusePolicy) error func WithOrchestrationIdReusePolicy(policy *protos.OrchestrationIdReusePolicy) OrchestrationIdReusePolicyOptions { @@ -35,6 +56,7 @@ func WithOrchestrationIdReusePolicy(policy *protos.OrchestrationIdReusePolicy) O } } +// Backend is the interface that must be implemented by all task hub backends. type Backend interface { // CreateTaskHub creates a new task hub for the current backend. Task hub creation must be idempotent. // @@ -74,8 +96,23 @@ type Backend interface { // CompleteOrchestrationWorkItem completes a work item by saving the updated runtime state to durable storage. // + // The [OrchestrationStateChanges] parameter contains the changes to the orchestration state that should be + // saved to durable storage. The [HistoryStartIndex] field indicates the index of the first history event + // in the [OrchestrationStateChanges.NewEvents] slice. This is used to determine the index of the first + // history event in the [OrchestrationRuntimeState.History] slice, which is useful for backends that store + // the history events as an append log. + // + // The [OrchestrationStateChanges.IsPartial] field indicates whether this is a partial completion operation, + // in which case more calls to this function are expected to follow with the same work item. Partial completion + // is used to commit state updates in chunks to avoid overly large transactions. The final chunk will be committed + // with [OrchestrationStateChanges.IsPartial] set to [false]. + // + // Implementations of this function should not attempt to delete the work item from storage until the final chunk + // is committed (i.e., until [OrchestrationStateChanges.IsPartial] is [false]) to ensure that the work item can be + // recovered if the process crashes before the final chunk is committed. + // // Returns [ErrWorkItemLockLost] if the work-item couldn't be completed due to a lock-lost conflict (e.g., split-brain). - CompleteOrchestrationWorkItem(context.Context, *OrchestrationWorkItem) error + CompleteOrchestrationWorkItem(context.Context, *OrchestrationWorkItem, OrchestrationStateChanges) error // AbandonOrchestrationWorkItem undoes any state changes and returns the work item to the work item queue. // diff --git a/backend/client.go b/backend/client.go index b38cfd8..371c590 100644 --- a/backend/client.go +++ b/backend/client.go @@ -73,7 +73,7 @@ func (c *backendClient) ScheduleNewOrchestration(ctx context.Context, orchestrat func (c *backendClient) FetchOrchestrationMetadata(ctx context.Context, id api.InstanceID) (*api.OrchestrationMetadata, error) { metadata, err := c.be.GetOrchestrationMetadata(ctx, id) if err != nil { - return nil, fmt.Errorf("Failed to fetch orchestration metadata: %w", err) + return nil, fmt.Errorf("failed to fetch orchestration metadata: %w", err) } return metadata, nil } diff --git a/backend/executor.go b/backend/executor.go index 7d7f573..66efa85 100644 --- a/backend/executor.go +++ b/backend/executor.go @@ -404,6 +404,8 @@ loop: } // mustEmbedUnimplementedTaskHubSidecarServiceServer implements protos.TaskHubSidecarServiceServer +// +//lint:ignore U1000 because this is a required gRPC method func (grpcExecutor) mustEmbedUnimplementedTaskHubSidecarServiceServer() { } diff --git a/backend/orchestration.go b/backend/orchestration.go index 37f7b3b..623175c 100644 --- a/backend/orchestration.go +++ b/backend/orchestration.go @@ -2,7 +2,10 @@ package backend import ( "context" + "errors" "fmt" + "strconv" + "strings" "time" "go.opentelemetry.io/otel/attribute" @@ -23,19 +26,57 @@ type OrchestratorExecutor interface { newEvents []*protos.HistoryEvent) (*ExecutionResults, error) } +// OrchestrationWorkerOption is a function that configures an orchestrator worker. +type OrchestrationWorkerOption func(*orchestrationWorkerConfig) + +type orchestrationWorkerConfig struct { + workerOptions []NewTaskWorkerOptions + chunkingConfig ChunkingConfiguration +} + +// WithChunkingConfiguration sets the chunking configuration for the orchestrator worker. +// Use this option to configure how the orchestrator worker will chunk large history lists. +func WithChunkingConfiguration(config ChunkingConfiguration) OrchestrationWorkerOption { + return func(p *orchestrationWorkerConfig) { + p.chunkingConfig = config + } +} + +// WithMaxConcurrentOrchestratorInvocations sets the maximum number of orchestrations that can be invoked +// concurrently by the orchestrator worker. If this limit is exceeded, the worker will block until the number +// of concurrent orchestrations drops below the limit. +// +// Note that this limit is applied to the number of orchestrations that are being invoked concurrently, not +// the number of orchestrations that are in a running state. For example, if an orchestration is waiting for +// an external event, a timer, or an activity, it will not count against this limit. +// +// If this value is set to 0 or less, then the number of parallel orchestrations is unlimited. +func WithMaxConcurrentOrchestratorInvocations(n int32) OrchestrationWorkerOption { + return func(p *orchestrationWorkerConfig) { + p.workerOptions = append(p.workerOptions, WithMaxParallelism(n)) + } +} + type orchestratorProcessor struct { - be Backend - executor OrchestratorExecutor - logger Logger + be Backend + executor OrchestratorExecutor + logger Logger + chunkingConfig ChunkingConfiguration } -func NewOrchestrationWorker(be Backend, executor OrchestratorExecutor, logger Logger, opts ...NewTaskWorkerOptions) TaskWorker { +func NewOrchestrationWorker(be Backend, executor OrchestratorExecutor, logger Logger, opts ...OrchestrationWorkerOption) TaskWorker { + config := &orchestrationWorkerConfig{} + for _, configure := range opts { + configure(config) + } + processor := &orchestratorProcessor{ - be: be, - executor: executor, - logger: logger, + be: be, + executor: executor, + logger: logger, + chunkingConfig: config.chunkingConfig, } - return NewTaskWorker(be, processor, logger, opts...) + return NewTaskWorker(be, processor, logger, config.workerOptions...) } // Name implements TaskProcessor @@ -72,11 +113,12 @@ func (w *orchestratorProcessor) ProcessWorkItem(ctx context.Context, cwi WorkIte w.endOrchestratorSpan(ctx, wi, span, false) }() + continueAsNewLoop: for continueAsNewCount := 0; ; continueAsNewCount++ { if continueAsNewCount > 0 { w.logger.Debugf("%v: continuing-as-new with %d event(s): %s", wi.InstanceID, len(wi.State.NewEvents()), helpers.HistoryListSummary(wi.State.NewEvents())) } else { - w.logger.Debugf("%v: invoking orchestrator", wi.InstanceID) + w.logger.Debugf("%v: invoking orchestrator with %d event(s): %s", wi.InstanceID, len(wi.State.NewEvents()), helpers.HistoryListSummary(wi.State.NewEvents())) } // Run the user orchestrator code, providing the old history and new events together. @@ -86,41 +128,87 @@ func (w *orchestratorProcessor) ProcessWorkItem(ctx context.Context, cwi WorkIte } w.logger.Debugf("%v: orchestrator returned %d action(s): %s", wi.InstanceID, len(results.Response.Actions), helpers.ActionListSummary(results.Response.Actions)) - // Apply the orchestrator outputs to the orchestration state. - continuedAsNew, err := wi.State.ApplyActions(results.Response.Actions, helpers.TraceContextFromSpan(span)) - if err != nil { - return fmt.Errorf("failed to apply the execution result actions: %w", err) - } - wi.State.CustomStatus = results.Response.CustomStatus - - // When continuing-as-new, we re-execute the orchestrator from the beginning with a truncated state in a tight loop - // until the orchestrator performs some non-continue-as-new action. - if continuedAsNew { - const MaxContinueAsNewCount = 20 - if continueAsNewCount >= MaxContinueAsNewCount { - return fmt.Errorf("exceeded tight-loop continue-as-new limit of %d iterations", MaxContinueAsNewCount) + // Apply the results of the orchestrator execution to the orchestration state. + wi.State.AddActions(results.Response.Actions) + + // Now we need to take all the changes we just made to the state and persist them to the backend storage. + // This is done in a loop because the list of actions may be too large to be committed in a single transaction. + // If this happens, we'll commit the changes in chunks until we've committed all of them. + addedEvents := 0 + for { + tc := helpers.TraceContextFromSpan(span) + changes, err := wi.State.ProcessChanges(w.chunkingConfig, tc, w.logger) + if errors.Is(err, ErrContinuedAsNew) { + // The orchestrator did a continue-as-new, which means we should re-run the orchestator with the new state. + // No changes are committed to the backend until the orchestration returns a non-continue-as-new result. + // Safety check: see if the user code might be an infinite continue-as-new loop. 10K is the arbitrary threshold we use. + const MaxContinueAsNewCount = 10000 + if continueAsNewCount >= MaxContinueAsNewCount { + // Fail the orchestration since we don't want it to be stuck in an infinite loop + continueAsNewError := fmt.Errorf("exceeded tight-loop continue-as-new limit of %d iterations", MaxContinueAsNewCount) + w.logger.Warnf("%v: terminating orchestration: %v", wi.InstanceID, err) + return w.failOrchestration(ctx, wi, continueAsNewError, tc) + } + + // We create a new trace span for every continue-as-new + w.endOrchestratorSpan(ctx, wi, span, true) + ctx, span = w.startOrResumeOrchestratorSpan(ctx, wi) + continue continueAsNewLoop + } else if err != nil { + // Any other error is assumed to be non-recoverable, so we fail the orchestration + return w.failOrchestration(ctx, wi, err, tc) } - // We create a new trace span for every continue-as-new - w.endOrchestratorSpan(ctx, wi, span, true) - ctx, span = w.startOrResumeOrchestratorSpan(ctx, wi) - continue - } + if !changes.IsEmpty() { + // Commit the changes to the backend + if len(changes.NewEvents) > 0 { + w.logger.Debugf("%v: committing %d new history event(s) to the backend (partial=%v): %v", wi.InstanceID, len(changes.NewEvents), changes.IsPartial, helpers.HistoryListSummary(changes.NewEvents)) + } + if len(changes.NewTasks) > 0 { + w.logger.Debugf("%v: committing %d new scheduled task(s) to the backend (partial=%v): %v", wi.InstanceID, len(changes.NewTasks), changes.IsPartial, helpers.HistoryListSummary(changes.NewTasks)) + } + if len(changes.NewTimers) > 0 { + w.logger.Debugf("%v: committing %d new timer(s) to the backend (partial=%v): %v", wi.InstanceID, len(changes.NewTimers), changes.IsPartial, helpers.HistoryListSummary(changes.NewTimers)) + } + if len(changes.NewMessages) > 0 { + w.logger.Debugf("%v: committing %d new message(s) to the backend (partial=%v): %v", wi.InstanceID, len(changes.NewMessages), changes.IsPartial, messageListSummary(changes.NewMessages)) + } + + changes.HistoryStartIndex = addedEvents + if err := w.be.CompleteOrchestrationWorkItem(ctx, wi, changes); err != nil { + return fmt.Errorf("failed to complete orchestration work item: %w", err) + } + + addedEvents += len(changes.NewEvents) + + // Keep looping until we've committed all the changes + continue + } - if wi.State.IsCompleted() { - name, _ := wi.State.Name() - w.logger.Infof("%v: '%s' completed with a %s status.", wi.InstanceID, name, helpers.ToRuntimeStatusString(wi.State.RuntimeStatus())) + if wi.State.IsCompleted() { + name, _ := wi.State.Name() + w.logger.Infof("%v: '%s' completed with a %s status.", wi.InstanceID, name, helpers.ToRuntimeStatusString(wi.State.RuntimeStatus())) + } + + break continueAsNewLoop // break out of the process results loop if no errors } - break } } return nil } -// CompleteWorkItem implements TaskProcessor -func (p *orchestratorProcessor) CompleteWorkItem(ctx context.Context, wi WorkItem) error { - owi := wi.(*OrchestrationWorkItem) - return p.be.CompleteOrchestrationWorkItem(ctx, owi) +func (p *orchestratorProcessor) failOrchestration(ctx context.Context, wi *OrchestrationWorkItem, err error, tc *protos.TraceContext) error { + p.logger.Warnf("%v: setting orchestration as failed: %v", wi.InstanceID, err) + wi.State.SetFailed(err) + + changes, err := wi.State.ProcessChanges(p.chunkingConfig, tc, p.logger) + if err != nil { + // This is assumed to be non-recoverable, so we swallow it and log a message + p.logger.Errorf("%v: failed to fail orchestration: %v", wi.InstanceID, err) + return nil + } + + return p.be.CompleteOrchestrationWorkItem(ctx, wi, changes) } // AbandonWorkItem implements TaskProcessor @@ -129,13 +217,21 @@ func (p *orchestratorProcessor) AbandonWorkItem(ctx context.Context, wi WorkItem return p.be.AbandonOrchestrationWorkItem(ctx, owi) } +// applyWorkItem adds the new events from the work item to the orchestration state. +// +// The returned context will contain a new distributed tracing span that should be used for all +// subsequent operations. The returned span will be nil if the work item was dropped. +// +// The returned boolean will be false if the work item was dropped. +// +// The caller is responsible for calling endOrchestratorSpan on the returned span. func (w *orchestratorProcessor) applyWorkItem(ctx context.Context, wi *OrchestrationWorkItem) (context.Context, trace.Span, bool) { // Ignore work items for orchestrations that are completed or are in a corrupted state. if !wi.State.IsValid() { w.logger.Warnf("%v: orchestration state is invalid; dropping work item", wi.InstanceID) return nil, nil, false } else if wi.State.IsCompleted() { - w.logger.Warnf("%v: orchestration already completed; dropping work item", wi.InstanceID) + w.logger.Infof("%v: dropping work item(s) for %s orchestration: %s", wi.InstanceID, helpers.ToRuntimeStatusString(wi.State.RuntimeStatus()), helpers.HistoryListSummary(wi.NewEvents)) return nil, nil, false } else if len(wi.NewEvents) == 0 { w.logger.Warnf("%v: the work item had no events!", wi.InstanceID) @@ -307,3 +403,26 @@ func addNotableEventsToSpan(events []*protos.HistoryEvent, span trace.Span) { } } } + +func messageListSummary(messages []OrchestratorMessage) string { + var sb strings.Builder + sb.WriteString("[") + for i, m := range messages { + if i > 0 { + sb.WriteString(", ") + } + if i >= 10 { + sb.WriteString("...") + break + } + name := helpers.GetHistoryEventTypeName(m.HistoryEvent) + sb.WriteString(name) + taskID := helpers.GetTaskId(m.HistoryEvent) + if taskID >= 0 { + sb.WriteRune('#') + sb.WriteString(strconv.FormatInt(int64(taskID), 10)) + } + } + sb.WriteString("]") + return sb.String() +} diff --git a/backend/runtimestate.go b/backend/runtimestate.go index ac23421..c4bd663 100644 --- a/backend/runtimestate.go +++ b/backend/runtimestate.go @@ -5,7 +5,6 @@ import ( "fmt" "time" - "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" "github.com/microsoft/durabletask-go/api" @@ -13,23 +12,27 @@ import ( "github.com/microsoft/durabletask-go/internal/protos" ) -var ErrDuplicateEvent = errors.New("duplicate event") +var ( + ErrDuplicateEvent = errors.New("duplicate event") + ErrContinuedAsNew = errors.New("orchestration continued-as-new") +) type OrchestrationRuntimeState struct { - instanceID api.InstanceID - newEvents []*protos.HistoryEvent - oldEvents []*protos.HistoryEvent - pendingTasks []*protos.HistoryEvent - pendingTimers []*protos.HistoryEvent - pendingMessages []OrchestratorMessage - - startEvent *protos.ExecutionStartedEvent - completedEvent *protos.ExecutionCompletedEvent - createdTime time.Time - lastUpdatedTime time.Time - completedTime time.Time - continuedAsNew bool - isSuspended bool + instanceID api.InstanceID + oldEvents []*protos.HistoryEvent + pendingEvents []*protos.HistoryEvent + newEvents []*protos.HistoryEvent + newActions []*protos.OrchestratorAction + + startEvent *protos.ExecutionStartedEvent + completedEvent *protos.ExecutionCompletedEvent + createdTime time.Time + lastUpdatedTime time.Time + completedTime time.Time + continuedAsNew bool + isSuspended bool + startedTaskIDs map[int32]bool + completedTaskIDs map[int32]bool CustomStatus *wrapperspb.StringValue } @@ -39,11 +42,27 @@ type OrchestratorMessage struct { TargetInstanceID string } +// ChunkingConfiguration specifies the maximum size of a single chunk of history events. +// See https://github.com/microsoft/durabletask-go/issues/44 for more details. +type ChunkingConfiguration struct { + // MaxHistoryEventCount is the maximum number of history events that can be stored in a single chunk. + // If the number of history events exceeds this value, the history events will be saved in multiple chunks that are each less than or equal to this value. + // A value of 0 or less means that there is no limit. + MaxHistoryEventCount int + // MaxHistoryEventSizeInKB is the maximum size of a single chunk in kilobytes. + // For example, a max size of 2MB would be specified as 2048. + // If the aggregate size of a batch of history events exceeds this value, the history events will be saved in multiple chunks that are each less than or equal to this size. + // If the size of a single history event exceeds this value, the orchestration will fail. + // A value of 0 or less means that there is no limit. + MaxHistoryEventSizeInKB int +} + func NewOrchestrationRuntimeState(instanceID api.InstanceID, existingHistory []*HistoryEvent) *OrchestrationRuntimeState { s := &OrchestrationRuntimeState{ - instanceID: instanceID, - oldEvents: make([]*HistoryEvent, 0, len(existingHistory)), - newEvents: make([]*HistoryEvent, 0, 10), + instanceID: instanceID, + oldEvents: make([]*HistoryEvent, 0, len(existingHistory)+10), + startedTaskIDs: make(map[int32]bool), + completedTaskIDs: make(map[int32]bool), } for _, e := range existingHistory { @@ -54,6 +73,8 @@ func NewOrchestrationRuntimeState(instanceID api.InstanceID, existingHistory []* } // AddEvent appends a new history event to the orchestration history +// +// Returns [ErrDuplicateEvent] if the event is known to be a duplicate. func (s *OrchestrationRuntimeState) AddEvent(e *HistoryEvent) error { return s.addEvent(e, true) } @@ -75,11 +96,27 @@ func (s *OrchestrationRuntimeState) addEvent(e *HistoryEvent, isNew bool) error s.isSuspended = true } else if e.GetExecutionResumed() != nil { s.isSuspended = false - } else { - // TODO: Check for other possible duplicates using task IDs + } else if e.GetTaskScheduled() != nil || e.GetSubOrchestrationInstanceCreated() != nil || e.GetTimerCreated() != nil { + // Filter out duplicate task started events. This is never expected unless there is a bug + // in a Durable Task SDK library. We filter here to prevent state store insert problems. + if _, exists := s.startedTaskIDs[e.EventId]; exists { + return ErrDuplicateEvent + } + s.startedTaskIDs[e.EventId] = true + } else if e.GetTaskCompleted() != nil || e.GetTaskFailed() != nil || + e.GetSubOrchestrationInstanceCompleted() != nil || e.GetSubOrchestrationInstanceFailed() != nil || + e.GetTimerFired() != nil { + // Filter out duplicate task completed events. This can happen in failure recovery cases where + // completion events get played into the history multiple times. + taskID := helpers.GetTaskId(e) + if _, exists := s.completedTaskIDs[taskID]; exists { + return ErrDuplicateEvent + } + s.completedTaskIDs[taskID] = true } if isNew { + s.pendingEvents = append(s.pendingEvents, e) s.newEvents = append(s.newEvents, e) } else { s.oldEvents = append(s.oldEvents, e) @@ -90,7 +127,7 @@ func (s *OrchestrationRuntimeState) addEvent(e *HistoryEvent, isNew bool) error } func (s *OrchestrationRuntimeState) IsValid() bool { - if len(s.oldEvents) == 0 && len(s.newEvents) == 0 { + if len(s.oldEvents) == 0 && len(s.pendingEvents) == 0 && len(s.newEvents) == 0 { // empty orchestration state return true } else if s.startEvent != nil { @@ -100,65 +137,168 @@ func (s *OrchestrationRuntimeState) IsValid() bool { return false } -// ApplyActions takes a set of actions and updates its internal state, including populating the outbox. -func (s *OrchestrationRuntimeState) ApplyActions(actions []*protos.OrchestratorAction, currentTraceContext *protos.TraceContext) (bool, error) { - for _, action := range actions { +func (s *OrchestrationRuntimeState) AddActions(actions []*protos.OrchestratorAction) { + s.newActions = append(s.newActions, actions...) +} + +// ProcessChanges processes all the changes that were added to the orchestration state and returns the changes that were applied +// based on the specified chunking configuration. ProcessChanges should be called continuously until it returns a non-partial +// result. If the result is partial, the caller should commit the changes and then call ProcessChanges again with the same +// state object to process any remaining changes. +// +// Returns [ErrContinuedAsNew] if the orchestration has continued-as-new. +func (s *OrchestrationRuntimeState) ProcessChanges(c ChunkingConfiguration, currentTraceContext *protos.TraceContext, log Logger) (OrchestrationStateChanges, error) { + var currentChunkLength int + var currentChunkSizeInBytes int + var changes OrchestrationStateChanges + + resetChanges := func(continuedAsNew bool) { + currentChunkLength = 0 + currentChunkSizeInBytes = 0 + changes = OrchestrationStateChanges{ + IsPartial: true, + ContinuedAsNew: s.continuedAsNew || continuedAsNew, + CustomStatus: s.CustomStatus, + RuntimeStatus: s.RuntimeStatus(), + } + } + resetChanges(false) + + // verifyAndAddPayloadSize is a helper function that processes a single history event and returns true if it should be + // included in the current chunk, false if it should be excluded, or an error if the item cannot be processed. We need + // to run this function for each item in the list of new events, new messages, new timers, and new tasks. + verifyAndAddPayloadSize := func(e *HistoryEvent) (bool, error) { + if c.MaxHistoryEventSizeInKB > 0 { + eventSize := helpers.GetProtoSize(e) + maxChunkSizeInBytes := c.MaxHistoryEventSizeInKB * 1024 + if eventSize > maxChunkSizeInBytes { + // This is a fatal error; we can't split a single event into multiple chunks + return false, fmt.Errorf("orchestration event size of %d bytes exceeds the maximum allowable size of %d bytes", eventSize, maxChunkSizeInBytes) + } + + currentChunkSizeInBytes += eventSize + if currentChunkSizeInBytes > maxChunkSizeInBytes { + // Can't fit this event into the current chunk; return the changes that were applied so far + return false, nil + } + } + + if c.MaxHistoryEventCount > 0 { + currentChunkLength += 1 + if currentChunkLength > c.MaxHistoryEventCount { + // Can't fit any more events into the current chunk; return the changes that were applied so far + return false, nil + } + } + + return true, nil + } + + // Process all the inbox events that were added to the state. + for len(s.pendingEvents) > 0 { + e := s.pendingEvents[0] + if ok, err := verifyAndAddPayloadSize(e); err != nil { + return OrchestrationStateChanges{}, err + } else if !ok { + return changes, nil + } + changes.NewEvents = append(changes.NewEvents, e) + s.pendingEvents = s.pendingEvents[1:] + } + + // Process all the orchestrator actions that were added to the state. + for len(s.newActions) > 0 { + action := s.newActions[0] if completedAction := action.GetCompleteOrchestration(); completedAction != nil { + // Continue-as-new requires us to reset all changes and start over with a new state object. if completedAction.OrchestrationStatus == protos.OrchestrationStatus_ORCHESTRATION_STATUS_CONTINUED_AS_NEW { - newState := NewOrchestrationRuntimeState(s.instanceID, []*protos.HistoryEvent{}) - newState.continuedAsNew = true - newState.AddEvent(helpers.NewOrchestratorStartedEvent()) - - // Duplicate the start event info, updating just the input - newState.AddEvent( - helpers.NewExecutionStartedEvent( - s.startEvent.Name, - string(s.instanceID), - completedAction.Result, - s.startEvent.ParentInstance, - s.startEvent.ParentTraceContext, - ), + resetChanges(true) + + orchestratorStartedEvent := helpers.NewOrchestratorStartedEvent() + verifyAndAddPayloadSize(orchestratorStartedEvent) + + // Create a new start event based on the old start event, updating just the input + executionStartedEvent := helpers.NewExecutionStartedEvent( + s.startEvent.Name, + string(s.instanceID), + completedAction.Result, + s.startEvent.ParentInstance, + s.startEvent.ParentTraceContext, ) + if ok, err := verifyAndAddPayloadSize(executionStartedEvent); err != nil { + return OrchestrationStateChanges{}, err + } else if !ok { + return OrchestrationStateChanges{}, fmt.Errorf("unable to fit both orchestratorStarted and executionStarted events into a single chunk") + } - // Unprocessed "carryover" events + // Replace the current state with a new state. All unprocessed actions that came after this continue-as-new will be lost. + newState := NewOrchestrationRuntimeState(s.instanceID, []*protos.HistoryEvent{}) + newState.AddEvent(orchestratorStartedEvent) + newState.AddEvent(executionStartedEvent) + newState.continuedAsNew = true for _, e := range completedAction.CarryoverEvents { - newState.AddEvent(e) + if err := newState.AddEvent(e); err != nil { + return OrchestrationStateChanges{}, err + } + } + if len(s.newActions) > 1 { + log.Warnf("%v: Discarding %d orchestrator action(s) because they were scheduled after the orchestration continued-as-new", s.instanceID, len(s.newActions)-1) } - - // Overwrite the current state object with a new one *s = *newState - // ignore all remaining actions - return true, nil + // Return ErrContinuedAsNew to indicate that the caller should start a new orchestrator invocation + // with the updated state. + return OrchestrationStateChanges{}, ErrContinuedAsNew } else { - s.AddEvent(helpers.NewExecutionCompletedEvent(action.Id, completedAction.OrchestrationStatus, completedAction.Result, completedAction.FailureDetails)) + changes.RuntimeStatus = completedAction.OrchestrationStatus + changes.IsPartial = false + + completedEvent := helpers.NewExecutionCompletedEvent(action.Id, completedAction.OrchestrationStatus, completedAction.Result, completedAction.FailureDetails) + if ok, err := verifyAndAddPayloadSize(completedEvent); err != nil { + return OrchestrationStateChanges{}, err + } else if !ok { + return changes, nil + } + changes.NewEvents = append(changes.NewEvents, completedEvent) if s.startEvent.GetParentInstance() != nil { msg := OrchestratorMessage{ - HistoryEvent: &protos.HistoryEvent{EventId: -1, Timestamp: timestamppb.Now()}, TargetInstanceID: s.startEvent.GetParentInstance().OrchestrationInstance.GetInstanceId(), } if completedAction.OrchestrationStatus == protos.OrchestrationStatus_ORCHESTRATION_STATUS_COMPLETED { - msg.HistoryEvent.EventType = &protos.HistoryEvent_SubOrchestrationInstanceCompleted{ - SubOrchestrationInstanceCompleted: &protos.SubOrchestrationInstanceCompletedEvent{ - TaskScheduledId: s.startEvent.ParentInstance.TaskScheduledId, - Result: completedAction.Result, - }, - } + msg.HistoryEvent = helpers.NewSubOrchestrationCompletedEvent( + s.startEvent.ParentInstance.TaskScheduledId, + completedAction.Result, + ) } else { - // TODO: What is the expected result for termination? - msg.HistoryEvent.EventType = &protos.HistoryEvent_SubOrchestrationInstanceFailed{ - SubOrchestrationInstanceFailed: &protos.SubOrchestrationInstanceFailedEvent{ - TaskScheduledId: s.startEvent.ParentInstance.TaskScheduledId, - FailureDetails: completedAction.FailureDetails, - }, - } + msg.HistoryEvent = helpers.NewSubOrchestrationFailedEvent( + s.startEvent.ParentInstance.TaskScheduledId, + completedAction.FailureDetails, + ) } - s.pendingMessages = append(s.pendingMessages, msg) + changes.NewMessages = append(changes.NewMessages, msg) } + + s.newEvents = append(s.newEvents, completedEvent) + s.completedEvent = completedEvent.GetExecutionCompleted() + s.completedTime = completedEvent.Timestamp.AsTime() } } else if createtimer := action.GetCreateTimer(); createtimer != nil { - s.AddEvent(helpers.NewTimerCreatedEvent(action.Id, createtimer.FireAt)) - s.pendingTimers = append(s.pendingTimers, helpers.NewTimerFiredEvent(action.Id, createtimer.FireAt, currentTraceContext)) + timerCreatedEvent := helpers.NewTimerCreatedEvent(action.Id, createtimer.FireAt) + if _, ok := s.startedTaskIDs[timerCreatedEvent.EventId]; ok { + log.Debugf("%v: Ignoring duplicate timer created event: %v", s.instanceID, timerCreatedEvent) + } else if s.IsCompleted() { + log.Warnf("%v: Dropping timer creation action because the orchestration is %s", s.instanceID, helpers.ToRuntimeStatusString(s.RuntimeStatus())) + } else { + if ok, err := verifyAndAddPayloadSize(timerCreatedEvent); err != nil { + return OrchestrationStateChanges{}, err + } else if !ok { + return changes, nil + } + + s.newEvents = append(s.newEvents, timerCreatedEvent) + changes.NewEvents = append(changes.NewEvents, timerCreatedEvent) + changes.NewTimers = append(changes.NewTimers, helpers.NewTimerFiredEvent(action.Id, createtimer.FireAt)) + } } else if scheduleTask := action.GetScheduleTask(); scheduleTask != nil { scheduledEvent := helpers.NewTaskScheduledEvent( action.Id, @@ -167,46 +307,105 @@ func (s *OrchestrationRuntimeState) ApplyActions(actions []*protos.OrchestratorA scheduleTask.Input, currentTraceContext, ) - s.AddEvent(scheduledEvent) - s.pendingTasks = append(s.pendingTasks, scheduledEvent) + if _, ok := s.startedTaskIDs[scheduledEvent.EventId]; ok { + log.Debugf("%v: Ignoring duplicate task scheduled event: %v", s.instanceID, scheduledEvent) + } else if s.IsCompleted() { + log.Warnf("%v: Dropping task schedule action because the orchestration is %s", s.instanceID, helpers.ToRuntimeStatusString(s.RuntimeStatus())) + } else { + if ok, err := verifyAndAddPayloadSize(scheduledEvent); err != nil { + return OrchestrationStateChanges{}, err + } else if !ok { + return changes, nil + } + + s.newEvents = append(s.newEvents, scheduledEvent) + changes.NewEvents = append(changes.NewEvents, scheduledEvent) + changes.NewTasks = append(changes.NewTasks, scheduledEvent) + } } else if createSO := action.GetCreateSubOrchestration(); createSO != nil { // Autogenerate an instance ID for the sub-orchestration if none is provided, using a // deterministic algorithm based on the parent instance ID to help enable de-duplication. if createSO.InstanceId == "" { createSO.InstanceId = fmt.Sprintf("%s:%04x", s.instanceID, action.Id) } - s.AddEvent(helpers.NewSubOrchestrationCreatedEvent( + createdEvent := helpers.NewSubOrchestrationCreatedEvent( action.Id, createSO.Name, createSO.Version, createSO.Input, createSO.InstanceId, - currentTraceContext)) - startEvent := helpers.NewExecutionStartedEvent( - createSO.Name, - createSO.InstanceId, - createSO.Input, - helpers.NewParentInfo(action.Id, s.startEvent.Name, string(s.instanceID)), currentTraceContext, ) - s.pendingMessages = append(s.pendingMessages, OrchestratorMessage{HistoryEvent: startEvent, TargetInstanceID: createSO.InstanceId}) + if _, ok := s.startedTaskIDs[createdEvent.EventId]; ok { + log.Debugf("%v: Ignoring duplicate sub-orchestration created event: %v", s.instanceID, createdEvent) + } else if s.IsCompleted() { + log.Warnf("%v: Dropping sub-orchestration creation action because the orchestration is %s", s.instanceID, helpers.ToRuntimeStatusString(s.RuntimeStatus())) + } else { + if ok, err := verifyAndAddPayloadSize(createdEvent); err != nil { + return OrchestrationStateChanges{}, err + } else if !ok { + return changes, nil + } + + s.newEvents = append(s.newEvents, createdEvent) + changes.NewEvents = append(changes.NewEvents, createdEvent) + startEvent := helpers.NewExecutionStartedEvent( + createSO.Name, + createSO.InstanceId, + createSO.Input, + helpers.NewParentInfo(action.Id, s.startEvent.Name, string(s.instanceID)), + currentTraceContext, + ) + changes.NewMessages = append(changes.NewMessages, OrchestratorMessage{HistoryEvent: startEvent, TargetInstanceID: createSO.InstanceId}) + } } else if sendEvent := action.GetSendEvent(); sendEvent != nil { e := helpers.NewSendEventEvent(action.Id, sendEvent.Instance.InstanceId, sendEvent.Name, sendEvent.Data) - s.AddEvent(e) - s.pendingMessages = append(s.pendingMessages, OrchestratorMessage{HistoryEvent: e, TargetInstanceID: sendEvent.Instance.InstanceId}) + if ok, err := verifyAndAddPayloadSize(e); err != nil { + return OrchestrationStateChanges{}, err + } else if !ok { + return changes, nil + } + + s.newEvents = append(s.newEvents, e) + changes.NewEvents = append(changes.NewEvents, e) + changes.NewMessages = append(changes.NewMessages, OrchestratorMessage{HistoryEvent: e, TargetInstanceID: sendEvent.Instance.InstanceId}) } else if terminate := action.GetTerminateOrchestration(); terminate != nil { // Send a message to terminate the target orchestration - msg := OrchestratorMessage{ - TargetInstanceID: terminate.InstanceId, - HistoryEvent: helpers.NewExecutionTerminatedEvent(terminate.Reason, terminate.Recurse), - } - s.pendingMessages = append(s.pendingMessages, msg) + e := helpers.NewExecutionTerminatedEvent(terminate.Reason, terminate.Recurse) + msg := OrchestratorMessage{TargetInstanceID: terminate.InstanceId, HistoryEvent: e} + s.newEvents = append(s.newEvents, e) + changes.NewMessages = append(changes.NewMessages, msg) } else { - return false, fmt.Errorf("unknown action type: %v", action) + return OrchestrationStateChanges{}, fmt.Errorf("unknown action type: %v", action) } + s.newActions = s.newActions[1:] } - return false, nil + // All changes were applied, so we set the IsPartial flag to false. + changes.IsPartial = false + return changes, nil +} + +// SetFailed adds a failure action to the orchestration state and removes all other pending actions. +func (s *OrchestrationRuntimeState) SetFailed(err error) { + // Clear the list of pending events since we don't care about these anymore. + s.pendingEvents = nil + + // Add a fake "execution started" event if one doesn't already exist. + if s.startEvent == nil { + s.newEvents = append(s.newEvents, helpers.NewExecutionStartedEvent( + "(Unknown)", + string(s.instanceID), + )) + + // Apply an "orchestration failed" action to the current state. + s.newActions = []*protos.OrchestratorAction{helpers.NewCompleteOrchestrationAction( + -1, + protos.OrchestrationStatus_ORCHESTRATION_STATUS_FAILED, + nil, + nil, + helpers.NewTaskFailureDetails(err), + )} } func (s *OrchestrationRuntimeState) InstanceID() api.InstanceID { @@ -297,18 +496,6 @@ func (s *OrchestrationRuntimeState) FailureDetails() (*TaskFailureDetails, error return s.completedEvent.FailureDetails, nil } -func (s *OrchestrationRuntimeState) PendingTimers() []*HistoryEvent { - return s.pendingTimers -} - -func (s *OrchestrationRuntimeState) PendingTasks() []*HistoryEvent { - return s.pendingTasks -} - -func (s *OrchestrationRuntimeState) PendingMessages() []OrchestratorMessage { - return s.pendingMessages -} - func (s *OrchestrationRuntimeState) ContinuedAsNew() bool { return s.continuedAsNew } diff --git a/backend/sqlite/sqlite.go b/backend/sqlite/sqlite.go index 3ce4dbc..43cdacb 100644 --- a/backend/sqlite/sqlite.go +++ b/backend/sqlite/sqlite.go @@ -189,7 +189,7 @@ func (be *sqliteBackend) AbandonOrchestrationWorkItem(ctx context.Context, wi *b } // CompleteOrchestrationWorkItem implements backend.Backend -func (be *sqliteBackend) CompleteOrchestrationWorkItem(ctx context.Context, wi *backend.OrchestrationWorkItem) error { +func (be *sqliteBackend) CompleteOrchestrationWorkItem(ctx context.Context, wi *backend.OrchestrationWorkItem, changes backend.OrchestrationStateChanges) error { if err := be.ensureDB(); err != nil { return err } @@ -210,7 +210,7 @@ func (be *sqliteBackend) CompleteOrchestrationWorkItem(ctx context.Context, wi * isCreated := false isCompleted := false - for _, e := range wi.State.NewEvents() { + for _, e := range changes.NewEvents { if es := e.GetExecutionStarted(); es != nil { if isCreated { // TODO: Log warning about duplicate start event @@ -239,19 +239,19 @@ func (be *sqliteBackend) CompleteOrchestrationWorkItem(ctx context.Context, wi * sqlUpdateArgs = append(sqlUpdateArgs, nil) } } - // TODO: Execution suspended & resumed } - if wi.State.CustomStatus != nil { + if changes.CustomStatus != nil { sqlSB.WriteString("[CustomStatus] = ?, ") - sqlUpdateArgs = append(sqlUpdateArgs, wi.State.CustomStatus.Value) + sqlUpdateArgs = append(sqlUpdateArgs, changes.CustomStatus.Value) } // TODO: Support for stickiness, which would extend the LockExpiration sqlSB.WriteString("[RuntimeStatus] = ?, [LastUpdatedTime] = ?, [LockExpiration] = NULL WHERE [InstanceID] = ? AND [LockedBy] = ?") - sqlUpdateArgs = append(sqlUpdateArgs, helpers.ToRuntimeStatusString(wi.State.RuntimeStatus()), now, string(wi.InstanceID), wi.LockedBy) + sqlUpdateArgs = append(sqlUpdateArgs, helpers.ToRuntimeStatusString(changes.RuntimeStatus), now, string(wi.InstanceID), wi.LockedBy) - result, err := tx.ExecContext(ctx, sqlSB.String(), sqlUpdateArgs...) + updateSql := sqlSB.String() + result, err := tx.ExecContext(ctx, updateSql, sqlUpdateArgs...) if err != nil { return fmt.Errorf("failed to update Instances table: %w", err) } @@ -264,21 +264,23 @@ func (be *sqliteBackend) CompleteOrchestrationWorkItem(ctx context.Context, wi * } // If continue-as-new, delete all existing history - if wi.State.ContinuedAsNew() { + if changes.ContinuedAsNew { + be.logger.Debugf("%v: Deleting all existing history events as part of continue-as-new", wi.InstanceID) if _, err := tx.ExecContext(ctx, "DELETE FROM History WHERE InstanceID = ?", string(wi.InstanceID)); err != nil { return fmt.Errorf("failed to delete from History table: %w", err) } } // Save new history events - newHistoryCount := len(wi.State.NewEvents()) + newHistoryCount := len(changes.NewEvents) if newHistoryCount > 0 { query := "INSERT INTO History ([InstanceID], [SequenceNumber], [EventPayload]) VALUES (?, ?, ?)" + strings.Repeat(", (?, ?, ?)", newHistoryCount-1) args := make([]interface{}, 0, newHistoryCount*3) - nextSequenceNumber := len(wi.State.OldEvents()) - for _, e := range wi.State.NewEvents() { + startIndex := len(wi.State.OldEvents()) + changes.HistoryStartIndex + nextSequenceNumber := startIndex + for _, e := range changes.NewEvents { eventPayload, err := backend.MarshalHistoryEvent(e) if err != nil { return err @@ -288,6 +290,13 @@ func (be *sqliteBackend) CompleteOrchestrationWorkItem(ctx context.Context, wi * nextSequenceNumber++ } + be.logger.Debugf( + "%v: Inserting %d new history events with sequence numbers %d thru %d", + wi.InstanceID, + newHistoryCount, + startIndex, + nextSequenceNumber-1) + _, err = tx.ExecContext(ctx, query, args...) if err != nil { return fmt.Errorf("failed to insert into the History table: %w", err) @@ -295,13 +304,13 @@ func (be *sqliteBackend) CompleteOrchestrationWorkItem(ctx context.Context, wi * } // Save outbound activity tasks - newActivityCount := len(wi.State.PendingTasks()) + newActivityCount := len(changes.NewTasks) if newActivityCount > 0 { insertSql := "INSERT INTO NewTasks ([InstanceID], [EventPayload]) VALUES (?, ?)" + strings.Repeat(", (?, ?)", newActivityCount-1) sqlInsertArgs := make([]interface{}, 0, newActivityCount*2) - for _, e := range wi.State.PendingTasks() { + for _, e := range changes.NewTasks { eventPayload, err := backend.MarshalHistoryEvent(e) if err != nil { return err @@ -317,13 +326,14 @@ func (be *sqliteBackend) CompleteOrchestrationWorkItem(ctx context.Context, wi * } // Save outbound orchestrator events - newEventCount := len(wi.State.PendingTimers()) + len(wi.State.PendingMessages()) + newEventCount := len(changes.NewTimers) + len(changes.NewMessages) if newEventCount > 0 { + be.logger.Debugf("%v: Inserting %d rows into the NewEvents table", wi.InstanceID, newEventCount) insertSql := "INSERT INTO NewEvents ([InstanceID], [EventPayload], [VisibleTime]) VALUES (?, ?, ?)" + strings.Repeat(", (?, ?, ?)", newEventCount-1) sqlInsertArgs := make([]interface{}, 0, newEventCount*3) - for _, e := range wi.State.PendingTimers() { + for _, e := range changes.NewTimers { eventPayload, err := backend.MarshalHistoryEvent(e) if err != nil { return err @@ -333,7 +343,7 @@ func (be *sqliteBackend) CompleteOrchestrationWorkItem(ctx context.Context, wi * sqlInsertArgs = append(sqlInsertArgs, string(wi.InstanceID), eventPayload, visibileTime) } - for _, msg := range wi.State.PendingMessages() { + for _, msg := range changes.NewMessages { if es := msg.HistoryEvent.GetExecutionStarted(); es != nil { // Need to insert a new row into the DB if _, err := be.createOrchestrationInstanceInternal(ctx, msg.HistoryEvent, tx); err != nil { @@ -362,22 +372,26 @@ func (be *sqliteBackend) CompleteOrchestrationWorkItem(ctx context.Context, wi * } } - // Delete inbound events - dbResult, err := tx.ExecContext( - ctx, - "DELETE FROM NewEvents WHERE [InstanceID] = ? AND [LockedBy] = ?", - string(wi.InstanceID), - wi.LockedBy, - ) - if err != nil { - return fmt.Errorf("failed to delete from NewEvents table: %w", err) - } + // Delete inbound events only on the final "chunk", if completion is being done in batches + if !changes.IsPartial { + dbResult, err := tx.ExecContext( + ctx, + "DELETE FROM NewEvents WHERE [InstanceID] = ? AND [LockedBy] = ?", + string(wi.InstanceID), + wi.LockedBy, + ) + if err != nil { + return fmt.Errorf("failed to delete from NewEvents table: %w", err) + } - rowsAffected, err := dbResult.RowsAffected() - if err != nil { - return fmt.Errorf("failed get rows affected by delete statement: %w", err) - } else if rowsAffected == 0 { - return backend.ErrWorkItemLockLost + rowsAffected, err := dbResult.RowsAffected() + if err != nil { + return fmt.Errorf("failed get rows affected by delete statement: %w", err) + } else if rowsAffected == 0 { + // Log a warning about no events to delete. This is not severe enough to be considered an error. + be.logger.Warnf("%v: no incoming events to delete", wi.InstanceID) + } + be.logger.Debugf("%v: Deleted %d row(s) from the NewEvents table", wi.InstanceID, rowsAffected) } if err != nil { diff --git a/backend/worker.go b/backend/worker.go index 3ccca40..deb0d6d 100644 --- a/backend/worker.go +++ b/backend/worker.go @@ -28,7 +28,6 @@ type TaskProcessor interface { FetchWorkItem(context.Context) (WorkItem, error) ProcessWorkItem(context.Context, WorkItem) error AbandonWorkItem(context.Context, WorkItem) error - CompleteWorkItem(context.Context, WorkItem) error } type worker struct { @@ -219,12 +218,5 @@ func (w *worker) processWorkItem(ctx context.Context, wi WorkItem) { return } - if err := w.processor.CompleteWorkItem(ctx, wi); err != nil { - w.logger.Errorf("%v: failed to complete work item: %v", w.Name(), err) - if err := w.processor.AbandonWorkItem(ctx, wi); err != nil { - w.logger.Errorf("%v: failed to abandon work item: %v", w.Name(), err) - } - } - w.logger.Debugf("%v: work item processed successfully", w.Name()) } diff --git a/internal/helpers/history.go b/internal/helpers/history.go index e5d8cda..72af197 100644 --- a/internal/helpers/history.go +++ b/internal/helpers/history.go @@ -7,6 +7,8 @@ import ( "time" "github.com/google/uuid" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" @@ -52,6 +54,10 @@ func NewExecutionCompletedEvent(eventID int32, status protos.OrchestrationStatus } } +func NewExecutionFailedEvent(eventID int32, failureDetails *protos.TaskFailureDetails) *protos.HistoryEvent { + return NewExecutionCompletedEvent(eventID, protos.OrchestrationStatus_ORCHESTRATION_STATUS_FAILED, nil, failureDetails) +} + func NewExecutionTerminatedEvent(rawReason *wrapperspb.StringValue, recurse bool) *protos.HistoryEvent { return &protos.HistoryEvent{ EventId: -1, @@ -142,11 +148,7 @@ func NewTimerCreatedEvent(eventID int32, fireAt *timestamppb.Timestamp) *protos. } } -func NewTimerFiredEvent( - timerID int32, - fireAt *timestamppb.Timestamp, - parentTraceContext *protos.TraceContext, -) *protos.HistoryEvent { +func NewTimerFiredEvent(timerID int32, fireAt *timestamppb.Timestamp) *protos.HistoryEvent { return &protos.HistoryEvent{ EventId: -1, Timestamp: timestamppb.New(time.Now()), @@ -182,6 +184,32 @@ func NewSubOrchestrationCreatedEvent( } } +func NewSubOrchestrationCompletedEvent(taskID int32, result *wrapperspb.StringValue) *protos.HistoryEvent { + return &protos.HistoryEvent{ + EventId: -1, + Timestamp: timestamppb.New(time.Now()), + EventType: &protos.HistoryEvent_SubOrchestrationInstanceCompleted{ + SubOrchestrationInstanceCompleted: &protos.SubOrchestrationInstanceCompletedEvent{ + TaskScheduledId: taskID, + Result: result, + }, + }, + } +} + +func NewSubOrchestrationFailedEvent(taskID int32, failureDetails *protos.TaskFailureDetails) *protos.HistoryEvent { + return &protos.HistoryEvent{ + EventId: -1, + Timestamp: timestamppb.New(time.Now()), + EventType: &protos.HistoryEvent_SubOrchestrationInstanceFailed{ + SubOrchestrationInstanceFailed: &protos.SubOrchestrationInstanceFailedEvent{ + TaskScheduledId: taskID, + FailureDetails: failureDetails, + }, + }, + } +} + func NewSendEventEvent(eventID int32, instanceID string, name string, rawInput *wrapperspb.StringValue) *protos.HistoryEvent { return &protos.HistoryEvent{ EventId: eventID, @@ -328,6 +356,10 @@ func NewTaskFailureDetails(err error) *protos.TaskFailureDetails { } } +func GetProtoSize(m protoreflect.ProtoMessage) int { + return proto.Size(m) +} + func HistoryListSummary(list []*protos.HistoryEvent) string { var sb strings.Builder sb.WriteString("[") @@ -339,13 +371,18 @@ func HistoryListSummary(list []*protos.HistoryEvent) string { sb.WriteString("...") break } - name := getHistoryEventTypeName(e) + name := GetHistoryEventTypeName(e) sb.WriteString(name) taskID := GetTaskId(e) - if taskID > -0 { + if taskID >= 0 { sb.WriteRune('#') sb.WriteString(strconv.FormatInt(int64(taskID), 10)) } + if x := e.GetTimerFired(); x != nil { + sb.WriteString(" (") + sb.WriteString(x.FireAt.AsTime().Format(time.RFC3339)) + sb.WriteString(")") + } } sb.WriteString("]") return sb.String() @@ -368,6 +405,16 @@ func ActionListSummary(actions []*protos.OrchestratorAction) string { sb.WriteRune('#') sb.WriteString(strconv.FormatInt(int64(a.Id), 10)) } + if x := a.GetCompleteOrchestration(); x != nil { + sb.WriteString(" (") + sb.WriteString(ToRuntimeStatusString(x.OrchestrationStatus)) + sb.WriteString(")") + } + if x := a.GetCreateTimer(); x != nil { + sb.WriteString(" (") + sb.WriteString(x.FireAt.AsTime().Format(time.RFC3339)) + sb.WriteString(")") + } } sb.WriteString("]") return sb.String() @@ -403,12 +450,18 @@ func FromRuntimeStatusString(status string) protos.OrchestrationStatus { return protos.OrchestrationStatus(protos.OrchestrationStatus_value[runtimeStatus]) } -func getHistoryEventTypeName(e *protos.HistoryEvent) string { +func GetHistoryEventTypeName(e *protos.HistoryEvent) string { + if e.EventType == nil { + return "Unknown" + } // PERFORMANCE: Replace this with a switch statement or a map lookup to avoid this use of reflection return reflect.TypeOf(e.EventType).Elem().Name()[len("HistoryEvent_"):] } func getActionTypeName(a *protos.OrchestratorAction) string { + if a.OrchestratorActionType == nil { + return "Unknown" + } // PERFORMANCE: Replace this with a switch statement or a map lookup to avoid this use of reflection return reflect.TypeOf(a.OrchestratorActionType).Elem().Name()[len("OrchestratorAction_"):] } diff --git a/task/orchestrator.go b/task/orchestrator.go index d51041a..0efeba1 100644 --- a/task/orchestrator.go +++ b/task/orchestrator.go @@ -222,46 +222,54 @@ func (octx *OrchestrationContext) GetInput(v any) error { // parameter can be either the name of an activity as a string or can be a pointer to the function // that implements the activity, in which case the name is obtained via reflection. func (ctx *OrchestrationContext) CallActivity(activity interface{}, opts ...callActivityOption) Task { + activityName := helpers.GetTaskFunctionName(activity) + sequentNumber := ctx.getNextSequenceNumber() + description := fmt.Sprintf("%s#%d", activityName, sequentNumber) + options := new(callActivityOptions) for _, configure := range opts { if err := configure(options); err != nil { - failedTask := newTask(ctx) + failedTask := newTask(ctx, description) failedTask.fail(helpers.NewTaskFailureDetails(err)) return failedTask } } scheduleTaskAction := helpers.NewScheduleTaskAction( - ctx.getNextSequenceNumber(), + sequentNumber, helpers.GetTaskFunctionName(activity), options.rawInput) ctx.pendingActions[scheduleTaskAction.Id] = scheduleTaskAction - task := newTask(ctx) + task := newTask(ctx, description) ctx.pendingTasks[scheduleTaskAction.Id] = task return task } func (ctx *OrchestrationContext) CallSubOrchestrator(orchestrator interface{}, opts ...subOrchestratorOption) Task { + name := helpers.GetTaskFunctionName(orchestrator) + sequentNumber := ctx.getNextSequenceNumber() + description := fmt.Sprintf("%s#%d", name, sequentNumber) + options := new(callSubOrchestratorOptions) for _, configure := range opts { if err := configure(options); err != nil { - failedTask := newTask(ctx) + failedTask := newTask(ctx, description) failedTask.fail(helpers.NewTaskFailureDetails(err)) return failedTask } } createSubOrchestrationAction := helpers.NewCreateSubOrchestrationAction( - ctx.getNextSequenceNumber(), - helpers.GetTaskFunctionName(orchestrator), + sequentNumber, + name, options.instanceID, options.rawInput, ) ctx.pendingActions[createSubOrchestrationAction.Id] = createSubOrchestrationAction - task := newTask(ctx) + task := newTask(ctx, description) ctx.pendingTasks[createSubOrchestrationAction.Id] = task return task } @@ -276,7 +284,8 @@ func (ctx *OrchestrationContext) createTimerInternal(delay time.Duration) *compl timerAction := helpers.NewCreateTimerAction(ctx.getNextSequenceNumber(), fireAt) ctx.pendingActions[timerAction.Id] = timerAction - task := newTask(ctx) + description := fmt.Sprintf("timer#%d", timerAction.Id) + task := newTask(ctx, description) ctx.pendingTasks[timerAction.Id] = task return task } @@ -295,7 +304,8 @@ func (ctx *OrchestrationContext) createTimerInternal(delay time.Duration) *compl // // Note that event names are case-insensitive. func (ctx *OrchestrationContext) WaitForSingleEvent(eventName string, timeout time.Duration) Task { - task := newTask(ctx) + description := fmt.Sprintf("event:%s", eventName) + task := newTask(ctx, description) key := strings.ToUpper(eventName) if eventList, ok := ctx.bufferedExternalEvents[key]; ok { // An event with this name arrived already and can be consumed immediately. diff --git a/task/task.go b/task/task.go index 4e690c1..3b2f890 100644 --- a/task/task.go +++ b/task/task.go @@ -23,6 +23,7 @@ type Task interface { type completableTask struct { orchestrationCtx *OrchestrationContext + description string isCompleted bool isCanceled bool rawResult []byte @@ -30,9 +31,10 @@ type completableTask struct { completedCallback func() } -func newTask(ctx *OrchestrationContext) *completableTask { +func newTask(ctx *OrchestrationContext, description string) *completableTask { return &completableTask{ orchestrationCtx: ctx, + description: description, } } @@ -49,13 +51,13 @@ func (t *completableTask) Await(v any) error { for { if t.isCompleted { if t.failureDetails != nil { - return fmt.Errorf("task failed with an error: %v", t.failureDetails.ErrorMessage) + return fmt.Errorf("task '%s' failed with an error: %v", t.description, t.failureDetails.ErrorMessage) } else if t.isCanceled { return ErrTaskCanceled } if v != nil && len(t.rawResult) > 0 { if err := unmarshalData(t.rawResult, v); err != nil { - return fmt.Errorf("failed to decode task result: %w", err) + return fmt.Errorf("failed to decode task '%s' result: %w", t.description, err) } } return nil diff --git a/tests/backend_test.go b/tests/backend_test.go index 2b3f414..56f47e4 100644 --- a/tests/backend_test.go +++ b/tests/backend_test.go @@ -14,6 +14,7 @@ import ( "github.com/microsoft/durabletask-go/internal/helpers" "github.com/microsoft/durabletask-go/internal/protos" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -184,7 +185,7 @@ func Test_ScheduleActivityTasks(t *testing.T) { for i, be := range backends { initTest(t, be, i, true) - wi, err := be.GetActivityWorkItem(ctx) + _, err := be.GetActivityWorkItem(ctx) if !assert.ErrorIs(t, err, backend.ErrNoWorkItems) { continue } @@ -209,7 +210,7 @@ func Test_ScheduleActivityTasks(t *testing.T) { assert.ErrorIs(t, err, backend.ErrNoWorkItems) // However, there should be an activity work item - wi, err = be.GetActivityWorkItem(ctx) + wi, err := be.GetActivityWorkItem(ctx) if assert.NoError(t, err) && assert.NotNil(t, wi) { assert.Equal(t, expectedName, wi.NewEvent.GetTaskScheduled().GetName()) assert.Equal(t, expectedInput, wi.NewEvent.GetTaskScheduled().GetInput().GetValue()) @@ -321,6 +322,7 @@ func Test_AbandonActivityWorkItem(t *testing.T) { if err := be.AbandonActivityWorkItem(ctx, wi); assert.NoError(t, err) { // Re-fetch the abandoned activity work item wi, err = be.GetActivityWorkItem(ctx) + require.NoError(t, err) assert.Equal(t, "MyActivity", wi.NewEvent.GetTaskScheduled().GetName()) assert.Equal(t, int32(123), wi.NewEvent.EventId) assert.Nil(t, wi.NewEvent.GetTaskScheduled().GetInput()) @@ -335,7 +337,7 @@ func Test_UninitializedBackend(t *testing.T) { err := be.AbandonOrchestrationWorkItem(ctx, nil) assert.Equal(t, err, backend.ErrNotInitialized) - err = be.CompleteOrchestrationWorkItem(ctx, nil) + err = be.CompleteOrchestrationWorkItem(ctx, nil, backend.OrchestrationStateChanges{}) assert.Equal(t, err, backend.ErrNotInitialized) err = be.CreateOrchestrationInstance(ctx, nil) assert.Equal(t, err, backend.ErrNotInitialized) @@ -443,32 +445,34 @@ func workItemProcessingTestLogic( } actions := getOrchestratorActions() - _, err := state.ApplyActions(actions, nil) - if assert.NoError(t, err) { - wi.State = state - err := be.CompleteOrchestrationWorkItem(ctx, wi) + state.AddActions(actions) + + wi.State = state + changes, err := state.ProcessChanges(defaultChunkingConfig, nil, logger) + require.NoError(t, err) + + err = be.CompleteOrchestrationWorkItem(ctx, wi, changes) + require.NoError(t, err) + + // Validate runtime state + if state, ok = getOrchestrationRuntimeState(t, be, wi); ok { + createdTime, err := state.CreatedTime() if assert.NoError(t, err) { - // Validate runtime state - if state, ok = getOrchestrationRuntimeState(t, be, wi); ok { - createdTime, err := state.CreatedTime() - if assert.NoError(t, err) { - assert.GreaterOrEqual(t, createdTime, startTime) - } - - // State should be initialized with only "old" events - assert.Empty(t, state.NewEvents()) - assert.NotEmpty(t, state.OldEvents()) - - // Validate orchestration metadata - if metadata, ok := getOrchestrationMetadata(t, be, state.InstanceID()); ok { - assert.Equal(t, defaultName, metadata.Name) - assert.Equal(t, defaultInput, metadata.SerializedInput) - assert.Equal(t, createdTime, metadata.CreatedAt) - assert.Equal(t, state.RuntimeStatus(), metadata.RuntimeStatus) - - validateMetadata(metadata) - } - } + assert.GreaterOrEqual(t, createdTime, startTime) + } + + // State should be initialized with only "old" events + assert.Empty(t, state.NewEvents()) + assert.NotEmpty(t, state.OldEvents()) + + // Validate orchestration metadata + if metadata, ok := getOrchestrationMetadata(t, be, state.InstanceID()); ok { + assert.Equal(t, defaultName, metadata.Name) + assert.Equal(t, defaultInput, metadata.SerializedInput) + assert.Equal(t, createdTime, metadata.CreatedAt) + assert.Equal(t, state.RuntimeStatus(), metadata.RuntimeStatus) + + validateMetadata(metadata) } } } diff --git a/tests/grpc/grpc_test.go b/tests/grpc/grpc_test.go index 8f0de3c..aae9523 100644 --- a/tests/grpc/grpc_test.go +++ b/tests/grpc/grpc_test.go @@ -173,7 +173,7 @@ func Test_Grpc_Terminate_Recursive(t *testing.T) { r := task.NewTaskRegistry() r.AddOrchestratorN("Root", func(ctx *task.OrchestrationContext) (any, error) { tasks := []task.Task{} - for i := 0; i < 5; i++ { + for i := 0; i < 3; i++ { task := ctx.CallSubOrchestrator("L1") tasks = append(tasks, task) } diff --git a/tests/mocks/Backend.go b/tests/mocks/Backend.go index a59f0f1..827e638 100644 --- a/tests/mocks/Backend.go +++ b/tests/mocks/Backend.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. +// Code generated by mockery v2.37.0. DO NOT EDIT. package mocks @@ -30,10 +30,6 @@ func (_m *Backend) EXPECT() *Backend_Expecter { func (_m *Backend) AbandonActivityWorkItem(_a0 context.Context, _a1 *backend.ActivityWorkItem) error { ret := _m.Called(_a0, _a1) - if len(ret) == 0 { - panic("no return value specified for AbandonActivityWorkItem") - } - var r0 error if rf, ok := ret.Get(0).(func(context.Context, *backend.ActivityWorkItem) error); ok { r0 = rf(_a0, _a1) @@ -77,10 +73,6 @@ func (_c *Backend_AbandonActivityWorkItem_Call) RunAndReturn(run func(context.Co func (_m *Backend) AbandonOrchestrationWorkItem(_a0 context.Context, _a1 *backend.OrchestrationWorkItem) error { ret := _m.Called(_a0, _a1) - if len(ret) == 0 { - panic("no return value specified for AbandonOrchestrationWorkItem") - } - var r0 error if rf, ok := ret.Get(0).(func(context.Context, *backend.OrchestrationWorkItem) error); ok { r0 = rf(_a0, _a1) @@ -124,10 +116,6 @@ func (_c *Backend_AbandonOrchestrationWorkItem_Call) RunAndReturn(run func(conte func (_m *Backend) AddNewOrchestrationEvent(_a0 context.Context, _a1 api.InstanceID, _a2 *protos.HistoryEvent) error { ret := _m.Called(_a0, _a1, _a2) - if len(ret) == 0 { - panic("no return value specified for AddNewOrchestrationEvent") - } - var r0 error if rf, ok := ret.Get(0).(func(context.Context, api.InstanceID, *protos.HistoryEvent) error); ok { r0 = rf(_a0, _a1, _a2) @@ -172,10 +160,6 @@ func (_c *Backend_AddNewOrchestrationEvent_Call) RunAndReturn(run func(context.C func (_m *Backend) CompleteActivityWorkItem(_a0 context.Context, _a1 *backend.ActivityWorkItem) error { ret := _m.Called(_a0, _a1) - if len(ret) == 0 { - panic("no return value specified for CompleteActivityWorkItem") - } - var r0 error if rf, ok := ret.Get(0).(func(context.Context, *backend.ActivityWorkItem) error); ok { r0 = rf(_a0, _a1) @@ -215,17 +199,13 @@ func (_c *Backend_CompleteActivityWorkItem_Call) RunAndReturn(run func(context.C return _c } -// CompleteOrchestrationWorkItem provides a mock function with given fields: _a0, _a1 -func (_m *Backend) CompleteOrchestrationWorkItem(_a0 context.Context, _a1 *backend.OrchestrationWorkItem) error { - ret := _m.Called(_a0, _a1) - - if len(ret) == 0 { - panic("no return value specified for CompleteOrchestrationWorkItem") - } +// CompleteOrchestrationWorkItem provides a mock function with given fields: _a0, _a1, _a2 +func (_m *Backend) CompleteOrchestrationWorkItem(_a0 context.Context, _a1 *backend.OrchestrationWorkItem, _a2 backend.OrchestrationStateChanges) error { + ret := _m.Called(_a0, _a1, _a2) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *backend.OrchestrationWorkItem) error); ok { - r0 = rf(_a0, _a1) + if rf, ok := ret.Get(0).(func(context.Context, *backend.OrchestrationWorkItem, backend.OrchestrationStateChanges) error); ok { + r0 = rf(_a0, _a1, _a2) } else { r0 = ret.Error(0) } @@ -241,13 +221,14 @@ type Backend_CompleteOrchestrationWorkItem_Call struct { // CompleteOrchestrationWorkItem is a helper method to define mock.On call // - _a0 context.Context // - _a1 *backend.OrchestrationWorkItem -func (_e *Backend_Expecter) CompleteOrchestrationWorkItem(_a0 interface{}, _a1 interface{}) *Backend_CompleteOrchestrationWorkItem_Call { - return &Backend_CompleteOrchestrationWorkItem_Call{Call: _e.mock.On("CompleteOrchestrationWorkItem", _a0, _a1)} +// - _a2 backend.OrchestrationStateChanges +func (_e *Backend_Expecter) CompleteOrchestrationWorkItem(_a0 interface{}, _a1 interface{}, _a2 interface{}) *Backend_CompleteOrchestrationWorkItem_Call { + return &Backend_CompleteOrchestrationWorkItem_Call{Call: _e.mock.On("CompleteOrchestrationWorkItem", _a0, _a1, _a2)} } -func (_c *Backend_CompleteOrchestrationWorkItem_Call) Run(run func(_a0 context.Context, _a1 *backend.OrchestrationWorkItem)) *Backend_CompleteOrchestrationWorkItem_Call { +func (_c *Backend_CompleteOrchestrationWorkItem_Call) Run(run func(_a0 context.Context, _a1 *backend.OrchestrationWorkItem, _a2 backend.OrchestrationStateChanges)) *Backend_CompleteOrchestrationWorkItem_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*backend.OrchestrationWorkItem)) + run(args[0].(context.Context), args[1].(*backend.OrchestrationWorkItem), args[2].(backend.OrchestrationStateChanges)) }) return _c } @@ -257,7 +238,7 @@ func (_c *Backend_CompleteOrchestrationWorkItem_Call) Return(_a0 error) *Backend return _c } -func (_c *Backend_CompleteOrchestrationWorkItem_Call) RunAndReturn(run func(context.Context, *backend.OrchestrationWorkItem) error) *Backend_CompleteOrchestrationWorkItem_Call { +func (_c *Backend_CompleteOrchestrationWorkItem_Call) RunAndReturn(run func(context.Context, *backend.OrchestrationWorkItem, backend.OrchestrationStateChanges) error) *Backend_CompleteOrchestrationWorkItem_Call { _c.Call.Return(run) return _c } @@ -273,10 +254,6 @@ func (_m *Backend) CreateOrchestrationInstance(_a0 context.Context, _a1 *protos. _ca = append(_ca, _va...) ret := _m.Called(_ca...) - if len(ret) == 0 { - panic("no return value specified for CreateOrchestrationInstance") - } - var r0 error if rf, ok := ret.Get(0).(func(context.Context, *protos.HistoryEvent, ...backend.OrchestrationIdReusePolicyOptions) error); ok { r0 = rf(_a0, _a1, _a2...) @@ -328,10 +305,6 @@ func (_c *Backend_CreateOrchestrationInstance_Call) RunAndReturn(run func(contex func (_m *Backend) CreateTaskHub(_a0 context.Context) error { ret := _m.Called(_a0) - if len(ret) == 0 { - panic("no return value specified for CreateTaskHub") - } - var r0 error if rf, ok := ret.Get(0).(func(context.Context) error); ok { r0 = rf(_a0) @@ -374,10 +347,6 @@ func (_c *Backend_CreateTaskHub_Call) RunAndReturn(run func(context.Context) err func (_m *Backend) DeleteTaskHub(_a0 context.Context) error { ret := _m.Called(_a0) - if len(ret) == 0 { - panic("no return value specified for DeleteTaskHub") - } - var r0 error if rf, ok := ret.Get(0).(func(context.Context) error); ok { r0 = rf(_a0) @@ -420,10 +389,6 @@ func (_c *Backend_DeleteTaskHub_Call) RunAndReturn(run func(context.Context) err func (_m *Backend) GetActivityWorkItem(_a0 context.Context) (*backend.ActivityWorkItem, error) { ret := _m.Called(_a0) - if len(ret) == 0 { - panic("no return value specified for GetActivityWorkItem") - } - var r0 *backend.ActivityWorkItem var r1 error if rf, ok := ret.Get(0).(func(context.Context) (*backend.ActivityWorkItem, error)); ok { @@ -478,10 +443,6 @@ func (_c *Backend_GetActivityWorkItem_Call) RunAndReturn(run func(context.Contex func (_m *Backend) GetOrchestrationMetadata(_a0 context.Context, _a1 api.InstanceID) (*api.OrchestrationMetadata, error) { ret := _m.Called(_a0, _a1) - if len(ret) == 0 { - panic("no return value specified for GetOrchestrationMetadata") - } - var r0 *api.OrchestrationMetadata var r1 error if rf, ok := ret.Get(0).(func(context.Context, api.InstanceID) (*api.OrchestrationMetadata, error)); ok { @@ -537,10 +498,6 @@ func (_c *Backend_GetOrchestrationMetadata_Call) RunAndReturn(run func(context.C func (_m *Backend) GetOrchestrationRuntimeState(_a0 context.Context, _a1 *backend.OrchestrationWorkItem) (*backend.OrchestrationRuntimeState, error) { ret := _m.Called(_a0, _a1) - if len(ret) == 0 { - panic("no return value specified for GetOrchestrationRuntimeState") - } - var r0 *backend.OrchestrationRuntimeState var r1 error if rf, ok := ret.Get(0).(func(context.Context, *backend.OrchestrationWorkItem) (*backend.OrchestrationRuntimeState, error)); ok { @@ -596,10 +553,6 @@ func (_c *Backend_GetOrchestrationRuntimeState_Call) RunAndReturn(run func(conte func (_m *Backend) GetOrchestrationWorkItem(_a0 context.Context) (*backend.OrchestrationWorkItem, error) { ret := _m.Called(_a0) - if len(ret) == 0 { - panic("no return value specified for GetOrchestrationWorkItem") - } - var r0 *backend.OrchestrationWorkItem var r1 error if rf, ok := ret.Get(0).(func(context.Context) (*backend.OrchestrationWorkItem, error)); ok { @@ -654,10 +607,6 @@ func (_c *Backend_GetOrchestrationWorkItem_Call) RunAndReturn(run func(context.C func (_m *Backend) PurgeOrchestrationState(_a0 context.Context, _a1 api.InstanceID) error { ret := _m.Called(_a0, _a1) - if len(ret) == 0 { - panic("no return value specified for PurgeOrchestrationState") - } - var r0 error if rf, ok := ret.Get(0).(func(context.Context, api.InstanceID) error); ok { r0 = rf(_a0, _a1) @@ -701,10 +650,6 @@ func (_c *Backend_PurgeOrchestrationState_Call) RunAndReturn(run func(context.Co func (_m *Backend) Start(_a0 context.Context) error { ret := _m.Called(_a0) - if len(ret) == 0 { - panic("no return value specified for Start") - } - var r0 error if rf, ok := ret.Get(0).(func(context.Context) error); ok { r0 = rf(_a0) @@ -747,10 +692,6 @@ func (_c *Backend_Start_Call) RunAndReturn(run func(context.Context) error) *Bac func (_m *Backend) Stop(_a0 context.Context) error { ret := _m.Called(_a0) - if len(ret) == 0 { - panic("no return value specified for Stop") - } - var r0 error if rf, ok := ret.Get(0).(func(context.Context) error); ok { r0 = rf(_a0) diff --git a/tests/mocks/Executor.go b/tests/mocks/Executor.go index 659ac36..03365c5 100644 --- a/tests/mocks/Executor.go +++ b/tests/mocks/Executor.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. +// Code generated by mockery v2.37.0. DO NOT EDIT. package mocks @@ -30,10 +30,6 @@ func (_m *Executor) EXPECT() *Executor_Expecter { func (_m *Executor) ExecuteActivity(_a0 context.Context, _a1 api.InstanceID, _a2 *protos.HistoryEvent) (*protos.HistoryEvent, error) { ret := _m.Called(_a0, _a1, _a2) - if len(ret) == 0 { - panic("no return value specified for ExecuteActivity") - } - var r0 *protos.HistoryEvent var r1 error if rf, ok := ret.Get(0).(func(context.Context, api.InstanceID, *protos.HistoryEvent) (*protos.HistoryEvent, error)); ok { @@ -90,10 +86,6 @@ func (_c *Executor_ExecuteActivity_Call) RunAndReturn(run func(context.Context, func (_m *Executor) ExecuteOrchestrator(ctx context.Context, iid api.InstanceID, oldEvents []*protos.HistoryEvent, newEvents []*protos.HistoryEvent) (*backend.ExecutionResults, error) { ret := _m.Called(ctx, iid, oldEvents, newEvents) - if len(ret) == 0 { - panic("no return value specified for ExecuteOrchestrator") - } - var r0 *backend.ExecutionResults var r1 error if rf, ok := ret.Get(0).(func(context.Context, api.InstanceID, []*protos.HistoryEvent, []*protos.HistoryEvent) (*backend.ExecutionResults, error)); ok { diff --git a/tests/mocks/TaskWorker.go b/tests/mocks/TaskWorker.go index 7ee0f2c..878aa82 100644 --- a/tests/mocks/TaskWorker.go +++ b/tests/mocks/TaskWorker.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. +// Code generated by mockery v2.37.0. DO NOT EDIT. package mocks @@ -25,10 +25,6 @@ func (_m *TaskWorker) EXPECT() *TaskWorker_Expecter { func (_m *TaskWorker) ProcessNext(_a0 context.Context) (bool, error) { ret := _m.Called(_a0) - if len(ret) == 0 { - panic("no return value specified for ProcessNext") - } - var r0 bool var r1 error if rf, ok := ret.Get(0).(func(context.Context) (bool, error)); ok { diff --git a/tests/orchestrations_test.go b/tests/orchestrations_test.go index fde0390..6977f95 100644 --- a/tests/orchestrations_test.go +++ b/tests/orchestrations_test.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "sort" + "strconv" + "strings" "testing" "time" @@ -18,6 +20,27 @@ import ( "github.com/microsoft/durabletask-go/task" ) +type orchestrationTestConfig struct { + orchestrationWorkerOptions []backend.OrchestrationWorkerOption + activityWorkerOptions []backend.ActivityWorkerOption +} + +type testOption func(*orchestrationTestConfig) + +func withMaxParallelism(maxParallelWorkItems int32) testOption { + return func(cfg *orchestrationTestConfig) { + cfg.orchestrationWorkerOptions = append(cfg.orchestrationWorkerOptions, backend.WithMaxConcurrentOrchestratorInvocations(maxParallelWorkItems)) + cfg.activityWorkerOptions = append(cfg.activityWorkerOptions, backend.WithMaxConcurrentActivityInvocations(maxParallelWorkItems)) + } +} + +func withChunkingConfig(chunkingConfig backend.ChunkingConfiguration) testOption { + return func(cfg *orchestrationTestConfig) { + cfg.orchestrationWorkerOptions = append(cfg.orchestrationWorkerOptions, backend.WithChunkingConfiguration(chunkingConfig)) + cfg.activityWorkerOptions = append(cfg.activityWorkerOptions, backend.WithMaxActivityOutputSizeInKB(chunkingConfig.MaxHistoryEventSizeInKB)) + } +} + func Test_EmptyOrchestration(t *testing.T) { // Registration r := task.NewTaskRegistry() @@ -284,7 +307,7 @@ func Test_ActivityFanOut(t *testing.T) { // Initialization ctx := context.Background() exporter := initTracing() - client, worker := initTaskHubWorker(ctx, r, backend.WithMaxParallelism(10)) + client, worker := initTaskHubWorker(ctx, r, withMaxParallelism(10)) defer worker.Shutdown(ctx) // Run the orchestration @@ -455,22 +478,24 @@ func Test_ContinueAsNew_Events(t *testing.T) { return nil, nil }) - // Initialization + // Initialization - configure chunking for this test as a way to increase code coverage ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) + chunkingConfig := backend.ChunkingConfiguration{MaxHistoryEventCount: 1000} + client, worker := initTaskHubWorker(ctx, r, withChunkingConfig(chunkingConfig)) defer worker.Shutdown(ctx) // Run the orchestration + eventCount := 100 id, err := client.ScheduleNewOrchestration(ctx, "ContinueAsNewTest", api.WithInput(0)) require.NoError(t, err) - for i := 0; i < 10; i++ { + for i := 0; i < eventCount; i++ { require.NoError(t, client.RaiseEvent(ctx, id, "MyEvent", api.WithEventPayload(false))) } require.NoError(t, client.RaiseEvent(ctx, id, "MyEvent", api.WithEventPayload(true))) metadata, err := client.WaitForOrchestrationCompletion(ctx, id) require.NoError(t, err) assert.Equal(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_COMPLETED, metadata.RuntimeStatus) - assert.Equal(t, `10`, metadata.SerializedOutput) + assert.Equal(t, strconv.Itoa(eventCount), metadata.SerializedOutput) } func Test_ExternalEventContention(t *testing.T) { @@ -546,23 +571,20 @@ func Test_ExternalEventOrchestration(t *testing.T) { // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "ExternalEventOrchestration", api.WithInput(0)) - if assert.NoError(t, err) { - for i := 0; i < eventCount; i++ { - opts := api.WithEventPayload(i) - require.NoError(t, client.RaiseEvent(ctx, id, "MyEvent", opts)) - } - - timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - metadata, err := client.WaitForOrchestrationCompletion(timeoutCtx, id) - require.NoError(t, err) - require.True(t, metadata.IsComplete()) - require.Equal(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_COMPLETED, metadata.RuntimeStatus) + require.NoError(t, err) + for i := 0; i < eventCount; i++ { + opts := api.WithEventPayload(i) + require.NoError(t, client.RaiseEvent(ctx, id, "MyEvent", opts)) } + metadata, err := client.WaitForOrchestrationCompletion(ctx, id) + require.NoError(t, err) + require.True(t, metadata.IsComplete()) + require.Equal(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_COMPLETED, metadata.RuntimeStatus) + require.Equal(t, strconv.FormatBool(true), metadata.SerializedOutput) // Validate the exported OTel traces eventSizeInBytes := 1 + require.NoError(t, sharedSpanProcessor.ForceFlush(ctx)) spans := exporter.GetSpans().Snapshots() assertSpanSequence(t, spans, assertOrchestratorCreated("ExternalEventOrchestration", id), @@ -875,6 +897,251 @@ func Test_RecreateCompletedOrchestration(t *testing.T) { ) } +func Test_ContinueAsNew_InfiniteLoop(t *testing.T) { + // Count how many times we run the orchestrator function. We set a hard cap of 1M iterations, + // but the orchestrator should fail before we reach this number. + maxTestIterations := 1000000 + iteration := 0 + + // Registration + r := task.NewTaskRegistry() + r.AddOrchestratorN("InfiniteLoop", func(ctx *task.OrchestrationContext) (any, error) { + iteration++ + if iteration <= maxTestIterations { + ctx.ContinueAsNew(nil) + } + return nil, nil + }) + + // Initialization + ctx := context.Background() + client, worker := initTaskHubWorker(ctx, r) + defer worker.Shutdown(ctx) + + // Run the orchestration + id, err := client.ScheduleNewOrchestration(ctx, "InfiniteLoop") + require.NoError(t, err) + + // Wait for the orchestration to complete + metadata, err := client.WaitForOrchestrationCompletion(ctx, id) + require.NoError(t, err) + assert.True(t, metadata.IsComplete()) + assert.Equal(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_FAILED, metadata.RuntimeStatus) + if assert.NotNil(t, metadata.FailureDetails) { + assert.Contains(t, metadata.FailureDetails.ErrorMessage, "exceeded tight-loop continue-as-new limit") + } +} + +func Test_OrchestrationOutputTooLarge(t *testing.T) { + // Registration + r := task.NewTaskRegistry() + r.AddOrchestratorN("OrchestrationOutputTooLarge", func(ctx *task.OrchestrationContext) (any, error) { + // Return a payload that's larger than the configured max chunk size + return strings.Repeat("a", 3*1024), nil + }) + + // Initialization + ctx := context.Background() + client, worker := initTaskHubWorker(ctx, r, withChunkingConfig(backend.ChunkingConfiguration{MaxHistoryEventSizeInKB: 2})) + defer worker.Shutdown(ctx) + + // Run the orchestration with an input that's larger than the max chunk size + id, err := client.ScheduleNewOrchestration(ctx, "OrchestrationOutputTooLarge") + require.NoError(t, err) + + // Wait for the orchestration to fail + metadata, err := client.WaitForOrchestrationCompletion(ctx, id) + require.NoError(t, err) + assert.True(t, metadata.IsComplete()) + assert.Equal(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_FAILED, metadata.RuntimeStatus) + if assert.NotNil(t, metadata.FailureDetails) { + assert.Contains(t, metadata.FailureDetails.ErrorMessage, "exceeds") + assert.Contains(t, metadata.FailureDetails.ErrorMessage, "maximum") + assert.Contains(t, metadata.FailureDetails.ErrorMessage, "size") + assert.Contains(t, metadata.FailureDetails.ErrorMessage, "2048") + } +} + +func Test_ActivityInputTooLarge(t *testing.T) { + // Registration + r := task.NewTaskRegistry() + r.AddOrchestratorN("ActivityInputTooLarge", func(ctx *task.OrchestrationContext) (any, error) { + // Return a payload that's larger than the configured max chunk size + largeInput := strings.Repeat("a", 3*1024) + return nil, ctx.CallActivity("SayHello", task.WithActivityInput(largeInput)).Await(nil) + }) + r.AddActivityN("SayHello", func(ctx task.ActivityContext) (any, error) { + return nil, nil + }) + + // Initialization + ctx := context.Background() + client, worker := initTaskHubWorker(ctx, r, withChunkingConfig(backend.ChunkingConfiguration{MaxHistoryEventSizeInKB: 2})) + defer worker.Shutdown(ctx) + + // Run the orchestration with an input that's larger than the max chunk size + id, err := client.ScheduleNewOrchestration(ctx, "ActivityInputTooLarge") + require.NoError(t, err) + + // Wait for the orchestration to fail + metadata, err := client.WaitForOrchestrationCompletion(ctx, id) + require.NoError(t, err) + assert.True(t, metadata.IsComplete()) + assert.Equal(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_FAILED, metadata.RuntimeStatus) + if assert.NotNil(t, metadata.FailureDetails) { + assert.Contains(t, metadata.FailureDetails.ErrorMessage, "exceeds") + assert.Contains(t, metadata.FailureDetails.ErrorMessage, "maximum") + assert.Contains(t, metadata.FailureDetails.ErrorMessage, "size") + assert.Contains(t, metadata.FailureDetails.ErrorMessage, "2048") + } +} + +func Test_ActivityOutputTooLarge(t *testing.T) { + // Registration + r := task.NewTaskRegistry() + r.AddOrchestratorN("ActivityOutputTooLarge", func(ctx *task.OrchestrationContext) (any, error) { + if err := ctx.CallActivity("SayHello", task.WithActivityInput("世界")).Await(nil); err != nil { + return nil, err + } + return nil, nil + }) + r.AddActivityN("SayHello", func(ctx task.ActivityContext) (any, error) { + // Return a payload that's larger than the configured max chunk size + return strings.Repeat("a", 3*1024), nil + }) + + // Initialization + ctx := context.Background() + client, worker := initTaskHubWorker(ctx, r, withChunkingConfig(backend.ChunkingConfiguration{MaxHistoryEventSizeInKB: 2})) + defer worker.Shutdown(ctx) + + // Run the orchestration with an input that's larger than the max chunk size + id, err := client.ScheduleNewOrchestration(ctx, "ActivityOutputTooLarge") + require.NoError(t, err) + + // Wait for the orchestration to fail + metadata, err := client.WaitForOrchestrationCompletion(ctx, id) + require.NoError(t, err) + assert.True(t, metadata.IsComplete()) + assert.Equal(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_FAILED, metadata.RuntimeStatus) + if assert.NotNil(t, metadata.FailureDetails) { + assert.Contains(t, metadata.FailureDetails.ErrorMessage, "SayHello") + assert.Contains(t, metadata.FailureDetails.ErrorMessage, "exceeds") + assert.Contains(t, metadata.FailureDetails.ErrorMessage, "limit") + assert.Contains(t, metadata.FailureDetails.ErrorMessage, "size") + assert.Contains(t, metadata.FailureDetails.ErrorMessage, "2048") + } +} + +// Test_ChunkActivityFanOut_MaxEventCount tests the case where the degree of fan out exceeds max chunking configuration. +// The point of this test is to make sure that the orchestration completes successfully. It does not test the chunking behavior itself. +func Test_ChunkActivityFanOut_MaxEventCount(t *testing.T) { + // Registration + r := task.NewTaskRegistry() + r.AddOrchestratorN("ActivityFanOut", func(ctx *task.OrchestrationContext) (any, error) { + tasks := []task.Task{} + for i := 0; i < 100; i++ { + tasks = append(tasks, ctx.CallActivity("PlusOne", task.WithActivityInput(0))) + } + results := []int{} + for _, t := range tasks { + var result int + if err := t.Await(&result); err != nil { + return nil, err + } + results = append(results, result) + } + sum := 0 + for _, r := range results { + sum += r + } + return sum, nil + }) + r.AddActivityN("PlusOne", func(ctx task.ActivityContext) (any, error) { + var input int + if err := ctx.GetInput(&input); err != nil { + return nil, err + } + return input + 1, nil + }) + + // Force the orchestrator to chunk the history into batches of 10 events + chunkingConfig := backend.ChunkingConfiguration{MaxHistoryEventCount: 10} + + // Initialization + ctx := context.Background() + client, worker := initTaskHubWorker(ctx, r, withChunkingConfig(chunkingConfig)) + defer worker.Shutdown(ctx) + + // Run the orchestration + id, err := client.ScheduleNewOrchestration(ctx, "ActivityFanOut") + require.NoError(t, err) + timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + metadata, err := client.WaitForOrchestrationCompletion(timeoutCtx, id) + require.NoError(t, err) + assert.Equal(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_COMPLETED, metadata.RuntimeStatus) + assert.Equal(t, `100`, metadata.SerializedOutput) +} + +// Test_ChunkContinueAsNew test the case where an orchestration receives 30 events and then continues as new +// after each one is processed. The chunk size is set to 10, so we expect there are at least 3 chunks. +// We don't measure the chunking behavior here (we can't) but we do want to make sure that the orchestration +// completes successfully with the expected output, which is the number of events processed. +func Test_ChunkContinueAsNew(t *testing.T) { + targetEventCount := 30 + + // Registration + r := task.NewTaskRegistry() + r.AddOrchestratorN("ContinueAsNewTest", func(ctx *task.OrchestrationContext) (any, error) { + var currentValue int + if err := ctx.GetInput(¤tValue); err != nil { + return nil, err + } + + if currentValue == 0 { + // Wait for 1 second to give the client a chance to raise all the events + time.Sleep(1 * time.Second) + } + + // Break out of the loop once we've received all N events + if currentValue == targetEventCount { + return currentValue, nil + } + + // Wait for an event + if err := ctx.WaitForSingleEvent("MyEvent", 5*time.Second).Await(nil); err != nil { + return nil, err + } + + // Loop until we've received N events + ctx.ContinueAsNew(currentValue+1, task.WithKeepUnprocessedEvents()) + return -1, nil + }) + + // Force the orchestrator to chunk the history into batches of 10 events + chunkingConfig := backend.ChunkingConfiguration{MaxHistoryEventCount: 10} + + // Initialization + ctx := context.Background() + client, worker := initTaskHubWorker(ctx, r, withChunkingConfig(chunkingConfig)) + defer worker.Shutdown(ctx) + + // Run the orchestration + id, err := client.ScheduleNewOrchestration(ctx, "ContinueAsNewTest", api.WithInput(0)) + require.NoError(t, err) + + // Raise N events to the orchestration + for i := 0; i < int(targetEventCount); i++ { + require.NoError(t, client.RaiseEvent(ctx, id, "MyEvent")) + } + + metadata, err := client.WaitForOrchestrationCompletion(ctx, id) + require.NoError(t, err) + assert.Equal(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_COMPLETED, metadata.RuntimeStatus) + assert.Equal(t, fmt.Sprintf("%d", targetEventCount), metadata.SerializedOutput) +} + func Test_SingleActivity_ReuseInstanceIDIgnore(t *testing.T) { // Registration r := task.NewTaskRegistry() @@ -1007,19 +1274,24 @@ func Test_SingleActivity_ReuseInstanceIDError(t *testing.T) { // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "SingleActivity", api.WithInput("世界"), api.WithInstanceID(instanceID)) require.NoError(t, err) - id, err = client.ScheduleNewOrchestration(ctx, "SingleActivity", api.WithInput("World"), api.WithInstanceID(id)) + _, err = client.ScheduleNewOrchestration(ctx, "SingleActivity", api.WithInput("World"), api.WithInstanceID(id)) if assert.Error(t, err) { assert.Contains(t, err.Error(), "orchestration instance already exists") } } -func initTaskHubWorker(ctx context.Context, r *task.TaskRegistry, opts ...backend.NewTaskWorkerOptions) (backend.TaskHubClient, backend.TaskHubWorker) { +func initTaskHubWorker(ctx context.Context, r *task.TaskRegistry, opts ...testOption) (backend.TaskHubClient, backend.TaskHubWorker) { + var config orchestrationTestConfig + for _, configure := range opts { + configure(&config) + } + // TODO: Switch to options pattern logger := backend.DefaultLogger() be := sqlite.NewSqliteBackend(sqlite.NewSqliteOptions(""), logger) executor := task.NewTaskExecutor(r) - orchestrationWorker := backend.NewOrchestrationWorker(be, executor, logger, opts...) - activityWorker := backend.NewActivityTaskWorker(be, executor, logger, opts...) + orchestrationWorker := backend.NewOrchestrationWorker(be, executor, logger, config.orchestrationWorkerOptions...) + activityWorker := backend.NewActivityTaskWorker(be, executor, logger, config.activityWorkerOptions...) taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger) if err := taskHubWorker.Start(ctx); err != nil { panic(err) diff --git a/tests/runtimestate_test.go b/tests/runtimestate_test.go index 2f3eb94..089a409 100644 --- a/tests/runtimestate_test.go +++ b/tests/runtimestate_test.go @@ -1,6 +1,7 @@ package tests import ( + "errors" "testing" "time" @@ -9,10 +10,13 @@ import ( "github.com/microsoft/durabletask-go/internal/helpers" "github.com/microsoft/durabletask-go/internal/protos" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" ) +var defaultChunkingConfig = backend.ChunkingConfiguration{} + // Verifies runtime state created from an ExecutionStarted event func Test_NewOrchestration(t *testing.T) { const iid = "abc" @@ -127,28 +131,32 @@ func Test_CompletedSubOrchestration(t *testing.T) { []*protos.HistoryEvent{}, nil), } - - continuedAsNew, err := s.ApplyActions(actions, nil) - if assert.NoError(t, err) && assert.False(t, continuedAsNew) { - if assert.Len(t, s.NewEvents(), 1) { - e := s.NewEvents()[0] - assert.NotNil(t, e.Timestamp) - if ec := e.GetExecutionCompleted(); assert.NotNil(t, ec) { - assert.Equal(t, expectedTaskID, e.EventId) - assert.Equal(t, status, ec.OrchestrationStatus) - assert.Equal(t, expectedOutput, ec.Result.GetValue()) - assert.Nil(t, ec.FailureDetails) - } - } - if assert.Len(t, s.PendingMessages(), 1) { - e := s.PendingMessages()[0] - assert.NotNil(t, e.HistoryEvent.Timestamp) - if soc := e.HistoryEvent.GetSubOrchestrationInstanceCompleted(); assert.NotNil(t, soc) { - assert.Equal(t, expectedTaskID, soc.TaskScheduledId) - assert.Equal(t, expectedOutput, soc.Result.GetValue()) - } - } - } + s.AddActions(actions) + + changes, err := s.ProcessChanges(defaultChunkingConfig, nil, logger) + require.NoError(t, err) + assert.False(t, changes.ContinuedAsNew) + assert.False(t, changes.IsPartial) + + require.Len(t, changes.NewEvents, 1) + e := changes.NewEvents[0] + assert.NotNil(t, e.Timestamp) + + ec := e.GetExecutionCompleted() + require.NotNil(t, ec) + assert.Equal(t, expectedTaskID, e.EventId) + assert.Equal(t, status, ec.OrchestrationStatus) + assert.Equal(t, expectedOutput, ec.Result.GetValue()) + assert.Nil(t, ec.FailureDetails) + + require.Len(t, changes.NewMessages, 1) + m := changes.NewMessages[0] + assert.NotNil(t, m.HistoryEvent.GetTimestamp()) + + soc := m.HistoryEvent.GetSubOrchestrationInstanceCompleted() + require.NotNil(t, soc) + assert.Equal(t, expectedTaskID, soc.TaskScheduledId) + assert.Equal(t, expectedOutput, soc.Result.GetValue()) } func Test_RuntimeState_ContinueAsNew(t *testing.T) { @@ -172,34 +180,10 @@ func Test_RuntimeState_ContinueAsNew(t *testing.T) { carryoverEvents, nil), } + state.AddActions(actions) - continuedAsNew, err := state.ApplyActions(actions, nil) - if assert.NoError(t, err) && assert.True(t, continuedAsNew) { - if assert.Len(t, state.NewEvents(), 3) { - assert.NotNil(t, state.NewEvents()[0].Timestamp) - assert.NotNil(t, state.NewEvents()[0].GetOrchestratorStarted()) - assert.NotNil(t, state.NewEvents()[1].Timestamp) - if ec := state.NewEvents()[1].GetExecutionStarted(); assert.NotNil(t, ec) { - assert.Equal(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_RUNNING, state.RuntimeStatus()) - assert.Equal(t, string(state.InstanceID()), ec.OrchestrationInstance.InstanceId) - if name, err := state.Name(); assert.NoError(t, err) { - assert.Equal(t, expectedName, name) - assert.Equal(t, expectedName, ec.Name) - } - if input, err := state.Input(); assert.NoError(t, err) { - assert.Equal(t, continueAsNewInput, input) - } - } - assert.NotNil(t, state.NewEvents()[2].Timestamp) - if er := state.NewEvents()[2].GetEventRaised(); assert.NotNil(t, er) { - assert.Equal(t, eventName, er.Name) - assert.Equal(t, eventPayload, er.Input.GetValue()) - } - } - assert.Empty(t, state.PendingMessages()) - assert.Empty(t, state.PendingTasks()) - assert.Empty(t, state.PendingTimers()) - } + _, err := state.ProcessChanges(defaultChunkingConfig, nil, logger) + require.ErrorIs(t, err, backend.ErrContinuedAsNew) } func Test_CreateTimer(t *testing.T) { @@ -215,25 +199,26 @@ func Test_CreateTimer(t *testing.T) { for i := 1; i <= timerCount; i++ { actions = append(actions, helpers.NewCreateTimerAction(int32(i), expectedFireAt)) } + s.AddActions(actions) - continuedAsNew, err := s.ApplyActions(actions, nil) - if assert.NoError(t, err) && assert.False(t, continuedAsNew) { - if assert.Len(t, s.NewEvents(), timerCount) { - for _, e := range s.NewEvents() { - assert.NotNil(t, e.Timestamp) - if timerCreated := e.GetTimerCreated(); assert.NotNil(t, timerCreated) { - assert.WithinDuration(t, expectedFireAt, timerCreated.FireAt.AsTime(), 0) - } + changes, err := s.ProcessChanges(defaultChunkingConfig, nil, logger) + require.NoError(t, err) + + if assert.Len(t, changes.NewEvents, timerCount) { + for _, e := range changes.NewEvents { + assert.NotNil(t, e.Timestamp) + if timerCreated := e.GetTimerCreated(); assert.NotNil(t, timerCreated) { + assert.WithinDuration(t, expectedFireAt, timerCreated.FireAt.AsTime(), 0) } } - if assert.Len(t, s.PendingTimers(), timerCount) { - for i, e := range s.PendingTimers() { - assert.NotNil(t, e.Timestamp) - if timerFired := e.GetTimerFired(); assert.NotNil(t, timerFired) { - expectedTimerID := int32(i + 1) - assert.WithinDuration(t, expectedFireAt, timerFired.FireAt.AsTime(), 0) - assert.Equal(t, expectedTimerID, timerFired.TimerId) - } + } + if assert.Len(t, changes.NewTimers, timerCount) { + for i, e := range changes.NewTimers { + assert.NotNil(t, e.Timestamp) + if timerFired := e.GetTimerFired(); assert.NotNil(t, timerFired) { + expectedTimerID := int32(i + 1) + assert.WithinDuration(t, expectedFireAt, timerFired.FireAt.AsTime(), 0) + assert.Equal(t, expectedTimerID, timerFired.TimerId) } } } @@ -252,32 +237,33 @@ func Test_ScheduleTask(t *testing.T) { actions := []*protos.OrchestratorAction{ helpers.NewScheduleTaskAction(expectedTaskID, expectedName, wrapperspb.String(expectedInput)), } + state.AddActions(actions) tc := &protos.TraceContext{TraceParent: "trace", TraceState: wrapperspb.String("state")} - continuedAsNew, err := state.ApplyActions(actions, tc) - if assert.NoError(t, err) && assert.False(t, continuedAsNew) { - if assert.Len(t, state.NewEvents(), 1) { - e := state.NewEvents()[0] - if taskScheduled := e.GetTaskScheduled(); assert.NotNil(t, taskScheduled) { - assert.Equal(t, expectedTaskID, e.EventId) - assert.Equal(t, expectedName, taskScheduled.Name) - assert.Equal(t, expectedInput, taskScheduled.Input.GetValue()) - if assert.NotNil(t, taskScheduled.ParentTraceContext) { - assert.Equal(t, "trace", taskScheduled.ParentTraceContext.TraceParent) - assert.Equal(t, "state", taskScheduled.ParentTraceContext.TraceState.GetValue()) - } + changes, err := state.ProcessChanges(defaultChunkingConfig, tc, logger) + require.NoError(t, err) + + if assert.Len(t, changes.NewEvents, 1) { + e := changes.NewEvents[0] + if taskScheduled := e.GetTaskScheduled(); assert.NotNil(t, taskScheduled) { + assert.Equal(t, expectedTaskID, e.EventId) + assert.Equal(t, expectedName, taskScheduled.Name) + assert.Equal(t, expectedInput, taskScheduled.Input.GetValue()) + if assert.NotNil(t, taskScheduled.ParentTraceContext) { + assert.Equal(t, "trace", taskScheduled.ParentTraceContext.TraceParent) + assert.Equal(t, "state", taskScheduled.ParentTraceContext.TraceState.GetValue()) } } - if assert.Len(t, state.PendingTasks(), 1) { - e := state.PendingTasks()[0] - if taskScheduled := e.GetTaskScheduled(); assert.NotNil(t, taskScheduled) { - assert.Equal(t, expectedTaskID, e.EventId) - assert.Equal(t, expectedName, taskScheduled.Name) - assert.Equal(t, expectedInput, taskScheduled.Input.GetValue()) - if assert.NotNil(t, taskScheduled.ParentTraceContext) { - assert.Equal(t, "trace", taskScheduled.ParentTraceContext.TraceParent) - assert.Equal(t, "state", taskScheduled.ParentTraceContext.TraceState.GetValue()) - } + } + if assert.Len(t, changes.NewTasks, 1) { + e := changes.NewTasks[0] + if taskScheduled := e.GetTaskScheduled(); assert.NotNil(t, taskScheduled) { + assert.Equal(t, expectedTaskID, e.EventId) + assert.Equal(t, expectedName, taskScheduled.Name) + assert.Equal(t, expectedInput, taskScheduled.Input.GetValue()) + if assert.NotNil(t, taskScheduled.ParentTraceContext) { + assert.Equal(t, "trace", taskScheduled.ParentTraceContext.TraceParent) + assert.Equal(t, "state", taskScheduled.ParentTraceContext.TraceState.GetValue()) } } } @@ -299,46 +285,47 @@ func Test_CreateSubOrchestration(t *testing.T) { actions := []*protos.OrchestratorAction{ helpers.NewCreateSubOrchestrationAction(expectedTaskID, expectedName, expectedInstanceID, expectedInput), } + state.AddActions(actions) tc := &protos.TraceContext{ TraceParent: expectedTraceParent, TraceState: wrapperspb.String(expectedTraceState), } - continuedAsNew, err := state.ApplyActions(actions, tc) - if assert.NoError(t, err) && assert.False(t, continuedAsNew) { - if assert.Len(t, state.NewEvents(), 1) { - e := state.NewEvents()[0] - if orchCreated := e.GetSubOrchestrationInstanceCreated(); assert.NotNil(t, orchCreated) { - assert.Equal(t, expectedTaskID, e.EventId) - assert.Equal(t, expectedInstanceID, orchCreated.InstanceId) - assert.Equal(t, expectedName, orchCreated.Name) - assert.Equal(t, expectedInput.GetValue(), orchCreated.Input.GetValue()) - if assert.NotNil(t, orchCreated.ParentTraceContext) { - assert.Equal(t, expectedTraceParent, orchCreated.ParentTraceContext.TraceParent) - assert.Equal(t, expectedTraceState, orchCreated.ParentTraceContext.TraceState.GetValue()) - } + changes, err := state.ProcessChanges(defaultChunkingConfig, tc, logger) + require.NoError(t, err) + + if assert.Len(t, changes.NewEvents, 1) { + e := changes.NewEvents[0] + if orchCreated := e.GetSubOrchestrationInstanceCreated(); assert.NotNil(t, orchCreated) { + assert.Equal(t, expectedTaskID, e.EventId) + assert.Equal(t, expectedInstanceID, orchCreated.InstanceId) + assert.Equal(t, expectedName, orchCreated.Name) + assert.Equal(t, expectedInput.GetValue(), orchCreated.Input.GetValue()) + if assert.NotNil(t, orchCreated.ParentTraceContext) { + assert.Equal(t, expectedTraceParent, orchCreated.ParentTraceContext.TraceParent) + assert.Equal(t, expectedTraceState, orchCreated.ParentTraceContext.TraceState.GetValue()) } } - if assert.Len(t, state.PendingMessages(), 1) { - msg := state.PendingMessages()[0] - if executionStarted := msg.HistoryEvent.GetExecutionStarted(); assert.NotNil(t, executionStarted) { - assert.Equal(t, int32(-1), msg.HistoryEvent.EventId) - assert.Equal(t, expectedInstanceID, executionStarted.OrchestrationInstance.InstanceId) - assert.NotEmpty(t, executionStarted.OrchestrationInstance.ExecutionId) - assert.Equal(t, expectedName, executionStarted.Name) - assert.Equal(t, expectedInput.GetValue(), executionStarted.Input.GetValue()) - if assert.NotNil(t, executionStarted.ParentInstance) { - assert.Equal(t, "Parent", executionStarted.ParentInstance.Name.GetValue()) - assert.Equal(t, expectedTaskID, executionStarted.ParentInstance.TaskScheduledId) - if assert.NotNil(t, executionStarted.ParentInstance.OrchestrationInstance) { - assert.Equal(t, iid, executionStarted.ParentInstance.OrchestrationInstance.InstanceId) - } - } - if assert.NotNil(t, executionStarted.ParentTraceContext) { - assert.Equal(t, expectedTraceParent, executionStarted.ParentTraceContext.TraceParent) - assert.Equal(t, expectedTraceState, executionStarted.ParentTraceContext.TraceState.GetValue()) + } + if assert.Len(t, changes.NewMessages, 1) { + msg := changes.NewMessages[0] + if executionStarted := msg.HistoryEvent.GetExecutionStarted(); assert.NotNil(t, executionStarted) { + assert.Equal(t, int32(-1), msg.HistoryEvent.EventId) + assert.Equal(t, expectedInstanceID, executionStarted.OrchestrationInstance.InstanceId) + assert.NotEmpty(t, executionStarted.OrchestrationInstance.ExecutionId) + assert.Equal(t, expectedName, executionStarted.Name) + assert.Equal(t, expectedInput.GetValue(), executionStarted.Input.GetValue()) + if assert.NotNil(t, executionStarted.ParentInstance) { + assert.Equal(t, "Parent", executionStarted.ParentInstance.Name.GetValue()) + assert.Equal(t, expectedTaskID, executionStarted.ParentInstance.TaskScheduledId) + if assert.NotNil(t, executionStarted.ParentInstance.OrchestrationInstance) { + assert.Equal(t, iid, executionStarted.ParentInstance.OrchestrationInstance.InstanceId) } } + if assert.NotNil(t, executionStarted.ParentTraceContext) { + assert.Equal(t, expectedTraceParent, executionStarted.ParentTraceContext.TraceParent) + assert.Equal(t, expectedTraceState, executionStarted.ParentTraceContext.TraceState.GetValue()) + } } } } @@ -355,24 +342,25 @@ func Test_SendEvent(t *testing.T) { actions := []*protos.OrchestratorAction{ helpers.NewSendEventAction(expectedInstanceID, expectedEventName, wrapperspb.String(expectedInput)), } + s.AddActions(actions) - continuedAsNew, err := s.ApplyActions(actions, nil) - if assert.NoError(t, err) && assert.False(t, continuedAsNew) { - if assert.Len(t, s.NewEvents(), 1) { - e := s.NewEvents()[0] - if sendEvent := e.GetEventSent(); assert.NotNil(t, sendEvent) { - assert.Equal(t, expectedEventName, sendEvent.Name) - assert.Equal(t, expectedInput, sendEvent.Input.GetValue()) - assert.Equal(t, expectedInstanceID, sendEvent.InstanceId) - } + changes, err := s.ProcessChanges(defaultChunkingConfig, nil, logger) + require.NoError(t, err) + + if assert.Len(t, changes.NewEvents, 1) { + e := changes.NewEvents[0] + if sendEvent := e.GetEventSent(); assert.NotNil(t, sendEvent) { + assert.Equal(t, expectedEventName, sendEvent.Name) + assert.Equal(t, expectedInput, sendEvent.Input.GetValue()) + assert.Equal(t, expectedInstanceID, sendEvent.InstanceId) } - if assert.Len(t, s.PendingMessages(), 1) { - msg := s.PendingMessages()[0] - if sendEvent := msg.HistoryEvent.GetEventSent(); assert.NotNil(t, sendEvent) { - assert.Equal(t, expectedEventName, sendEvent.Name) - assert.Equal(t, expectedInput, sendEvent.Input.GetValue()) - assert.Equal(t, expectedInstanceID, sendEvent.InstanceId) - } + } + if assert.Len(t, changes.NewMessages, 1) { + msg := changes.NewMessages[0] + if sendEvent := msg.HistoryEvent.GetEventSent(); assert.NotNil(t, sendEvent) { + assert.Equal(t, expectedEventName, sendEvent.Name) + assert.Equal(t, expectedInput, sendEvent.Input.GetValue()) + assert.Equal(t, expectedInstanceID, sendEvent.InstanceId) } } } @@ -390,21 +378,172 @@ func Test_StateIsValid(t *testing.T) { assert.False(t, s.IsValid()) } -func Test_DuplicateEvents(t *testing.T) { - s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{}) - if err := s.AddEvent(helpers.NewExecutionStartedEvent("MyOrchestration", "abc", nil, nil, nil)); assert.NoError(t, err) { +func Test_DuplicateIncomingEvents(t *testing.T) { + t.Run("ExecutionStarted", func(t *testing.T) { + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{}) + err := s.AddEvent(helpers.NewExecutionStartedEvent("MyOrchestration", "abc", nil, nil, nil)) + require.NoError(t, err) err = s.AddEvent(helpers.NewExecutionStartedEvent("MyOrchestration", "abc", nil, nil, nil)) + require.ErrorIs(t, err, backend.ErrDuplicateEvent) + }) + + t.Run("ExecutionCompleted", func(t *testing.T) { + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{}) + err := s.AddEvent(helpers.NewExecutionCompletedEvent(-1, protos.OrchestrationStatus_ORCHESTRATION_STATUS_COMPLETED, nil, nil)) + require.NoError(t, err) + err = s.AddEvent(helpers.NewExecutionCompletedEvent(-1, protos.OrchestrationStatus_ORCHESTRATION_STATUS_COMPLETED, nil, nil)) assert.ErrorIs(t, err, backend.ErrDuplicateEvent) - } else { - return - } + }) + + t.Run("TaskScheduled", func(t *testing.T) { + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{}) + err := s.AddEvent(helpers.NewTaskScheduledEvent(1, "MyTask", nil, nil, nil)) + require.NoError(t, err) + err = s.AddEvent(helpers.NewTaskScheduledEvent(1, "MyTask", nil, nil, nil)) + assert.ErrorIs(t, err, backend.ErrDuplicateEvent) + }) - // TODO: Add other types of duplicate events (task completion, external events, sub-orchestration, etc.) + t.Run("TaskCompleted", func(t *testing.T) { + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{}) + err := s.AddEvent(helpers.NewTaskCompletedEvent(1, nil)) + require.NoError(t, err) + err = s.AddEvent(helpers.NewTaskCompletedEvent(1, nil)) + assert.ErrorIs(t, err, backend.ErrDuplicateEvent) + err = s.AddEvent(helpers.NewTaskFailedEvent(1, nil)) + assert.ErrorIs(t, err, backend.ErrDuplicateEvent) + }) - if err := s.AddEvent(helpers.NewExecutionCompletedEvent(-1, protos.OrchestrationStatus_ORCHESTRATION_STATUS_COMPLETED, nil, nil)); assert.NoError(t, err) { - err = s.AddEvent(helpers.NewExecutionCompletedEvent(-1, protos.OrchestrationStatus_ORCHESTRATION_STATUS_COMPLETED, nil, nil)) + t.Run("TaskFailed", func(t *testing.T) { + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{}) + err := s.AddEvent(helpers.NewTaskFailedEvent(1, nil)) + require.NoError(t, err) + err = s.AddEvent(helpers.NewTaskFailedEvent(1, nil)) assert.ErrorIs(t, err, backend.ErrDuplicateEvent) - } else { - return - } + err = s.AddEvent(helpers.NewTaskCompletedEvent(1, nil)) + assert.ErrorIs(t, err, backend.ErrDuplicateEvent) + }) + + t.Run("TimerCreated", func(t *testing.T) { + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{}) + err := s.AddEvent(helpers.NewTimerCreatedEvent(1, timestamppb.Now())) + require.NoError(t, err) + err = s.AddEvent(helpers.NewTimerCreatedEvent(1, timestamppb.Now())) + assert.ErrorIs(t, err, backend.ErrDuplicateEvent) + }) + + t.Run("TimerFired", func(t *testing.T) { + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{}) + err := s.AddEvent(helpers.NewTimerFiredEvent(1, timestamppb.Now())) + require.NoError(t, err) + err = s.AddEvent(helpers.NewTimerFiredEvent(1, timestamppb.Now())) + assert.ErrorIs(t, err, backend.ErrDuplicateEvent) + }) + + t.Run("SubOrchestrationInstanceCreated", func(t *testing.T) { + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{}) + err := s.AddEvent(helpers.NewSubOrchestrationCreatedEvent(1, "MyOrchestration", nil, nil, "xyz", nil)) + require.NoError(t, err) + err = s.AddEvent(helpers.NewSubOrchestrationCreatedEvent(1, "MyOrchestration", nil, nil, "xyz", nil)) + assert.ErrorIs(t, err, backend.ErrDuplicateEvent) + }) + + t.Run("SubOrchestrationInstanceCompleted", func(t *testing.T) { + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{}) + err := s.AddEvent(helpers.NewSubOrchestrationCompletedEvent(1, nil)) + require.NoError(t, err) + err = s.AddEvent(helpers.NewSubOrchestrationCompletedEvent(1, nil)) + assert.ErrorIs(t, err, backend.ErrDuplicateEvent) + err = s.AddEvent(helpers.NewSubOrchestrationFailedEvent(1, nil)) + assert.ErrorIs(t, err, backend.ErrDuplicateEvent) + }) + + t.Run("SubOrchestrationInstanceFailed", func(t *testing.T) { + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{}) + err := s.AddEvent(helpers.NewSubOrchestrationFailedEvent(1, nil)) + require.NoError(t, err) + err = s.AddEvent(helpers.NewSubOrchestrationFailedEvent(1, nil)) + assert.ErrorIs(t, err, backend.ErrDuplicateEvent) + err = s.AddEvent(helpers.NewSubOrchestrationCompletedEvent(1, nil)) + assert.ErrorIs(t, err, backend.ErrDuplicateEvent) + }) +} + +// Test_DuplicateOutgoingEvents verifies that duplicate outgoing events are ignored. +// This can happen if the orchestrator code schedules certain actions successfully as part of +// a chunk, but then fails before the final chunk is committed. When the orchestrator is replayed, +// it will attempt to schedule the same actions again, but the runtime state will already contain +// the outbound events in its history, so it can identify and de-dupe them. +func Test_DuplicateOutgoingEvents(t *testing.T) { + t.Run("TimerCreated", func(t *testing.T) { + now := timestamppb.Now() + timerID := int32(1) + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{helpers.NewTimerCreatedEvent(timerID, now)}) + s.AddActions([]*protos.OrchestratorAction{helpers.NewCreateTimerAction(timerID, now.AsTime())}) + changes, err := s.ProcessChanges(defaultChunkingConfig, nil, logger) + require.NoError(t, err) + assert.Empty(t, changes.NewEvents) + assert.Empty(t, changes.NewTimers) + }) + + t.Run("TaskScheduled", func(t *testing.T) { + taskID := int32(1) + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{helpers.NewTaskScheduledEvent(taskID, "MyTask", nil, nil, nil)}) + s.AddActions([]*protos.OrchestratorAction{helpers.NewScheduleTaskAction(taskID, "MyTask", nil)}) + changes, err := s.ProcessChanges(defaultChunkingConfig, nil, logger) + require.NoError(t, err) + assert.Empty(t, changes.NewEvents) + assert.Empty(t, changes.NewTasks) + }) + + t.Run("SubOrchestrationInstanceCreated", func(t *testing.T) { + taskID := int32(1) + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{helpers.NewSubOrchestrationCreatedEvent(taskID, "MyOrchestration", nil, nil, "xyz", nil)}) + s.AddActions([]*protos.OrchestratorAction{helpers.NewCreateSubOrchestrationAction(taskID, "MyOrchestration", "xyz", nil)}) + changes, err := s.ProcessChanges(defaultChunkingConfig, nil, logger) + require.NoError(t, err) + assert.Empty(t, changes.NewEvents) + assert.Empty(t, changes.NewMessages) + }) +} + +func Test_SetFailed(t *testing.T) { + errFailure := errors.New("you got terminated") + t.Run("Pending", func(t *testing.T) { + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{}) + assert.NotEqual(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_FAILED, s.RuntimeStatus()) + s.SetFailed(errFailure) + s.ProcessChanges(defaultChunkingConfig, nil, logger) + require.True(t, s.IsValid()) + assert.Equal(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_FAILED, s.RuntimeStatus()) + failureDetails, err := s.FailureDetails() + require.NoError(t, err) + assert.Equal(t, errFailure.Error(), failureDetails.ErrorMessage) + }) + + t.Run("Running", func(t *testing.T) { + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{ + helpers.NewExecutionStartedEvent("MyOrchestration", "abc", nil, nil, nil), + }) + assert.NotEqual(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_FAILED, s.RuntimeStatus()) + s.SetFailed(errFailure) + s.ProcessChanges(defaultChunkingConfig, nil, logger) + assert.Equal(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_FAILED, s.RuntimeStatus()) + failureDetails, err := s.FailureDetails() + require.NoError(t, err) + assert.Equal(t, errFailure.Error(), failureDetails.ErrorMessage) + }) + + t.Run("ContinuedAsNew", func(t *testing.T) { + s := backend.NewOrchestrationRuntimeState("abc", []*protos.HistoryEvent{ + helpers.NewExecutionStartedEvent("MyOrchestration", "abc", nil, nil, nil), + }) + s.AddEvent(helpers.NewExecutionCompletedEvent(-1, protos.OrchestrationStatus_ORCHESTRATION_STATUS_CONTINUED_AS_NEW, nil, nil)) + assert.NotEqual(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_FAILED, s.RuntimeStatus()) + s.SetFailed(errFailure) + s.ProcessChanges(defaultChunkingConfig, nil, logger) + assert.Equal(t, protos.OrchestrationStatus_ORCHESTRATION_STATUS_FAILED, s.RuntimeStatus()) + failureDetails, err := s.FailureDetails() + require.NoError(t, err) + assert.Equal(t, errFailure.Error(), failureDetails.ErrorMessage) + }) } diff --git a/tests/tracing_test.go b/tests/tracing_test.go index 9209bde..5f0bf53 100644 --- a/tests/tracing_test.go +++ b/tests/tracing_test.go @@ -23,6 +23,7 @@ type ( var ( initTracingOnce sync.Once sharedTraceExporter = tracetest.NewInMemoryExporter() + sharedSpanProcessor = trace.NewSimpleSpanProcessor(sharedTraceExporter) ) func assertSpanSequence(t assert.TestingT, spans []trace.ReadOnlySpan, spanAsserts ...spanValidator) { @@ -221,8 +222,20 @@ func initTracing() *tracetest.InMemoryExporter { // The global tracer provider can only be initialized once. // Subsequent initializations will silently fail. initTracingOnce.Do(func() { - processor := trace.NewSimpleSpanProcessor(sharedTraceExporter) - provider := trace.NewTracerProvider(trace.WithSpanProcessor(processor)) + // Inspired by this sample: https://github.com/open-telemetry/opentelemetry-go/blob/main/example/zipkin/main.go + // zexp, _ := zipkin.New("http://localhost:9411/api/v2/spans") + + // NOTE: The simple span processor is not recommended for production. + // Instead, the batch span processor should be used for production. + // zprocessor := trace.NewSimpleSpanProcessor(zexp) + // processor := trace.NewSimpleSpanProcessor(sharedTraceExporter) + provider := trace.NewTracerProvider( + // trace.WithSpanProcessor(processor), + trace.WithSpanProcessor(sharedSpanProcessor), + // trace.WithSpanProcessor(zprocessor), + trace.WithSampler(trace.AlwaysSample()), + ) + otel.SetTracerProvider(provider) }) diff --git a/tests/worker_test.go b/tests/worker_test.go index fd6dfe5..a698827 100644 --- a/tests/worker_test.go +++ b/tests/worker_test.go @@ -2,7 +2,9 @@ package tests import ( "context" + "strings" "testing" + "time" "github.com/microsoft/durabletask-go/api" "github.com/microsoft/durabletask-go/backend" @@ -11,6 +13,7 @@ import ( "github.com/microsoft/durabletask-go/tests/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -29,12 +32,11 @@ func Test_TryProcessSingleOrchestrationWorkItem_BasicFlow(t *testing.T) { result := &backend.ExecutionResults{Response: &protos.OrchestratorResponse{}} be := mocks.NewBackend(t) + ex := mocks.NewExecutor(t) be.EXPECT().GetOrchestrationWorkItem(anyContext).Return(wi, nil).Once() be.EXPECT().GetOrchestrationRuntimeState(anyContext, wi).Return(state, nil).Once() - be.EXPECT().CompleteOrchestrationWorkItem(anyContext, wi).Return(nil).Once() - - ex := mocks.NewExecutor(t) ex.EXPECT().ExecuteOrchestrator(anyContext, wi.InstanceID, state.OldEvents(), mock.Anything).Return(result, nil).Once() + be.EXPECT().CompleteOrchestrationWorkItem(anyContext, wi, mock.Anything).Return(nil).Once() worker := backend.NewOrchestrationWorker(be, ex, logger) ok, err := worker.ProcessNext(ctx) @@ -95,8 +97,8 @@ func Test_TryProcessSingleOrchestrationWorkItem_ExecutionStartedAndCompleted(t * // but there doesn't seem to be a good way to assert this. ex.EXPECT().ExecuteOrchestrator(anyContext, iid, []*protos.HistoryEvent{}, mock.Anything).Return(result, nil).Once() - // After execution, the Complete action should be called - be.EXPECT().CompleteOrchestrationWorkItem(anyContext, wi).Return(nil).Once() + // The work item should be completed with the result of the execution + be.EXPECT().CompleteOrchestrationWorkItem(anyContext, wi, mock.Anything).Return(nil).Once() // Set up and run the test worker := backend.NewOrchestrationWorker(be, ex, logger) @@ -107,3 +109,158 @@ func Test_TryProcessSingleOrchestrationWorkItem_ExecutionStartedAndCompleted(t * assert.Nil(t, err) assert.True(t, ok) } + +func Test_ChunkRuntimeStateChanges(t *testing.T) { + iid := api.InstanceID("test123") + + run := func( + wi *backend.OrchestrationWorkItem, + config backend.ChunkingConfiguration, + actions []*protos.OrchestratorAction, + expectedChunks int, + continuesAsNew bool, + ) { + state := backend.NewOrchestrationRuntimeState(wi.InstanceID, nil) + + be := mocks.NewBackend(t) + be.EXPECT().GetOrchestrationWorkItem(anyContext).Return(wi, nil).Once() + be.EXPECT().GetOrchestrationRuntimeState(anyContext, wi).Return(state, nil).Once() + be.EXPECT().CompleteOrchestrationWorkItem(anyContext, wi, mock.Anything).Return(nil).Times(expectedChunks) + + result := &backend.ExecutionResults{ + Response: &protos.OrchestratorResponse{Actions: actions}, + } + + ex := mocks.NewExecutor(t) + ex.EXPECT().ExecuteOrchestrator(anyContext, wi.InstanceID, mock.Anything, mock.Anything).Return(result, nil).Once() + if continuesAsNew { + // There will be one more execution cycle after the first one and we need to return an empty action set + // to prevent execution from going forever. + nextResult := &backend.ExecutionResults{Response: &protos.OrchestratorResponse{}} + ex.EXPECT().ExecuteOrchestrator(anyContext, wi.InstanceID, mock.Anything, mock.Anything).Return(nextResult, nil).Once() + } + + // Set up and run the test + worker := backend.NewOrchestrationWorker(be, ex, logger, backend.WithChunkingConfiguration(config)) + ok, err := worker.ProcessNext(ctx) + require.NoError(t, err) + require.True(t, ok) + worker.StopAndDrain() + } + + t.Run("OnlyNewEvents", func(t *testing.T) { + t.Run("EventCount", func(t *testing.T) { + config := backend.ChunkingConfiguration{MaxHistoryEventCount: 10} + wi := &backend.OrchestrationWorkItem{ + InstanceID: iid, + NewEvents: []*protos.HistoryEvent{helpers.NewExecutionStartedEvent("MyOrch", string(iid), nil, nil, nil)}, + } + + // Append a bunch of new events to the work item + // 2 events (orchestratorStarted and executionStarted) + 48 external events = 50 total + for i := 0; i < 48; i++ { + wi.NewEvents = append(wi.NewEvents, helpers.NewEventRaisedEvent("MyEvent", nil)) + } + + // No actions returned by the orchestrator for this test + actions := make([]*protos.OrchestratorAction, 0) + + // We expect that completion will be called 5 times, once for each chunk (50 / 10 = 5) + expectedCompletionCalls := 5 + run(wi, config, actions, expectedCompletionCalls, false) + }) + + t.Run("EventSize", func(t *testing.T) { + config := backend.ChunkingConfiguration{MaxHistoryEventSizeInKB: 1} + wi := &backend.OrchestrationWorkItem{ + InstanceID: iid, + NewEvents: []*protos.HistoryEvent{helpers.NewExecutionStartedEvent("MyOrch", string(iid), nil, nil, nil)}, + } + + // Append a bunch of new events to the work item that are ~900 bytes in size + for i := 0; i < 4; i++ { + payload1KB := strings.Repeat("a", 900) // 900 UTF-8 bytes + wi.NewEvents = append(wi.NewEvents, helpers.NewEventRaisedEvent("MyEvent", wrapperspb.String(payload1KB))) + } + + // No actions returned by the orchestrator for this test + actions := make([]*protos.OrchestratorAction, 0) + + // We expect that completion will be called 5 times, once for each chunk. + // Chunk 1: OrchestratorStarted (27), ExecutionStarted (86) + // Chunk 2: EventRaised (944) + // Chunk 3: EventRaised (944) + // Chunk 4: EventRaised (944) + // Chunk 5: EventRaised (944) + expectedCompletionCalls := 5 + run(wi, config, actions, expectedCompletionCalls, false) + }) + }) + + t.Run("OnlyActionEvents", func(t *testing.T) { + config := backend.ChunkingConfiguration{MaxHistoryEventCount: 10} + wi := &backend.OrchestrationWorkItem{ + InstanceID: iid, + NewEvents: []*protos.HistoryEvent{helpers.NewExecutionStartedEvent("MyOrch", string(iid), nil, nil, nil)}, + } + + // The orchestrator will return 48 actions, which will be chunked into 5 chunks + // 2 new events (orchestratorStarted and executionStarted) + 48 create timer actions = 50 total + actions := make([]*protos.OrchestratorAction, 48) + for i := 0; i < 48; i++ { + actions[i] = helpers.NewCreateTimerAction(int32(i), time.Now()) + } + + // We expect that completion will be called 5 times, once for each chunk (50 / 10 = 5) + expectedCompletionCalls := 5 + run(wi, config, actions, expectedCompletionCalls, false) + }) + + t.Run("OnlyCarryoverEvents", func(t *testing.T) { + config := backend.ChunkingConfiguration{MaxHistoryEventCount: 10} + wi := &backend.OrchestrationWorkItem{ + InstanceID: iid, + NewEvents: []*protos.HistoryEvent{helpers.NewExecutionStartedEvent("MyOrch", string(iid), nil, nil, nil)}, + } + + // The orchestrator will return a single ContinueAsNew action with 48 carryover events, which will be chunked into 5 chunks + // 2 new events (orchestratorStarted and executionStarted) + 48 external event = 50 total + carroverEvents := make([]*protos.HistoryEvent, 48) + for i := 0; i < 48; i++ { + carroverEvents[i] = helpers.NewEventRaisedEvent("MyEvent", nil) + } + continueAsNewAction := helpers.NewCompleteOrchestrationAction( + -1, + protos.OrchestrationStatus_ORCHESTRATION_STATUS_CONTINUED_AS_NEW, + nil, + carroverEvents, + nil) + actions := []*protos.OrchestratorAction{continueAsNewAction} + + // We expect that completion will be called 5 times, once for each chunk (50 / 10 = 5) + expectedCompletionCalls := 5 + run(wi, config, actions, expectedCompletionCalls, true) + }) + + t.Run("NewEventsAndActionEvents", func(t *testing.T) { + config := backend.ChunkingConfiguration{MaxHistoryEventCount: 10} + wi := &backend.OrchestrationWorkItem{ + InstanceID: iid, + NewEvents: []*protos.HistoryEvent{helpers.NewExecutionStartedEvent("MyOrch", string(iid), nil, nil, nil)}, + } + + // Append a bunch of new events to the work item + // 2 events (orchestratorStarted and executionStarted) + 24 external events + 24 task scheduled actions = 50 total + for i := 0; i < 24; i++ { + wi.NewEvents = append(wi.NewEvents, helpers.NewEventRaisedEvent("MyEvent", nil)) + } + actions := make([]*protos.OrchestratorAction, 24) + for i := 0; i < 24; i++ { + actions[i] = helpers.NewScheduleTaskAction(int32(i), "MyAct", nil) + } + + // We expect that completion will be called 5 times, once for each chunk (50 / 10 = 5) + expectedCompletionCalls := 5 + run(wi, config, actions, expectedCompletionCalls, false) + }) +}