diff --git a/internal/internal_task_handlers.go b/internal/internal_task_handlers.go index a1b206f26..9559a3295 100644 --- a/internal/internal_task_handlers.go +++ b/internal/internal_task_handlers.go @@ -75,6 +75,7 @@ type ( // workflowTask wraps a decision task. workflowTask struct { task *s.PollForDecisionTaskResponse + autoConfigHint *s.AutoConfigHint historyIterator HistoryIterator doneCh chan struct{} laResultCh chan *localActivityResult @@ -82,8 +83,9 @@ type ( // activityTask wraps a activity task. activityTask struct { - task *s.PollForActivityTaskResponse - pollStartTime time.Time + task *s.PollForActivityTaskResponse + autoConfigHint *s.AutoConfigHint + pollStartTime time.Time } // resetStickinessTask wraps a ResetStickyTaskListRequest. diff --git a/internal/internal_task_pollers.go b/internal/internal_task_pollers.go index a69b8765d..6e1c4790d 100644 --- a/internal/internal_task_pollers.go +++ b/internal/internal_task_pollers.go @@ -848,7 +848,9 @@ func (wtp *workflowTaskPoller) poll(ctx context.Context) (interface{}, error) { if response == nil || len(response.TaskToken) == 0 { wtp.metricsScope.Counter(metrics.DecisionPollNoTaskCounter).Inc(1) wtp.updateBacklog(request.TaskList.GetKind(), 0) - return &workflowTask{}, nil + return &workflowTask{ + autoConfigHint: response.GetAutoConfigHint(), + }, nil } wtp.updateBacklog(request.TaskList.GetKind(), response.GetBacklogCountHint()) @@ -908,6 +910,7 @@ func (wtp *workflowTaskPoller) toWorkflowTask(response *s.PollForDecisionTaskRes task := &workflowTask{ task: response, historyIterator: historyIterator, + autoConfigHint: response.GetAutoConfigHint(), } return task } @@ -1095,7 +1098,7 @@ func (atp *activityTaskPoller) poll(ctx context.Context) (*s.PollForActivityTask } if response == nil || len(response.TaskToken) == 0 { atp.metricsScope.Counter(metrics.ActivityPollNoTaskCounter).Inc(1) - return nil, startTime, nil + return response, startTime, nil } return response, startTime, err @@ -1116,7 +1119,9 @@ func (atp *activityTaskPoller) pollWithMetrics(ctx context.Context, return nil, err } if response == nil || len(response.TaskToken) == 0 { - return &activityTask{}, nil + return &activityTask{ + autoConfigHint: response.GetAutoConfigHint(), + }, nil } workflowType := response.WorkflowType.GetName() @@ -1128,7 +1133,7 @@ func (atp *activityTaskPoller) pollWithMetrics(ctx context.Context, scheduledToStartLatency := time.Duration(response.GetStartedTimestamp() - response.GetScheduledTimestampOfThisAttempt()) metricsScope.Timer(metrics.ActivityScheduledToStartLatency).Record(scheduledToStartLatency) - return &activityTask{task: response, pollStartTime: startTime}, nil + return &activityTask{task: response, pollStartTime: startTime, autoConfigHint: response.GetAutoConfigHint()}, nil } // PollTask polls a new task diff --git a/internal/internal_task_pollers_test.go b/internal/internal_task_pollers_test.go index ed0a4e779..ee1ebe324 100644 --- a/internal/internal_task_pollers_test.go +++ b/internal/internal_task_pollers_test.go @@ -61,6 +61,110 @@ func Test_newWorkflowTaskPoller(t *testing.T) { }) } +func TestWorkflowTaskPoller(t *testing.T) { + t.Run("PollTask", func(t *testing.T) { + task := &s.PollForDecisionTaskResponse{ + TaskToken: []byte("some value"), + AutoConfigHint: &s.AutoConfigHint{ + common.PtrOf(true), + common.PtrOf(int64(1000)), + }, + } + emptyTask := &s.PollForDecisionTaskResponse{ + TaskToken: nil, + AutoConfigHint: &s.AutoConfigHint{ + common.PtrOf(true), + common.PtrOf(int64(1000)), + }, + } + for _, tt := range []struct { + name string + response *s.PollForDecisionTaskResponse + expected *workflowTask + }{ + { + "success with task", + task, + &workflowTask{ + task: task, + autoConfigHint: task.AutoConfigHint, + }, + }, + { + "success with empty task", + emptyTask, + &workflowTask{ + task: nil, + autoConfigHint: task.AutoConfigHint, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + poller, client, _, _ := buildWorkflowTaskPoller(t) + client.EXPECT().PollForDecisionTask(gomock.Any(), gomock.Any(), gomock.Any()).Return(tt.response, nil) + result, err := poller.PollTask() + assert.NoError(t, err) + resultTask, ok := result.(*workflowTask) + assert.True(t, ok) + assert.Equal(t, tt.expected.task, resultTask.task) + assert.Equal(t, tt.expected.autoConfigHint, resultTask.autoConfigHint) + }) + } + }) +} + +func TestActivityTaskPoller(t *testing.T) { + t.Run("PollTask", func(t *testing.T) { + task := &s.PollForActivityTaskResponse{ + TaskToken: []byte("some value"), + AutoConfigHint: &s.AutoConfigHint{ + common.PtrOf(true), + common.PtrOf(int64(1000)), + }, + } + emptyTask := &s.PollForActivityTaskResponse{ + TaskToken: nil, + AutoConfigHint: &s.AutoConfigHint{ + common.PtrOf(true), + common.PtrOf(int64(1000)), + }, + } + for _, tt := range []struct { + name string + response *s.PollForActivityTaskResponse + expected *activityTask + }{ + { + "success with task", + task, + &activityTask{ + task: task, + autoConfigHint: task.AutoConfigHint, + }, + }, + { + "success with empty task", + emptyTask, + &activityTask{ + task: nil, + autoConfigHint: task.AutoConfigHint, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + poller, client := buildActivityTaskPoller(t) + client.EXPECT().PollForActivityTask(gomock.Any(), gomock.Any(), gomock.Any()).Return(tt.response, nil) + result, err := poller.PollTask() + assert.NoError(t, err) + resultTask, ok := result.(*activityTask) + assert.True(t, ok) + assert.Equal(t, tt.expected.task, resultTask.task) + assert.Equal(t, tt.expected.autoConfigHint, resultTask.autoConfigHint) + }) + } + }) +} + func TestLocalActivityPanic(t *testing.T) { // regression: panics in local activities should not terminate the process s := WorkflowTestSuite{logger: testlogger.NewZap(t)} @@ -213,3 +317,20 @@ func buildWorkflowTaskPoller(t *testing.T) (*workflowTaskPoller, *workflowservic featureFlags: FeatureFlags{}, }, mockService, taskHandler, lda } + +func buildActivityTaskPoller(t *testing.T) (*activityTaskPoller, *workflowservicetest.MockClient) { + ctrl := gomock.NewController(t) + mockService := workflowservicetest.NewMockClient(ctrl) + return &activityTaskPoller{ + basePoller: basePoller{ + shutdownC: make(<-chan struct{}), + }, + domain: _testDomainName, + taskListName: _testTaskList, + identity: _testIdentity, + service: mockService, + metricsScope: &metrics.TaggedScope{Scope: tally.NewTestScope("test", nil)}, + logger: testlogger.NewZap(t), + featureFlags: FeatureFlags{}, + }, mockService +}