diff --git a/service/history/task/cross_cluster_task_processor.go b/service/history/task/cross_cluster_task_processor.go index d1b2b0d4b11..44f9f2c4498 100644 --- a/service/history/task/cross_cluster_task_processor.go +++ b/service/history/task/cross_cluster_task_processor.go @@ -56,6 +56,8 @@ type ( crossClusterTaskProcessors []*crossClusterTaskProcessor crossClusterTaskProcessor struct { + ctx context.Context + ctxCancel context.CancelFunc shard shard.Context taskProcessor Processor taskExecutor Executor @@ -117,6 +119,7 @@ func newCrossClusterTaskProcessor( taskFetcher Fetcher, options *CrossClusterTaskProcessorOptions, ) *crossClusterTaskProcessor { + ctx, cancel := context.WithCancel(context.Background()) sourceCluster := taskFetcher.GetSourceCluster() logger := shard.GetLogger().WithTags( tag.ComponentCrossClusterTaskProcessor, @@ -130,6 +133,8 @@ func newCrossClusterTaskProcessor( retryPolicy.SetMaximumInterval(time.Second) retryPolicy.SetExpirationInterval(options.TaskWaitInterval()) return &crossClusterTaskProcessor{ + ctx: ctx, + ctxCancel: cancel, shard: shard, taskProcessor: taskProcessor, taskExecutor: NewCrossClusterTargetTaskExecutor( @@ -190,6 +195,7 @@ func (p *crossClusterTaskProcessor) Stop() { } close(p.shutdownCh) + p.ctxCancel() p.redispatcher.Stop() if success := common.AwaitWaitGroup(&p.shutdownWG, time.Minute); !success { @@ -219,9 +225,13 @@ func (p *crossClusterTaskProcessor) processLoop() { sw := p.metricsScope.StartTimer(metrics.CrossClusterFetchLatency) var taskRequests []*types.CrossClusterTaskRequest - err := p.taskFetcher.Fetch(p.shard.GetShardID()).Get(context.Background(), &taskRequests) + err := p.taskFetcher.Fetch(p.shard.GetShardID()).Get(p.ctx, &taskRequests) sw.Stop() if err != nil { + if err == errTaskFetcherShutdown { + return + } + p.logger.Error("Unable to fetch cross cluster tasks", tag.Error(err)) if common.IsServiceBusyError(err) { p.metricsScope.IncCounter(metrics.CrossClusterFetchServiceBusyFailures) @@ -231,6 +241,7 @@ func (p *crossClusterTaskProcessor) processLoop() { )) } else { p.metricsScope.IncCounter(metrics.CrossClusterFetchFailures) + // note we rely on the aggregation interval in task fetcher as the backoff } continue } @@ -273,7 +284,7 @@ func (p *crossClusterTaskProcessor) processTaskRequests( TargetCluster: p.shard.GetClusterMetadata().GetCurrentClusterName(), FetchNewTasks: p.numPendingTasks() < p.options.MaxPendingTasks(), } - taskWaitContext, cancel := context.WithTimeout(context.Background(), p.options.TaskWaitInterval()) + taskWaitContext, cancel := context.WithTimeout(p.ctx, p.options.TaskWaitInterval()) deadlineExceeded := false for taskID, taskFuture := range taskFutures { if deadlineExceeded && !taskFuture.IsReady() { @@ -282,6 +293,13 @@ func (p *crossClusterTaskProcessor) processTaskRequests( var taskResponse types.CrossClusterTaskResponse if err := taskFuture.Get(taskWaitContext, &taskResponse); err != nil { + if p.ctx.Err() != nil { + // root context is no-longer valid, component is being shutdown, + // we can return directly + cancel() + return + } + if err == context.DeadlineExceeded { // switch to a valid context here, otherwise Get() will always return an error. // using context.Background() is fine since we will only be calling Get() with it @@ -361,7 +379,13 @@ func (p *crossClusterTaskProcessor) respondPendingTaskLoop() { for taskID, taskFuture := range p.pendingTasks { if taskFuture.IsReady() { var taskResponse types.CrossClusterTaskResponse - if err := taskFuture.Get(context.Background(), &taskResponse); err != nil { + if err := taskFuture.Get(p.ctx, &taskResponse); err != nil { + if p.ctx.Err() != nil { + // we are in shutdown logic + p.taskLock.Unlock() + return + } + // this case should not happen, // task failure should be converted to FailCause in the response by the processing logic taskResponse = types.CrossClusterTaskResponse{ @@ -433,7 +457,7 @@ func (p *crossClusterTaskProcessor) respondTaskCompletedWithRetry( var response *types.RespondCrossClusterTasksCompletedResponse op := func() error { - ctx, cancel := context.WithTimeout(context.Background(), respondCrossClusterTaskTimeout) + ctx, cancel := context.WithTimeout(p.ctx, respondCrossClusterTaskTimeout) defer cancel() var err error response, err = p.sourceAdminClient.RespondCrossClusterTasksCompleted(ctx, request) @@ -443,7 +467,7 @@ func (p *crossClusterTaskProcessor) respondTaskCompletedWithRetry( } return err } - err := p.throttleRetry.Do(context.Background(), op) + err := p.throttleRetry.Do(p.ctx, op) return response, err } diff --git a/service/history/task/fetcher.go b/service/history/task/fetcher.go index d29e199b85d..914da3d052b 100644 --- a/service/history/task/fetcher.go +++ b/service/history/task/fetcher.go @@ -56,6 +56,7 @@ type ( } fetchTaskFunc func( + ctx context.Context, clientBean client.Bean, sourceCluster string, currentCluster string, @@ -76,7 +77,9 @@ type ( shutdownCh chan struct{} requestChan chan fetchRequest - fetchTaskFunc fetchTaskFunc + fetchCtx context.Context + fetchCtxCancel context.CancelFunc + fetchTaskFunc fetchTaskFunc } ) @@ -111,6 +114,7 @@ func NewCrossClusterTaskFetchers( } func crossClusterTaskFetchFn( + ctx context.Context, clientBean client.Bean, sourceCluster string, currentCluster string, @@ -128,7 +132,7 @@ func crossClusterTaskFetchFn( ShardIDs: shardIDs, TargetCluster: currentCluster, } - ctx, cancel := context.WithTimeout(context.Background(), defaultFetchTimeout) + ctx, cancel := context.WithTimeout(ctx, defaultFetchTimeout) defer cancel() resp, err := adminClient.GetCrossClusterTasks(ctx, request) if err != nil { @@ -199,6 +203,7 @@ func newTaskFetcher( metricsClient metrics.Client, logger log.Logger, ) *fetcherImpl { + fetchCtx, fetchCtxCancel := context.WithCancel(context.Background()) return &fetcherImpl{ status: common.DaemonStatusInitialized, currentCluster: currentCluster, @@ -213,9 +218,11 @@ func newTaskFetcher( tag.ComponentCrossClusterTaskFetcher, tag.SourceCluster(sourceCluster), ), - shutdownCh: make(chan struct{}), - requestChan: make(chan fetchRequest, defaultRequestChanBufferSize), - fetchTaskFunc: fetchTaskFunc, + shutdownCh: make(chan struct{}), + requestChan: make(chan fetchRequest, defaultRequestChanBufferSize), + fetchCtx: fetchCtx, + fetchCtxCancel: fetchCtxCancel, + fetchTaskFunc: fetchTaskFunc, } } @@ -239,6 +246,7 @@ func (f *fetcherImpl) Stop() { } close(f.shutdownCh) + f.fetchCtxCancel() if success := common.AwaitWaitGroup(&f.shutdownWG, time.Minute); !success { f.logger.Warn("Task fetcher timedout on shutdown.", tag.LifeCycleStopTimedout) } @@ -338,7 +346,13 @@ func (f *fetcherImpl) fetch( sw := f.metricsScope.StartTimer(metrics.CrossClusterFetchLatency) defer sw.Stop() - responseByShard, err := f.fetchTaskFunc(f.clientBean, f.sourceCluster, f.currentCluster, outstandingRequests) + responseByShard, err := f.fetchTaskFunc( + f.fetchCtx, + f.clientBean, + f.sourceCluster, + f.currentCluster, + outstandingRequests, + ) if err != nil { return err } diff --git a/service/history/task/fetcher_test.go b/service/history/task/fetcher_test.go index 2b3c4f4edb9..1258d15d074 100644 --- a/service/history/task/fetcher_test.go +++ b/service/history/task/fetcher_test.go @@ -224,6 +224,7 @@ func (s *fetcherSuite) TestAggregator() { } func (s *fetcherSuite) testFetchTaskFn( + ctx context.Context, clientBean client.Bean, sourceCluster string, currentCluster string,