From b99b8ce6eb9d4e7078f645de4c5061dcff0710ab Mon Sep 17 00:00:00 2001 From: Shijie Sheng Date: Tue, 26 Nov 2024 09:38:32 -0800 Subject: [PATCH] address comments and add unit tests --- internal/internal_worker_base.go | 20 +-- .../{dynamic_params.go => concurrency.go} | 7 +- internal/worker/concurrency_test.go | 139 ++++++++++++++++++ 3 files changed, 153 insertions(+), 13 deletions(-) rename internal/worker/{dynamic_params.go => concurrency.go} (92%) create mode 100644 internal/worker/concurrency_test.go diff --git a/internal/internal_worker_base.go b/internal/internal_worker_base.go index 66f053112..73ca9461b 100644 --- a/internal/internal_worker_base.go +++ b/internal/internal_worker_base.go @@ -142,7 +142,7 @@ type ( logger *zap.Logger metricsScope tally.Scope - dynamic *worker.DynamicParams + concurrency *worker.Concurrency pollerAutoScaler *pollerAutoScaler taskQueueCh chan interface{} sessionTokenBucket *sessionTokenBucket @@ -168,18 +168,18 @@ func createPollRetryPolicy() backoff.RetryPolicy { func newBaseWorker(options baseWorkerOptions, logger *zap.Logger, metricsScope tally.Scope, sessionTokenBucket *sessionTokenBucket) *baseWorker { ctx, cancel := context.WithCancel(context.Background()) - dynamic := &worker.DynamicParams{ + concurrency := &worker.Concurrency{ PollerPermit: worker.NewPermit(options.pollerCount), TaskPermit: worker.NewPermit(options.maxConcurrentTask), } var pollerAS *pollerAutoScaler if pollerOptions := options.pollerAutoScaler; pollerOptions.Enabled { - dynamic.PollerPermit = worker.NewPermit(pollerOptions.InitCount) + concurrency.PollerPermit = worker.NewPermit(pollerOptions.InitCount) pollerAS = newPollerScaler( pollerOptions, logger, - dynamic.PollerPermit, + concurrency.PollerPermit, ) } @@ -190,7 +190,7 @@ func newBaseWorker(options baseWorkerOptions, logger *zap.Logger, metricsScope t retrier: backoff.NewConcurrentRetrier(pollOperationRetryPolicy), logger: logger.With(zapcore.Field{Key: tagWorkerType, Type: zapcore.StringType, String: options.workerType}), metricsScope: tagScope(metricsScope, tagWorkerType, options.workerType), - dynamic: dynamic, + concurrency: concurrency, pollerAutoScaler: pollerAS, taskQueueCh: make(chan interface{}), // no buffer, so poller only able to poll new task after previous is dispatched. limiterContext: ctx, @@ -252,13 +252,13 @@ func (bw *baseWorker) runPoller() { select { case <-bw.shutdownCh: return - case <-bw.dynamic.TaskPermit.AcquireChan(bw.limiterContext, &bw.shutdownWG): // don't poll unless there is a task permit + case <-bw.concurrency.TaskPermit.AcquireChan(bw.limiterContext, &bw.shutdownWG): // don't poll unless there is a task permit // TODO move to a centralized place inside the worker // emit metrics on concurrent task permit quota and current task permit count // NOTE task permit doesn't mean there is a task running, it still needs to poll until it gets a task to process // thus the metrics is only an estimated value of how many tasks are running concurrently - bw.metricsScope.Gauge(metrics.ConcurrentTaskQuota).Update(float64(bw.dynamic.TaskPermit.Quota())) - bw.metricsScope.Gauge(metrics.PollerRequestBufferUsage).Update(float64(bw.dynamic.TaskPermit.Count())) + bw.metricsScope.Gauge(metrics.ConcurrentTaskQuota).Update(float64(bw.concurrency.TaskPermit.Quota())) + bw.metricsScope.Gauge(metrics.PollerRequestBufferUsage).Update(float64(bw.concurrency.TaskPermit.Count())) if bw.sessionTokenBucket != nil { bw.sessionTokenBucket.waitForAvailableToken() } @@ -339,7 +339,7 @@ func (bw *baseWorker) pollTask() { case <-bw.shutdownCh: } } else { - bw.dynamic.TaskPermit.Release(1) // poll failed, trigger a new poll by returning a task permit + bw.concurrency.TaskPermit.Release(1) // poll failed, trigger a new poll by returning a task permit } } @@ -374,7 +374,7 @@ func (bw *baseWorker) processTask(task interface{}) { } if isPolledTask { - bw.dynamic.TaskPermit.Release(1) // task processed, trigger a new poll by returning a task permit + bw.concurrency.TaskPermit.Release(1) // task processed, trigger a new poll by returning a task permit } }() err := bw.options.taskWorker.ProcessTask(task) diff --git a/internal/worker/dynamic_params.go b/internal/worker/concurrency.go similarity index 92% rename from internal/worker/dynamic_params.go rename to internal/worker/concurrency.go index 34babb3b2..45d0dadc8 100644 --- a/internal/worker/dynamic_params.go +++ b/internal/worker/concurrency.go @@ -30,8 +30,8 @@ import ( var _ Permit = (*permit)(nil) -// Synchronization contains synchronization primitives for dynamic configuration. -type DynamicParams struct { +// Concurrency contains synchronization primitives for dynamically controlling the concurrencies in workers +type Concurrency struct { PollerPermit Permit // controls concurrency of pollers TaskPermit Permit // controlls concurrency of task processings } @@ -69,13 +69,14 @@ func (p *permit) AcquireChan(ctx context.Context, wg *sync.WaitGroup) <-chan str wg.Add(1) go func() { defer wg.Done() + defer close(ch) // close channel when permit is acquired or expired if err := p.sem.Acquire(ctx, 1); err != nil { - close(ch) return } select { // try to send to channel, but don't block if listener is gone case ch <- struct{}{}: default: + p.sem.Release(1) } }() return ch diff --git a/internal/worker/concurrency_test.go b/internal/worker/concurrency_test.go new file mode 100644 index 000000000..87049957b --- /dev/null +++ b/internal/worker/concurrency_test.go @@ -0,0 +1,139 @@ +// Copyright (c) 2017-2021 Uber Technologies Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package worker + +import ( + "context" + "sync" + "testing" + "time" + + "math/rand" + + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestPermit_Simulation(t *testing.T) { + tests := []struct{ + name string + capacity []int // update every 50ms + goroutines int // each would block on acquiring 2-6 tokens for 100ms + goroutinesAcquireChan int // each would block using AcquireChan for 100ms + maxTestDuration time.Duration + expectFailures int + expectFailuresAtLeast int + } { + { + name: "enough permit, no blocking", + maxTestDuration: 200*time.Millisecond, + capacity: []int{1000}, + goroutines: 100, + goroutinesAcquireChan: 100, + expectFailures: 0, + }, + { + name: "not enough permit, blocking but all acquire", + maxTestDuration: 1*time.Second, + capacity: []int{200}, + goroutines: 500, + goroutinesAcquireChan: 500, + expectFailures: 0, + }, + { + name: "not enough permit for some to acquire, fail some", + maxTestDuration: 100*time.Millisecond, + capacity: []int{100}, + goroutines: 500, + goroutinesAcquireChan: 500, + expectFailuresAtLeast: 1, + }, + { + name: "not enough permit at beginning but due to capacity change, blocking but all acquire", + maxTestDuration: 100*time.Second, + capacity: []int{100, 200, 300}, + goroutines: 500, + goroutinesAcquireChan: 500, + expectFailures: 0, + }, + { + name: "not enough permit for any acquire, fail all", + maxTestDuration: 1*time.Second, + capacity: []int{0}, + goroutines: 1000, + expectFailures: 1000, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wg := &sync.WaitGroup{} + permit := NewPermit(tt.capacity[0]) + wg.Add(1) + go func() { // update quota every 50ms + defer wg.Done() + for i := 1; i < len(tt.capacity); i++ { + time.Sleep(50*time.Millisecond) + permit.SetQuota(tt.capacity[i]) + } + }() + failures := atomic.NewInt32(0) + ctx, cancel := context.WithTimeout(context.Background(), tt.maxTestDuration) + defer cancel() + for i := 0; i < tt.goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + num := rand.Intn(2)+2 + // num := 1 + if err := permit.Acquire(ctx, num); err != nil { + failures.Add(1) + return + } + time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) + permit.Release(num) + }() + } + for i := 0; i < tt.goroutinesAcquireChan; i++ { + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-permit.AcquireChan(ctx, wg): + time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) + permit.Release(1) + case <-ctx.Done(): + failures.Add(1) + } + }() + } + + wg.Wait() + assert.Equal(t, 0, permit.Count()) + if tt.expectFailuresAtLeast >0 { + assert.LessOrEqual(t, tt.expectFailuresAtLeast, int(failures.Load())) + } else { + assert.Equal(t, tt.expectFailures, int(failures.Load())) + } + assert.Equal(t, tt.capacity[len(tt.capacity)-1], permit.Quota()) + }) + } +}