diff --git a/internal/internal_poller_autoscaler.go b/internal/internal_poller_autoscaler.go index 2dc81e7ba..c8dfb299a 100644 --- a/internal/internal_poller_autoscaler.go +++ b/internal/internal_poller_autoscaler.go @@ -174,9 +174,9 @@ func (m *pollerUsageEstimator) CollectUsage(data interface{}) error { func isTaskEmpty(task interface{}) (bool, error) { switch t := task.(type) { case *workflowTask: - return t == nil || t.task == nil, nil + return t == nil || t.task == nil || len(t.task.TaskToken) == 0, nil case *activityTask: - return t == nil || t.task == nil, nil + return t == nil || t.task == nil || len(t.task.TaskToken) == 0, nil case *localActivityTask: return t == nil || t.workflowTask == nil, nil default: diff --git a/internal/internal_poller_autoscaler_test.go b/internal/internal_poller_autoscaler_test.go index 4a441b642..53cfb790d 100644 --- a/internal/internal_poller_autoscaler_test.go +++ b/internal/internal_poller_autoscaler_test.go @@ -278,10 +278,10 @@ type unrelatedPolledTask struct{} func generateRandomPollResults(noTaskPoll, taskPoll, unrelated int) <-chan interface{} { var result []interface{} for i := 0; i < noTaskPoll; i++ { - result = append(result, &activityTask{}) + result = append(result, &activityTask{task: &s.PollForActivityTaskResponse{}}) } for i := 0; i < taskPoll; i++ { - result = append(result, &activityTask{task: &s.PollForActivityTaskResponse{}}) + result = append(result, &activityTask{task: &s.PollForActivityTaskResponse{TaskToken: []byte("some value")}}) } for i := 0; i < unrelated; i++ { result = append(result, &unrelatedPolledTask{}) diff --git a/internal/internal_task_pollers.go b/internal/internal_task_pollers.go index a69b8765d..e2e32e49a 100644 --- a/internal/internal_task_pollers.go +++ b/internal/internal_task_pollers.go @@ -848,7 +848,7 @@ func (wtp *workflowTaskPoller) poll(ctx context.Context) (interface{}, error) { if response == nil || len(response.TaskToken) == 0 { wtp.metricsScope.Counter(metrics.DecisionPollNoTaskCounter).Inc(1) wtp.updateBacklog(request.TaskList.GetKind(), 0) - return &workflowTask{}, nil + return &workflowTask{task: response}, nil } wtp.updateBacklog(request.TaskList.GetKind(), response.GetBacklogCountHint()) @@ -1095,7 +1095,7 @@ func (atp *activityTaskPoller) poll(ctx context.Context) (*s.PollForActivityTask } if response == nil || len(response.TaskToken) == 0 { atp.metricsScope.Counter(metrics.ActivityPollNoTaskCounter).Inc(1) - return nil, startTime, nil + return response, startTime, nil } return response, startTime, err @@ -1116,7 +1116,7 @@ func (atp *activityTaskPoller) pollWithMetrics(ctx context.Context, return nil, err } if response == nil || len(response.TaskToken) == 0 { - return &activityTask{}, nil + return &activityTask{task: response}, nil } workflowType := response.WorkflowType.GetName() diff --git a/internal/internal_task_pollers_test.go b/internal/internal_task_pollers_test.go index ed0a4e779..041553e12 100644 --- a/internal/internal_task_pollers_test.go +++ b/internal/internal_task_pollers_test.go @@ -61,6 +61,104 @@ func Test_newWorkflowTaskPoller(t *testing.T) { }) } +func TestWorkflowTaskPoller(t *testing.T) { + t.Run("PollTask", func(t *testing.T) { + task := &s.PollForDecisionTaskResponse{ + TaskToken: []byte("some value"), + AutoConfigHint: &s.AutoConfigHint{ + common.PtrOf(true), + common.PtrOf(int64(1000)), + }, + } + emptyTask := &s.PollForDecisionTaskResponse{ + TaskToken: nil, + AutoConfigHint: &s.AutoConfigHint{ + common.PtrOf(true), + common.PtrOf(int64(1000)), + }, + } + for _, tt := range []struct { + name string + response *s.PollForDecisionTaskResponse + expected *workflowTask + }{ + { + "success with task", + task, + &workflowTask{ + task: task, + }, + }, + { + "success with empty task", + emptyTask, + &workflowTask{ + task: emptyTask, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + poller, client, _, _ := buildWorkflowTaskPoller(t) + client.EXPECT().PollForDecisionTask(gomock.Any(), gomock.Any(), gomock.Any()).Return(tt.response, nil) + result, err := poller.PollTask() + assert.NoError(t, err) + resultTask, ok := result.(*workflowTask) + assert.True(t, ok) + assert.Equal(t, tt.expected.task, resultTask.task) + }) + } + }) +} + +func TestActivityTaskPoller(t *testing.T) { + t.Run("PollTask", func(t *testing.T) { + task := &s.PollForActivityTaskResponse{ + TaskToken: []byte("some value"), + AutoConfigHint: &s.AutoConfigHint{ + common.PtrOf(true), + common.PtrOf(int64(1000)), + }, + } + emptyTask := &s.PollForActivityTaskResponse{ + TaskToken: nil, + AutoConfigHint: &s.AutoConfigHint{ + common.PtrOf(true), + common.PtrOf(int64(1000)), + }, + } + for _, tt := range []struct { + name string + response *s.PollForActivityTaskResponse + expected *activityTask + }{ + { + "success with task", + task, + &activityTask{ + task: task, + }, + }, + { + "success with empty task", + emptyTask, + &activityTask{ + task: emptyTask, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + poller, client := buildActivityTaskPoller(t) + client.EXPECT().PollForActivityTask(gomock.Any(), gomock.Any(), gomock.Any()).Return(tt.response, nil) + result, err := poller.PollTask() + assert.NoError(t, err) + resultTask, ok := result.(*activityTask) + assert.True(t, ok) + assert.Equal(t, tt.expected.task, resultTask.task) + }) + } + }) +} + func TestLocalActivityPanic(t *testing.T) { // regression: panics in local activities should not terminate the process s := WorkflowTestSuite{logger: testlogger.NewZap(t)} @@ -213,3 +311,20 @@ func buildWorkflowTaskPoller(t *testing.T) (*workflowTaskPoller, *workflowservic featureFlags: FeatureFlags{}, }, mockService, taskHandler, lda } + +func buildActivityTaskPoller(t *testing.T) (*activityTaskPoller, *workflowservicetest.MockClient) { + ctrl := gomock.NewController(t) + mockService := workflowservicetest.NewMockClient(ctrl) + return &activityTaskPoller{ + basePoller: basePoller{ + shutdownC: make(<-chan struct{}), + }, + domain: _testDomainName, + taskListName: _testTaskList, + identity: _testIdentity, + service: mockService, + metricsScope: &metrics.TaggedScope{Scope: tally.NewTestScope("test", nil)}, + logger: testlogger.NewZap(t), + featureFlags: FeatureFlags{}, + }, mockService +}