Skip to content

Commit

Permalink
address comments and add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shijiesheng committed Nov 26, 2024
1 parent a982c04 commit b99b8ce
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 13 deletions.
20 changes: 10 additions & 10 deletions internal/internal_worker_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ type (
logger *zap.Logger
metricsScope tally.Scope

dynamic *worker.DynamicParams
concurrency *worker.Concurrency
pollerAutoScaler *pollerAutoScaler
taskQueueCh chan interface{}
sessionTokenBucket *sessionTokenBucket
Expand All @@ -168,18 +168,18 @@ func createPollRetryPolicy() backoff.RetryPolicy {
func newBaseWorker(options baseWorkerOptions, logger *zap.Logger, metricsScope tally.Scope, sessionTokenBucket *sessionTokenBucket) *baseWorker {
ctx, cancel := context.WithCancel(context.Background())

dynamic := &worker.DynamicParams{
concurrency := &worker.Concurrency{
PollerPermit: worker.NewPermit(options.pollerCount),
TaskPermit: worker.NewPermit(options.maxConcurrentTask),
}

var pollerAS *pollerAutoScaler
if pollerOptions := options.pollerAutoScaler; pollerOptions.Enabled {
dynamic.PollerPermit = worker.NewPermit(pollerOptions.InitCount)
concurrency.PollerPermit = worker.NewPermit(pollerOptions.InitCount)
pollerAS = newPollerScaler(
pollerOptions,
logger,
dynamic.PollerPermit,
concurrency.PollerPermit,
)
}

Expand All @@ -190,7 +190,7 @@ func newBaseWorker(options baseWorkerOptions, logger *zap.Logger, metricsScope t
retrier: backoff.NewConcurrentRetrier(pollOperationRetryPolicy),
logger: logger.With(zapcore.Field{Key: tagWorkerType, Type: zapcore.StringType, String: options.workerType}),
metricsScope: tagScope(metricsScope, tagWorkerType, options.workerType),
dynamic: dynamic,
concurrency: concurrency,
pollerAutoScaler: pollerAS,
taskQueueCh: make(chan interface{}), // no buffer, so poller only able to poll new task after previous is dispatched.
limiterContext: ctx,
Expand Down Expand Up @@ -252,13 +252,13 @@ func (bw *baseWorker) runPoller() {
select {
case <-bw.shutdownCh:
return
case <-bw.dynamic.TaskPermit.AcquireChan(bw.limiterContext, &bw.shutdownWG): // don't poll unless there is a task permit
case <-bw.concurrency.TaskPermit.AcquireChan(bw.limiterContext, &bw.shutdownWG): // don't poll unless there is a task permit
// TODO move to a centralized place inside the worker
// emit metrics on concurrent task permit quota and current task permit count
// NOTE task permit doesn't mean there is a task running, it still needs to poll until it gets a task to process
// thus the metrics is only an estimated value of how many tasks are running concurrently
bw.metricsScope.Gauge(metrics.ConcurrentTaskQuota).Update(float64(bw.dynamic.TaskPermit.Quota()))
bw.metricsScope.Gauge(metrics.PollerRequestBufferUsage).Update(float64(bw.dynamic.TaskPermit.Count()))
bw.metricsScope.Gauge(metrics.ConcurrentTaskQuota).Update(float64(bw.concurrency.TaskPermit.Quota()))
bw.metricsScope.Gauge(metrics.PollerRequestBufferUsage).Update(float64(bw.concurrency.TaskPermit.Count()))
if bw.sessionTokenBucket != nil {
bw.sessionTokenBucket.waitForAvailableToken()
}
Expand Down Expand Up @@ -339,7 +339,7 @@ func (bw *baseWorker) pollTask() {
case <-bw.shutdownCh:
}
} else {
bw.dynamic.TaskPermit.Release(1) // poll failed, trigger a new poll by returning a task permit
bw.concurrency.TaskPermit.Release(1) // poll failed, trigger a new poll by returning a task permit
}
}

Expand Down Expand Up @@ -374,7 +374,7 @@ func (bw *baseWorker) processTask(task interface{}) {
}

if isPolledTask {
bw.dynamic.TaskPermit.Release(1) // task processed, trigger a new poll by returning a task permit
bw.concurrency.TaskPermit.Release(1) // task processed, trigger a new poll by returning a task permit
}
}()
err := bw.options.taskWorker.ProcessTask(task)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ import (

var _ Permit = (*permit)(nil)

// Synchronization contains synchronization primitives for dynamic configuration.
type DynamicParams struct {
// Concurrency contains synchronization primitives for dynamically controlling the concurrencies in workers
type Concurrency struct {
PollerPermit Permit // controls concurrency of pollers
TaskPermit Permit // controlls concurrency of task processings
}
Expand Down Expand Up @@ -69,13 +69,14 @@ func (p *permit) AcquireChan(ctx context.Context, wg *sync.WaitGroup) <-chan str
wg.Add(1)
go func() {
defer wg.Done()
defer close(ch) // close channel when permit is acquired or expired
if err := p.sem.Acquire(ctx, 1); err != nil {
close(ch)
return
}
select { // try to send to channel, but don't block if listener is gone
case ch <- struct{}{}:
default:
p.sem.Release(1)
}
}()
return ch
Expand Down
139 changes: 139 additions & 0 deletions internal/worker/concurrency_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright (c) 2017-2021 Uber Technologies Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package worker

import (
"context"
"sync"
"testing"
"time"

"math/rand"

"github.com/stretchr/testify/assert"
"go.uber.org/atomic"
)

func TestPermit_Simulation(t *testing.T) {
tests := []struct{
name string
capacity []int // update every 50ms
goroutines int // each would block on acquiring 2-6 tokens for 100ms
goroutinesAcquireChan int // each would block using AcquireChan for 100ms
maxTestDuration time.Duration
expectFailures int
expectFailuresAtLeast int
} {
{
name: "enough permit, no blocking",
maxTestDuration: 200*time.Millisecond,
capacity: []int{1000},
goroutines: 100,
goroutinesAcquireChan: 100,
expectFailures: 0,
},
{
name: "not enough permit, blocking but all acquire",
maxTestDuration: 1*time.Second,
capacity: []int{200},
goroutines: 500,
goroutinesAcquireChan: 500,
expectFailures: 0,
},
{
name: "not enough permit for some to acquire, fail some",
maxTestDuration: 100*time.Millisecond,
capacity: []int{100},
goroutines: 500,
goroutinesAcquireChan: 500,
expectFailuresAtLeast: 1,
},
{
name: "not enough permit at beginning but due to capacity change, blocking but all acquire",
maxTestDuration: 100*time.Second,
capacity: []int{100, 200, 300},
goroutines: 500,
goroutinesAcquireChan: 500,
expectFailures: 0,
},
{
name: "not enough permit for any acquire, fail all",
maxTestDuration: 1*time.Second,
capacity: []int{0},
goroutines: 1000,
expectFailures: 1000,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wg := &sync.WaitGroup{}
permit := NewPermit(tt.capacity[0])
wg.Add(1)
go func() { // update quota every 50ms
defer wg.Done()
for i := 1; i < len(tt.capacity); i++ {
time.Sleep(50*time.Millisecond)
permit.SetQuota(tt.capacity[i])
}
}()
failures := atomic.NewInt32(0)
ctx, cancel := context.WithTimeout(context.Background(), tt.maxTestDuration)
defer cancel()
for i := 0; i < tt.goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
num := rand.Intn(2)+2
// num := 1
if err := permit.Acquire(ctx, num); err != nil {
failures.Add(1)
return
}
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
permit.Release(num)
}()
}
for i := 0; i < tt.goroutinesAcquireChan; i++ {
wg.Add(1)
go func() {
defer wg.Done()
select {
case <-permit.AcquireChan(ctx, wg):
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
permit.Release(1)
case <-ctx.Done():
failures.Add(1)
}
}()
}

wg.Wait()
assert.Equal(t, 0, permit.Count())
if tt.expectFailuresAtLeast >0 {
assert.LessOrEqual(t, tt.expectFailuresAtLeast, int(failures.Load()))
} else {
assert.Equal(t, tt.expectFailures, int(failures.Load()))
}
assert.Equal(t, tt.capacity[len(tt.capacity)-1], permit.Quota())
})
}
}

0 comments on commit b99b8ce

Please sign in to comment.