diff --git a/ants.go b/ants.go index d67ab5c..eae6a14 100644 --- a/ants.go +++ b/ants.go @@ -84,6 +84,9 @@ var ( // ErrInvalidLoadBalancingStrategy will be returned when trying to create a MultiPool with an invalid load-balancing strategy. ErrInvalidLoadBalancingStrategy = errors.New("invalid load-balancing strategy") + // ErrInvalidMultiPoolSize will be returned when trying to create a MultiPool with an invalid size. + ErrInvalidMultiPoolSize = errors.New("invalid size for multiple pool") + // workerChanCap determines whether the channel of a worker should be a buffered channel // to get the best performance. Inspired by fasthttp at // https://github.com/valyala/fasthttp/blob/master/workerpool.go#L139 @@ -387,6 +390,7 @@ func (p *poolCommon) Release() { p.lock.Lock() p.workers.reset() p.lock.Unlock() + // There might be some callers waiting in retrieveWorker(), so we need to wake them up to prevent // those callers blocking infinitely. p.cond.Broadcast() diff --git a/ants_benchmark_test.go b/ants_benchmark_test.go index 1dcc8dd..33b4c1e 100644 --- a/ants_benchmark_test.go +++ b/ants_benchmark_test.go @@ -48,6 +48,10 @@ func demoPoolFunc(args any) { time.Sleep(time.Duration(n) * time.Millisecond) } +func demoPoolFuncInt(n int) { + time.Sleep(time.Duration(n) * time.Millisecond) +} + var stopLongRunningFunc int32 func longRunningFunc() { @@ -56,16 +60,12 @@ func longRunningFunc() { } } -var stopLongRunningPoolFunc int32 - func longRunningPoolFunc(arg any) { - if ch, ok := arg.(chan struct{}); ok { - <-ch - return - } - for atomic.LoadInt32(&stopLongRunningPoolFunc) == 0 { - runtime.Gosched() - } + <-arg.(chan struct{}) +} + +func longRunningPoolFuncCh(ch chan struct{}) { + <-ch } func BenchmarkGoroutines(b *testing.B) { diff --git a/ants_test.go b/ants_test.go index 7909ea2..316497d 100644 --- a/ants_test.go +++ b/ants_test.go @@ -31,7 +31,7 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -111,6 +111,27 @@ func TestAntsPoolWithFuncWaitToGetWorker(t *testing.T) { t.Logf("memory usage:%d MB", curMem) } +// TestAntsPoolWithFuncGenericWaitToGetWorker is used to test waiting to get worker. +func TestAntsPoolWithFuncGenericWaitToGetWorker(t *testing.T) { + var wg sync.WaitGroup + p, _ := NewPoolWithFuncGeneric(AntsSize, func(i int) { + demoPoolFuncInt(i) + wg.Done() + }) + defer p.Release() + + for i := 0; i < n; i++ { + wg.Add(1) + _ = p.Invoke(Param) + } + wg.Wait() + t.Logf("pool with func, running workers number:%d", p.Running()) + mem := runtime.MemStats{} + runtime.ReadMemStats(&mem) + curMem = mem.TotalAlloc/MiB - curMem + t.Logf("memory usage:%d MB", curMem) +} + func TestAntsPoolWithFuncWaitToGetWorkerPreMalloc(t *testing.T) { var wg sync.WaitGroup p, _ := NewPoolWithFunc(AntsSize, func(i any) { @@ -131,6 +152,26 @@ func TestAntsPoolWithFuncWaitToGetWorkerPreMalloc(t *testing.T) { t.Logf("memory usage:%d MB", curMem) } +func TestAntsPoolWithFuncGenericWaitToGetWorkerPreMalloc(t *testing.T) { + var wg sync.WaitGroup + p, _ := NewPoolWithFuncGeneric(AntsSize, func(i int) { + demoPoolFuncInt(i) + wg.Done() + }, WithPreAlloc(true)) + defer p.Release() + + for i := 0; i < n; i++ { + wg.Add(1) + _ = p.Invoke(Param) + } + wg.Wait() + t.Logf("pool with func, running workers number:%d", p.Running()) + mem := runtime.MemStats{} + runtime.ReadMemStats(&mem) + curMem = mem.TotalAlloc/MiB - curMem + t.Logf("memory usage:%d MB", curMem) +} + // TestAntsPoolGetWorkerFromCache is used to test getting worker from sync.Pool. func TestAntsPoolGetWorkerFromCache(t *testing.T) { p, _ := NewPool(TestSize) @@ -166,6 +207,24 @@ func TestAntsPoolWithFuncGetWorkerFromCache(t *testing.T) { t.Logf("memory usage:%d MB", curMem) } +// TestAntsPoolWithFuncGenericGetWorkerFromCache is used to test getting worker from sync.Pool. +func TestAntsPoolWithFuncGenericGetWorkerFromCache(t *testing.T) { + dur := 10 + p, _ := NewPoolWithFuncGeneric(TestSize, demoPoolFuncInt) + defer p.Release() + + for i := 0; i < AntsSize; i++ { + _ = p.Invoke(dur) + } + time.Sleep(2 * DefaultCleanIntervalTime) + _ = p.Invoke(dur) + t.Logf("pool with func, running workers number:%d", p.Running()) + mem := runtime.MemStats{} + runtime.ReadMemStats(&mem) + curMem = mem.TotalAlloc/MiB - curMem + t.Logf("memory usage:%d MB", curMem) +} + func TestAntsPoolWithFuncGetWorkerFromCachePreMalloc(t *testing.T) { dur := 10 p, _ := NewPoolWithFunc(TestSize, demoPoolFunc, WithPreAlloc(true)) @@ -183,6 +242,23 @@ func TestAntsPoolWithFuncGetWorkerFromCachePreMalloc(t *testing.T) { t.Logf("memory usage:%d MB", curMem) } +func TestAntsPoolWithFuncGenericGetWorkerFromCachePreMalloc(t *testing.T) { + dur := 10 + p, _ := NewPoolWithFuncGeneric(TestSize, demoPoolFuncInt, WithPreAlloc(true)) + defer p.Release() + + for i := 0; i < AntsSize; i++ { + _ = p.Invoke(dur) + } + time.Sleep(2 * DefaultCleanIntervalTime) + _ = p.Invoke(dur) + t.Logf("pool with func, running workers number:%d", p.Running()) + mem := runtime.MemStats{} + runtime.ReadMemStats(&mem) + curMem = mem.TotalAlloc/MiB - curMem + t.Logf("memory usage:%d MB", curMem) +} + // Contrast between goroutines without a pool and goroutines with ants pool. func TestNoPool(t *testing.T) { @@ -232,7 +308,7 @@ func TestPanicHandler(t *testing.T) { atomic.AddInt64(&panicCounter, 1) t.Logf("catch panic with PanicHandler: %v", p) })) - assert.NoErrorf(t, err, "create new pool failed: %v", err) + require.NoErrorf(t, err, "create new pool failed: %v", err) defer p0.Release() wg.Add(1) _ = p0.Submit(func() { @@ -240,20 +316,34 @@ func TestPanicHandler(t *testing.T) { }) wg.Wait() c := atomic.LoadInt64(&panicCounter) - assert.EqualValuesf(t, 1, c, "panic handler didn't work, panicCounter: %d", c) - assert.EqualValues(t, 0, p0.Running(), "pool should be empty after panic") + require.EqualValuesf(t, 1, c, "panic handler didn't work, panicCounter: %d", c) + require.EqualValues(t, 0, p0.Running(), "pool should be empty after panic") + p1, err := NewPoolWithFunc(10, func(p any) { panic(p) }, WithPanicHandler(func(_ any) { defer wg.Done() atomic.AddInt64(&panicCounter, 1) })) - assert.NoErrorf(t, err, "create new pool with func failed: %v", err) + require.NoErrorf(t, err, "create new pool with func failed: %v", err) defer p1.Release() wg.Add(1) _ = p1.Invoke("Oops!") wg.Wait() c = atomic.LoadInt64(&panicCounter) - assert.EqualValuesf(t, 2, c, "panic handler didn't work, panicCounter: %d", c) - assert.EqualValues(t, 0, p1.Running(), "pool should be empty after panic") + require.EqualValuesf(t, 2, c, "panic handler didn't work, panicCounter: %d", c) + require.EqualValues(t, 0, p1.Running(), "pool should be empty after panic") + + p2, err := NewPoolWithFuncGeneric(10, func(s string) { panic(s) }, WithPanicHandler(func(_ any) { + defer wg.Done() + atomic.AddInt64(&panicCounter, 1) + })) + require.NoErrorf(t, err, "create new pool with func failed: %v", err) + defer p2.Release() + wg.Add(1) + _ = p2.Invoke("Oops!") + wg.Wait() + c = atomic.LoadInt64(&panicCounter) + require.EqualValuesf(t, 3, c, "panic handler didn't work, panicCounter: %d", c) + require.EqualValues(t, 0, p2.Running(), "pool should be empty after panic") } func TestPanicHandlerPreMalloc(t *testing.T) { @@ -264,7 +354,7 @@ func TestPanicHandlerPreMalloc(t *testing.T) { atomic.AddInt64(&panicCounter, 1) t.Logf("catch panic with PanicHandler: %v", p) })) - assert.NoErrorf(t, err, "create new pool failed: %v", err) + require.NoErrorf(t, err, "create new pool failed: %v", err) defer p0.Release() wg.Add(1) _ = p0.Submit(func() { @@ -272,41 +362,58 @@ func TestPanicHandlerPreMalloc(t *testing.T) { }) wg.Wait() c := atomic.LoadInt64(&panicCounter) - assert.EqualValuesf(t, 1, c, "panic handler didn't work, panicCounter: %d", c) - assert.EqualValues(t, 0, p0.Running(), "pool should be empty after panic") - p1, err := NewPoolWithFunc(10, func(p any) { panic(p) }, WithPanicHandler(func(_ any) { + require.EqualValuesf(t, 1, c, "panic handler didn't work, panicCounter: %d", c) + require.EqualValues(t, 0, p0.Running(), "pool should be empty after panic") + + p1, err := NewPoolWithFunc(10, func(p any) { panic(p) }, WithPreAlloc(true), WithPanicHandler(func(_ any) { defer wg.Done() atomic.AddInt64(&panicCounter, 1) })) - assert.NoErrorf(t, err, "create new pool with func failed: %v", err) + require.NoErrorf(t, err, "create new pool with func failed: %v", err) defer p1.Release() wg.Add(1) _ = p1.Invoke("Oops!") wg.Wait() c = atomic.LoadInt64(&panicCounter) - assert.EqualValuesf(t, 2, c, "panic handler didn't work, panicCounter: %d", c) - assert.EqualValues(t, 0, p1.Running(), "pool should be empty after panic") + require.EqualValuesf(t, 2, c, "panic handler didn't work, panicCounter: %d", c) + require.EqualValues(t, 0, p1.Running(), "pool should be empty after panic") + + p2, err := NewPoolWithFuncGeneric(10, func(p string) { panic(p) }, WithPreAlloc(true), WithPanicHandler(func(_ any) { + defer wg.Done() + atomic.AddInt64(&panicCounter, 1) + })) + require.NoErrorf(t, err, "create new pool with func failed: %v", err) + defer p2.Release() + wg.Add(1) + _ = p2.Invoke("Oops!") + wg.Wait() + c = atomic.LoadInt64(&panicCounter) + require.EqualValuesf(t, 3, c, "panic handler didn't work, panicCounter: %d", c) + require.EqualValues(t, 0, p1.Running(), "pool should be empty after panic") } func TestPoolPanicWithoutHandler(t *testing.T) { p0, err := NewPool(10) - assert.NoErrorf(t, err, "create new pool failed: %v", err) + require.NoErrorf(t, err, "create new pool failed: %v", err) defer p0.Release() _ = p0.Submit(func() { panic("Oops!") }) - p1, err := NewPoolWithFunc(10, func(p any) { - panic(p) - }) - assert.NoErrorf(t, err, "create new pool with func failed: %v", err) + p1, err := NewPoolWithFunc(10, func(p any) { panic(p) }) + require.NoErrorf(t, err, "create new pool with func failed: %v", err) defer p1.Release() _ = p1.Invoke("Oops!") + + p2, err := NewPoolWithFuncGeneric(10, func(p string) { panic(p) }) + require.NoErrorf(t, err, "create new pool with func failed: %v", err) + defer p2.Release() + _ = p2.Invoke("Oops!") } func TestPoolPanicWithoutHandlerPreMalloc(t *testing.T) { p0, err := NewPool(10, WithPreAlloc(true)) - assert.NoErrorf(t, err, "create new pool failed: %v", err) + require.NoErrorf(t, err, "create new pool failed: %v", err) defer p0.Release() _ = p0.Submit(func() { panic("Oops!") @@ -315,11 +422,16 @@ func TestPoolPanicWithoutHandlerPreMalloc(t *testing.T) { p1, err := NewPoolWithFunc(10, func(p any) { panic(p) }) - - assert.NoErrorf(t, err, "create new pool with func failed: %v", err) - + require.NoErrorf(t, err, "create new pool with func failed: %v", err) defer p1.Release() _ = p1.Invoke("Oops!") + + p2, err := NewPoolWithFuncGeneric(10, func(p any) { + panic(p) + }) + require.NoErrorf(t, err, "create new pool with func failed: %v", err) + defer p2.Release() + _ = p2.Invoke("Oops!") } func TestPurgePool(t *testing.T) { @@ -327,7 +439,7 @@ func TestPurgePool(t *testing.T) { ch := make(chan struct{}) p, err := NewPool(size) - assert.NoErrorf(t, err, "create TimingPool failed: %v", err) + require.NoErrorf(t, err, "create TimingPool failed: %v", err) defer p.Release() for i := 0; i < size; i++ { @@ -338,11 +450,11 @@ func TestPurgePool(t *testing.T) { time.Sleep(time.Duration(d) * time.Millisecond) }) } - assert.Equalf(t, size, p.Running(), "pool should be full, expected: %d, but got: %d", size, p.Running()) + require.Equalf(t, size, p.Running(), "pool should be full, expected: %d, but got: %d", size, p.Running()) close(ch) time.Sleep(5 * DefaultCleanIntervalTime) - assert.Equalf(t, 0, p.Running(), "pool should be empty after purge, but got %d", p.Running()) + require.Equalf(t, 0, p.Running(), "pool should be empty after purge, but got %d", p.Running()) ch = make(chan struct{}) f := func(i any) { @@ -352,41 +464,69 @@ func TestPurgePool(t *testing.T) { } p1, err := NewPoolWithFunc(size, f) - assert.NoErrorf(t, err, "create TimingPoolWithFunc failed: %v", err) + require.NoErrorf(t, err, "create TimingPoolWithFunc failed: %v", err) defer p1.Release() for i := 0; i < size; i++ { _ = p1.Invoke(i) } - assert.Equalf(t, size, p1.Running(), "pool should be full, expected: %d, but got: %d", size, p1.Running()) + require.Equalf(t, size, p1.Running(), "pool should be full, expected: %d, but got: %d", size, p1.Running()) + + close(ch) + time.Sleep(5 * DefaultCleanIntervalTime) + require.Equalf(t, 0, p1.Running(), "pool should be empty after purge, but got %d", p1.Running()) + + ch = make(chan struct{}) + f1 := func(i int) { + <-ch + d := i % 100 + time.Sleep(time.Duration(d) * time.Millisecond) + } + + p2, err := NewPoolWithFuncGeneric(size, f1) + require.NoErrorf(t, err, "create TimingPoolWithFunc failed: %v", err) + defer p2.Release() + + for i := 0; i < size; i++ { + _ = p2.Invoke(i) + } + require.Equalf(t, size, p2.Running(), "pool should be full, expected: %d, but got: %d", size, p2.Running()) close(ch) time.Sleep(5 * DefaultCleanIntervalTime) - assert.Equalf(t, 0, p1.Running(), "pool should be empty after purge, but got %d", p1.Running()) + require.Equalf(t, 0, p2.Running(), "pool should be empty after purge, but got %d", p2.Running()) } func TestPurgePreMallocPool(t *testing.T) { p, err := NewPool(10, WithPreAlloc(true)) - assert.NoErrorf(t, err, "create TimingPool failed: %v", err) + require.NoErrorf(t, err, "create TimingPool failed: %v", err) defer p.Release() _ = p.Submit(demoFunc) time.Sleep(3 * DefaultCleanIntervalTime) - assert.EqualValues(t, 0, p.Running(), "all p should be purged") + require.EqualValues(t, 0, p.Running(), "all p should be purged") + p1, err := NewPoolWithFunc(10, demoPoolFunc) - assert.NoErrorf(t, err, "create TimingPoolWithFunc failed: %v", err) + require.NoErrorf(t, err, "create TimingPoolWithFunc failed: %v", err) defer p1.Release() _ = p1.Invoke(1) time.Sleep(3 * DefaultCleanIntervalTime) - assert.EqualValues(t, 0, p.Running(), "all p should be purged") + require.EqualValues(t, 0, p1.Running(), "all p should be purged") + + p2, err := NewPoolWithFuncGeneric(10, demoPoolFuncInt) + require.NoErrorf(t, err, "create TimingPoolWithFunc failed: %v", err) + defer p2.Release() + _ = p2.Invoke(1) + time.Sleep(3 * DefaultCleanIntervalTime) + require.EqualValues(t, 0, p2.Running(), "all p should be purged") } func TestNonblockingSubmit(t *testing.T) { poolSize := 10 p, err := NewPool(poolSize, WithNonblocking(true)) - assert.NoErrorf(t, err, "create TimingPool failed: %v", err) + require.NoErrorf(t, err, "create TimingPool failed: %v", err) defer p.Release() for i := 0; i < poolSize-1; i++ { - assert.NoError(t, p.Submit(longRunningFunc), "nonblocking submit when pool is not full shouldn't return error") + require.NoError(t, p.Submit(longRunningFunc), "nonblocking submit when pool is not full shouldn't return error") } ch := make(chan struct{}) ch1 := make(chan struct{}) @@ -395,29 +535,29 @@ func TestNonblockingSubmit(t *testing.T) { close(ch1) } // p is full now. - assert.NoError(t, p.Submit(f), "nonblocking submit when pool is not full shouldn't return error") - assert.EqualError(t, p.Submit(demoFunc), ErrPoolOverload.Error(), + require.NoError(t, p.Submit(f), "nonblocking submit when pool is not full shouldn't return error") + require.ErrorIsf(t, p.Submit(demoFunc), ErrPoolOverload, "nonblocking submit when pool is full should get an ErrPoolOverload") // interrupt f to get an available worker close(ch) <-ch1 - assert.NoError(t, p.Submit(demoFunc), "nonblocking submit when pool is not full shouldn't return error") + require.NoError(t, p.Submit(demoFunc), "nonblocking submit when pool is not full shouldn't return error") } func TestMaxBlockingSubmit(t *testing.T) { poolSize := 10 p, err := NewPool(poolSize, WithMaxBlockingTasks(1)) - assert.NoErrorf(t, err, "create TimingPool failed: %v", err) + require.NoErrorf(t, err, "create TimingPool failed: %v", err) defer p.Release() for i := 0; i < poolSize-1; i++ { - assert.NoError(t, p.Submit(longRunningFunc), "submit when pool is not full shouldn't return error") + require.NoError(t, p.Submit(longRunningFunc), "submit when pool is not full shouldn't return error") } ch := make(chan struct{}) f := func() { <-ch } // p is full now. - assert.NoError(t, p.Submit(f), "submit when pool is not full shouldn't return error") + require.NoError(t, p.Submit(f), "submit when pool is not full shouldn't return error") var wg sync.WaitGroup wg.Add(1) errCh := make(chan error, 1) @@ -430,7 +570,7 @@ func TestMaxBlockingSubmit(t *testing.T) { }() time.Sleep(1 * time.Second) // already reached max blocking limit - assert.EqualError(t, p.Submit(demoFunc), ErrPoolOverload.Error(), + require.ErrorIsf(t, p.Submit(demoFunc), ErrPoolOverload, "blocking submit when pool reach max blocking submit should return ErrPoolOverload") // interrupt f to make blocking submit successful. close(ch) @@ -444,52 +584,115 @@ func TestMaxBlockingSubmit(t *testing.T) { func TestNonblockingSubmitWithFunc(t *testing.T) { poolSize := 10 + ch := make(chan struct{}) var wg sync.WaitGroup p, err := NewPoolWithFunc(poolSize, func(i any) { longRunningPoolFunc(i) wg.Done() }, WithNonblocking(true)) - assert.NoError(t, err, "create TimingPool failed: %v", err) + require.NoError(t, err, "create TimingPool failed: %v", err) + defer p.Release() + wg.Add(poolSize) + for i := 0; i < poolSize-1; i++ { + require.NoError(t, p.Invoke(ch), "nonblocking submit when pool is not full shouldn't return error") + } + // p is full now. + require.NoError(t, p.Invoke(ch), "nonblocking submit when pool is not full shouldn't return error") + require.ErrorIsf(t, p.Invoke(nil), ErrPoolOverload, + "nonblocking submit when pool is full should get an ErrPoolOverload") + // interrupt f to get an available worker + close(ch) + wg.Wait() + wg.Add(1) + require.NoError(t, p.Invoke(ch), "nonblocking submit when pool is not full shouldn't return error") + wg.Wait() +} + +func TestNonblockingSubmitWithFuncGeneric(t *testing.T) { + poolSize := 10 + var wg sync.WaitGroup + p, err := NewPoolWithFuncGeneric(poolSize, func(ch chan struct{}) { + longRunningPoolFuncCh(ch) + wg.Done() + }, WithNonblocking(true)) + require.NoError(t, err, "create TimingPool failed: %v", err) defer p.Release() ch := make(chan struct{}) wg.Add(poolSize) for i := 0; i < poolSize-1; i++ { - assert.NoError(t, p.Invoke(ch), "nonblocking submit when pool is not full shouldn't return error") + require.NoError(t, p.Invoke(ch), "nonblocking submit when pool is not full shouldn't return error") } // p is full now. - assert.NoError(t, p.Invoke(ch), "nonblocking submit when pool is not full shouldn't return error") - assert.EqualError(t, p.Invoke(nil), ErrPoolOverload.Error(), + require.NoError(t, p.Invoke(ch), "nonblocking submit when pool is not full shouldn't return error") + require.ErrorIsf(t, p.Invoke(nil), ErrPoolOverload, "nonblocking submit when pool is full should get an ErrPoolOverload") // interrupt f to get an available worker close(ch) wg.Wait() - assert.NoError(t, p.Invoke(nil), "nonblocking submit when pool is not full shouldn't return error") + wg.Add(1) + require.NoError(t, p.Invoke(ch), "nonblocking submit when pool is not full shouldn't return error") + wg.Wait() } func TestMaxBlockingSubmitWithFunc(t *testing.T) { + ch := make(chan struct{}) poolSize := 10 p, err := NewPoolWithFunc(poolSize, longRunningPoolFunc, WithMaxBlockingTasks(1)) - assert.NoError(t, err, "create TimingPool failed: %v", err) + require.NoError(t, err, "create TimingPool failed: %v", err) defer p.Release() for i := 0; i < poolSize-1; i++ { - assert.NoError(t, p.Invoke(Param), "submit when pool is not full shouldn't return error") + require.NoError(t, p.Invoke(ch), "submit when pool is not full shouldn't return error") } + // p is full now. + require.NoError(t, p.Invoke(ch), "submit when pool is not full shouldn't return error") + var wg sync.WaitGroup + wg.Add(1) + errCh := make(chan error, 1) + go func() { + // should be blocked. blocking num == 1 + if err := p.Invoke(ch); err != nil { + errCh <- err + } + wg.Done() + }() + time.Sleep(1 * time.Second) + // already reached max blocking limit + require.ErrorIsf(t, p.Invoke(ch), ErrPoolOverload, + "blocking submit when pool reach max blocking submit should return ErrPoolOverload: %v", err) + // interrupt one func to make blocking submit successful. + close(ch) + wg.Wait() + select { + case <-errCh: + t.Fatalf("blocking submit when pool is full should not return error") + default: + } +} + +func TestMaxBlockingSubmitWithFuncGeneric(t *testing.T) { + poolSize := 10 + p, err := NewPoolWithFuncGeneric(poolSize, longRunningPoolFuncCh, WithMaxBlockingTasks(1)) + require.NoError(t, err, "create TimingPool failed: %v", err) + defer p.Release() ch := make(chan struct{}) + for i := 0; i < poolSize-1; i++ { + require.NoError(t, p.Invoke(ch), "submit when pool is not full shouldn't return error") + } // p is full now. - assert.NoError(t, p.Invoke(ch), "submit when pool is not full shouldn't return error") + require.NoError(t, p.Invoke(ch), "submit when pool is not full shouldn't return error") var wg sync.WaitGroup wg.Add(1) errCh := make(chan error, 1) go func() { // should be blocked. blocking num == 1 - if err := p.Invoke(Param); err != nil { + if err := p.Invoke(ch); err != nil { errCh <- err } wg.Done() }() time.Sleep(1 * time.Second) // already reached max blocking limit - assert.EqualErrorf(t, p.Invoke(Param), ErrPoolOverload.Error(), + require.ErrorIsf(t, p.Invoke(ch), ErrPoolOverload, "blocking submit when pool reach max blocking submit should return ErrPoolOverload: %v", err) // interrupt one func to make blocking submit successful. close(ch) @@ -511,18 +714,18 @@ func TestRebootDefaultPool(t *testing.T) { wg.Done() }) wg.Wait() - assert.NoError(t, ReleaseTimeout(time.Second)) - assert.EqualError(t, Submit(nil), ErrPoolClosed.Error(), "pool should be closed") + require.NoError(t, ReleaseTimeout(time.Second)) + require.ErrorIsf(t, Submit(nil), ErrPoolClosed, "pool should be closed") Reboot() wg.Add(1) - assert.NoError(t, Submit(func() { wg.Done() }), "pool should be rebooted") + require.NoError(t, Submit(func() { wg.Done() }), "pool should be rebooted") wg.Wait() } func TestRebootNewPool(t *testing.T) { var wg sync.WaitGroup p, err := NewPool(10) - assert.NoErrorf(t, err, "create Pool failed: %v", err) + require.NoErrorf(t, err, "create Pool failed: %v", err) defer p.Release() wg.Add(1) _ = p.Submit(func() { @@ -530,27 +733,43 @@ func TestRebootNewPool(t *testing.T) { wg.Done() }) wg.Wait() - assert.NoError(t, p.ReleaseTimeout(time.Second)) - assert.EqualError(t, p.Submit(nil), ErrPoolClosed.Error(), "pool should be closed") + require.NoError(t, p.ReleaseTimeout(time.Second)) + require.ErrorIsf(t, p.Submit(nil), ErrPoolClosed, "pool should be closed") p.Reboot() wg.Add(1) - assert.NoError(t, p.Submit(func() { wg.Done() }), "pool should be rebooted") + require.NoError(t, p.Submit(func() { wg.Done() }), "pool should be rebooted") wg.Wait() p1, err := NewPoolWithFunc(10, func(i any) { demoPoolFunc(i) wg.Done() }) - assert.NoErrorf(t, err, "create TimingPoolWithFunc failed: %v", err) + require.NoErrorf(t, err, "create TimingPoolWithFunc failed: %v", err) defer p1.Release() wg.Add(1) _ = p1.Invoke(1) wg.Wait() - assert.NoError(t, p1.ReleaseTimeout(time.Second)) - assert.EqualError(t, p1.Invoke(nil), ErrPoolClosed.Error(), "pool should be closed") + require.NoError(t, p1.ReleaseTimeout(time.Second)) + require.ErrorIsf(t, p1.Invoke(nil), ErrPoolClosed, "pool should be closed") p1.Reboot() wg.Add(1) - assert.NoError(t, p1.Invoke(1), "pool should be rebooted") + require.NoError(t, p1.Invoke(1), "pool should be rebooted") + wg.Wait() + + p2, err := NewPoolWithFuncGeneric(10, func(i int) { + demoPoolFuncInt(i) + wg.Done() + }) + require.NoErrorf(t, err, "create TimingPoolWithFunc failed: %v", err) + defer p2.Release() + wg.Add(1) + _ = p2.Invoke(1) + wg.Wait() + require.NoError(t, p2.ReleaseTimeout(time.Second)) + require.ErrorIsf(t, p2.Invoke(1), ErrPoolClosed, "pool should be closed") + p2.Reboot() + wg.Add(1) + require.NoError(t, p2.Invoke(1), "pool should be rebooted") wg.Wait() } @@ -575,7 +794,7 @@ func TestInfinitePool(t *testing.T) { } var err error _, err = NewPool(-1, WithPreAlloc(true)) - assert.EqualErrorf(t, err, ErrInvalidPreAllocSize.Error(), "") + require.EqualErrorf(t, err, ErrInvalidPreAllocSize.Error(), "") } func testPoolWithDisablePurge(t *testing.T, p *Pool, numWorker int, waitForPurge time.Duration) { @@ -593,9 +812,9 @@ func testPoolWithDisablePurge(t *testing.T, p *Pool, numWorker int, waitForPurge wg1.Wait() runningCnt := p.Running() - assert.EqualValuesf(t, numWorker, runningCnt, "expect %d workers running, but got %d", numWorker, runningCnt) + require.EqualValuesf(t, numWorker, runningCnt, "expect %d workers running, but got %d", numWorker, runningCnt) freeCnt := p.Free() - assert.EqualValuesf(t, 0, freeCnt, "expect %d free workers, but got %d", 0, freeCnt) + require.EqualValuesf(t, 0, freeCnt, "expect %d free workers, but got %d", 0, freeCnt) // Finish all tasks and sleep for a while to wait for purging, since we've disabled purge mechanism, // we should see that all workers are still running after the sleep. @@ -604,17 +823,17 @@ func testPoolWithDisablePurge(t *testing.T, p *Pool, numWorker int, waitForPurge time.Sleep(waitForPurge + waitForPurge/2) runningCnt = p.Running() - assert.EqualValuesf(t, numWorker, runningCnt, "expect %d workers running, but got %d", numWorker, runningCnt) + require.EqualValuesf(t, numWorker, runningCnt, "expect %d workers running, but got %d", numWorker, runningCnt) freeCnt = p.Free() - assert.EqualValuesf(t, 0, freeCnt, "expect %d free workers, but got %d", 0, freeCnt) + require.EqualValuesf(t, 0, freeCnt, "expect %d free workers, but got %d", 0, freeCnt) err := p.ReleaseTimeout(waitForPurge + waitForPurge/2) - assert.NoErrorf(t, err, "release pool failed: %v", err) + require.NoErrorf(t, err, "release pool failed: %v", err) runningCnt = p.Running() - assert.EqualValuesf(t, 0, runningCnt, "expect %d workers running, but got %d", 0, runningCnt) + require.EqualValuesf(t, 0, runningCnt, "expect %d workers running, but got %d", 0, runningCnt) freeCnt = p.Free() - assert.EqualValuesf(t, numWorker, freeCnt, "expect %d free workers, but got %d", numWorker, freeCnt) + require.EqualValuesf(t, numWorker, freeCnt, "expect %d free workers, but got %d", numWorker, freeCnt) } func TestWithDisablePurgePool(t *testing.T) { @@ -637,9 +856,9 @@ func testPoolFuncWithDisablePurge(t *testing.T, p *PoolWithFunc, numWorker int, wg1.Wait() runningCnt := p.Running() - assert.EqualValuesf(t, numWorker, runningCnt, "expect %d workers running, but got %d", numWorker, runningCnt) + require.EqualValuesf(t, numWorker, runningCnt, "expect %d workers running, but got %d", numWorker, runningCnt) freeCnt := p.Free() - assert.EqualValuesf(t, 0, freeCnt, "expect %d free workers, but got %d", 0, freeCnt) + require.EqualValuesf(t, 0, freeCnt, "expect %d free workers, but got %d", 0, freeCnt) // Finish all tasks and sleep for a while to wait for purging, since we've disabled purge mechanism, // we should see that all workers are still running after the sleep. @@ -648,17 +867,17 @@ func testPoolFuncWithDisablePurge(t *testing.T, p *PoolWithFunc, numWorker int, time.Sleep(waitForPurge + waitForPurge/2) runningCnt = p.Running() - assert.EqualValuesf(t, numWorker, runningCnt, "expect %d workers running, but got %d", numWorker, runningCnt) + require.EqualValuesf(t, numWorker, runningCnt, "expect %d workers running, but got %d", numWorker, runningCnt) freeCnt = p.Free() - assert.EqualValuesf(t, 0, freeCnt, "expect %d free workers, but got %d", 0, freeCnt) + require.EqualValuesf(t, 0, freeCnt, "expect %d free workers, but got %d", 0, freeCnt) err := p.ReleaseTimeout(waitForPurge + waitForPurge/2) - assert.NoErrorf(t, err, "release pool failed: %v", err) + require.NoErrorf(t, err, "release pool failed: %v", err) runningCnt = p.Running() - assert.EqualValuesf(t, 0, runningCnt, "expect %d workers running, but got %d", 0, runningCnt) + require.EqualValuesf(t, 0, runningCnt, "expect %d workers running, but got %d", 0, runningCnt) freeCnt = p.Free() - assert.EqualValuesf(t, numWorker, freeCnt, "expect %d free workers, but got %d", numWorker, freeCnt) + require.EqualValuesf(t, numWorker, freeCnt, "expect %d free workers, but got %d", numWorker, freeCnt) } func TestWithDisablePurgePoolFunc(t *testing.T) { @@ -692,10 +911,12 @@ func TestWithDisablePurgeAndWithExpirationPoolFunc(t *testing.T) { func TestInfinitePoolWithFunc(t *testing.T) { c := make(chan struct{}) - p, _ := NewPoolWithFunc(-1, func(i any) { + p, err := NewPoolWithFunc(-1, func(i any) { demoPoolFunc(i) <-c }) + require.NoErrorf(t, err, "create pool with func failed: %v", err) + defer p.Release() _ = p.Invoke(10) _ = p.Invoke(10) c <- struct{}{} @@ -710,16 +931,40 @@ func TestInfinitePoolWithFunc(t *testing.T) { if capacity := p.Cap(); capacity != -1 { t.Fatalf("expect capacity: -1 but got %d", capacity) } - var err error _, err = NewPoolWithFunc(-1, demoPoolFunc, WithPreAlloc(true)) - if err != ErrInvalidPreAllocSize { - t.Errorf("expect ErrInvalidPreAllocSize but got %v", err) + require.ErrorIsf(t, err, ErrInvalidPreAllocSize, "expect ErrInvalidPreAllocSize but got %v", err) +} + +func TestInfinitePoolWithFuncGeneric(t *testing.T) { + c := make(chan struct{}) + p, err := NewPoolWithFuncGeneric(-1, func(i int) { + demoPoolFuncInt(i) + <-c + }) + require.NoErrorf(t, err, "create pool with func failed: %v", err) + defer p.Release() + _ = p.Invoke(10) + _ = p.Invoke(10) + c <- struct{}{} + c <- struct{}{} + if n := p.Running(); n != 2 { + t.Errorf("expect 2 workers running, but got %d", n) } + if n := p.Free(); n != -1 { + t.Errorf("expect -1 of free workers by unlimited pool, but got %d", n) + } + p.Tune(10) + if capacity := p.Cap(); capacity != -1 { + t.Fatalf("expect capacity: -1 but got %d", capacity) + } + _, err = NewPoolWithFuncGeneric(-1, demoPoolFuncInt, WithPreAlloc(true)) + require.ErrorIsf(t, err, ErrInvalidPreAllocSize, "expect ErrInvalidPreAllocSize but got %v", err) } func TestReleaseWhenRunningPool(t *testing.T) { var wg sync.WaitGroup - p, _ := NewPool(1) + p, err := NewPool(1) + require.NoErrorf(t, err, "create pool failed: %v", err) wg.Add(2) go func() { t.Log("start aaa") @@ -759,10 +1004,12 @@ func TestReleaseWhenRunningPool(t *testing.T) { func TestReleaseWhenRunningPoolWithFunc(t *testing.T) { var wg sync.WaitGroup - p, _ := NewPoolWithFunc(1, func(i any) { + p, err := NewPoolWithFunc(1, func(i any) { t.Log("do task", i) time.Sleep(1 * time.Second) }) + require.NoErrorf(t, err, "create pool with func failed: %v", err) + wg.Add(2) go func() { t.Log("start aaa") @@ -792,15 +1039,61 @@ func TestReleaseWhenRunningPoolWithFunc(t *testing.T) { wg.Wait() } +func TestReleaseWhenRunningPoolWithFuncGeneric(t *testing.T) { + var wg sync.WaitGroup + p, err := NewPoolWithFuncGeneric(1, func(i int) { + t.Log("do task", i) + time.Sleep(1 * time.Second) + }) + require.NoErrorf(t, err, "create pool with func failed: %v", err) + wg.Add(2) + + go func() { + t.Log("start aaa") + defer func() { + wg.Done() + t.Log("stop aaa") + }() + for i := 0; i < 30; i++ { + _ = p.Invoke(i) + } + }() + + go func() { + t.Log("start bbb") + defer func() { + wg.Done() + t.Log("stop bbb") + }() + for i := 100; i < 130; i++ { + _ = p.Invoke(i) + } + }() + + time.Sleep(3 * time.Second) + p.Release() + t.Log("wait for all goroutines to exit...") + wg.Wait() +} + func TestRestCodeCoverage(t *testing.T) { _, err := NewPool(-1, WithExpiryDuration(-1)) - t.Log(err) + require.ErrorIs(t, err, ErrInvalidPoolExpiry) _, err = NewPool(1, WithExpiryDuration(-1)) - t.Log(err) + require.ErrorIs(t, err, ErrInvalidPoolExpiry) _, err = NewPoolWithFunc(-1, demoPoolFunc, WithExpiryDuration(-1)) - t.Log(err) + require.ErrorIs(t, err, ErrInvalidPoolExpiry) _, err = NewPoolWithFunc(1, demoPoolFunc, WithExpiryDuration(-1)) - t.Log(err) + require.ErrorIs(t, err, ErrInvalidPoolExpiry) + _, err = NewPoolWithFunc(1, nil, WithExpiryDuration(-1)) + require.ErrorIs(t, err, ErrLackPoolFunc) + _, err = NewPoolWithFuncGeneric(-1, demoPoolFuncInt, WithExpiryDuration(-1)) + require.ErrorIs(t, err, ErrInvalidPoolExpiry) + _, err = NewPoolWithFuncGeneric(1, demoPoolFuncInt, WithExpiryDuration(-1)) + require.ErrorIs(t, err, ErrInvalidPoolExpiry) + var fn func(i int) + _, err = NewPoolWithFuncGeneric(1, fn, WithExpiryDuration(-1)) + require.ErrorIs(t, err, ErrLackPoolFunc) options := Options{} options.ExpiryDuration = time.Duration(10) * time.Second @@ -824,74 +1117,106 @@ func TestRestCodeCoverage(t *testing.T) { p0.Tune(TestSize / 10) t.Logf("pool, after tuning capacity, capacity:%d, running:%d", p0.Cap(), p0.Running()) - pprem, _ := NewPool(TestSize, WithPreAlloc(true)) + p1, _ := NewPool(TestSize, WithPreAlloc(true)) defer func() { - _ = pprem.Submit(demoFunc) + _ = p1.Submit(demoFunc) }() - defer pprem.Release() + defer p1.Release() for i := 0; i < n; i++ { - _ = pprem.Submit(demoFunc) + _ = p1.Submit(demoFunc) } - t.Logf("pre-malloc pool, capacity:%d", pprem.Cap()) - t.Logf("pre-malloc pool, running workers number:%d", pprem.Running()) - t.Logf("pre-malloc pool, free workers number:%d", pprem.Free()) - pprem.Tune(TestSize) - pprem.Tune(TestSize / 10) - t.Logf("pre-malloc pool, after tuning capacity, capacity:%d, running:%d", pprem.Cap(), pprem.Running()) - - p, _ := NewPoolWithFunc(TestSize, demoPoolFunc) + t.Logf("pre-malloc pool, capacity:%d", p1.Cap()) + t.Logf("pre-malloc pool, running workers number:%d", p1.Running()) + t.Logf("pre-malloc pool, free workers number:%d", p1.Free()) + p1.Tune(TestSize) + p1.Tune(TestSize / 10) + t.Logf("pre-malloc pool, after tuning capacity, capacity:%d, running:%d", p1.Cap(), p1.Running()) + + p2, _ := NewPoolWithFunc(TestSize, demoPoolFunc) defer func() { - _ = p.Invoke(Param) + _ = p2.Invoke(Param) }() - defer p.Release() + defer p2.Release() for i := 0; i < n; i++ { - _ = p.Invoke(Param) + _ = p2.Invoke(Param) } time.Sleep(DefaultCleanIntervalTime) - t.Logf("pool with func, capacity:%d", p.Cap()) - t.Logf("pool with func, running workers number:%d", p.Running()) - t.Logf("pool with func, free workers number:%d", p.Free()) - p.Tune(TestSize) - p.Tune(TestSize / 10) - t.Logf("pool with func, after tuning capacity, capacity:%d, running:%d", p.Cap(), p.Running()) - - ppremWithFunc, _ := NewPoolWithFunc(TestSize, demoPoolFunc, WithPreAlloc(true)) + t.Logf("pool with func, capacity:%d", p2.Cap()) + t.Logf("pool with func, running workers number:%d", p2.Running()) + t.Logf("pool with func, free workers number:%d", p2.Free()) + p2.Tune(TestSize) + p2.Tune(TestSize / 10) + t.Logf("pool with func, after tuning capacity, capacity:%d, running:%d", p2.Cap(), p2.Running()) + + p3, _ := NewPoolWithFuncGeneric(TestSize, demoPoolFuncInt) defer func() { - _ = ppremWithFunc.Invoke(Param) + _ = p3.Invoke(Param) }() - defer ppremWithFunc.Release() + defer p3.Release() for i := 0; i < n; i++ { - _ = ppremWithFunc.Invoke(Param) + _ = p3.Invoke(Param) } time.Sleep(DefaultCleanIntervalTime) - t.Logf("pre-malloc pool with func, capacity:%d", ppremWithFunc.Cap()) - t.Logf("pre-malloc pool with func, running workers number:%d", ppremWithFunc.Running()) - t.Logf("pre-malloc pool with func, free workers number:%d", ppremWithFunc.Free()) - ppremWithFunc.Tune(TestSize) - ppremWithFunc.Tune(TestSize / 10) - t.Logf("pre-malloc pool with func, after tuning capacity, capacity:%d, running:%d", ppremWithFunc.Cap(), - ppremWithFunc.Running()) + t.Logf("pool with func, capacity:%d", p3.Cap()) + t.Logf("pool with func, running workers number:%d", p3.Running()) + t.Logf("pool with func, free workers number:%d", p3.Free()) + p3.Tune(TestSize) + p3.Tune(TestSize / 10) + t.Logf("pool with func, after tuning capacity, capacity:%d, running:%d", p3.Cap(), p3.Running()) + + p4, _ := NewPoolWithFunc(TestSize, demoPoolFunc, WithPreAlloc(true)) + defer func() { + _ = p4.Invoke(Param) + }() + defer p4.Release() + for i := 0; i < n; i++ { + _ = p4.Invoke(Param) + } + time.Sleep(DefaultCleanIntervalTime) + t.Logf("pre-malloc pool with func, capacity:%d", p4.Cap()) + t.Logf("pre-malloc pool with func, running workers number:%d", p4.Running()) + t.Logf("pre-malloc pool with func, free workers number:%d", p4.Free()) + p4.Tune(TestSize) + p4.Tune(TestSize / 10) + t.Logf("pre-malloc pool with func, after tuning capacity, capacity:%d, running:%d", p4.Cap(), + p4.Running()) + + p5, _ := NewPoolWithFuncGeneric(TestSize, demoPoolFuncInt, WithPreAlloc(true)) + defer func() { + _ = p5.Invoke(Param) + }() + defer p5.Release() + for i := 0; i < n; i++ { + _ = p5.Invoke(Param) + } + time.Sleep(DefaultCleanIntervalTime) + t.Logf("pre-malloc pool with func, capacity:%d", p5.Cap()) + t.Logf("pre-malloc pool with func, running workers number:%d", p5.Running()) + t.Logf("pre-malloc pool with func, free workers number:%d", p5.Free()) + p5.Tune(TestSize) + p5.Tune(TestSize / 10) + t.Logf("pre-malloc pool with func, after tuning capacity, capacity:%d, running:%d", p5.Cap(), + p5.Running()) } func TestPoolTuneScaleUp(t *testing.T) { c := make(chan struct{}) + // Test Pool p, _ := NewPool(2) for i := 0; i < 2; i++ { _ = p.Submit(func() { <-c }) } - if n := p.Running(); n != 2 { - t.Errorf("expect 2 workers running, but got %d", n) - } + n := p.Running() + require.EqualValuesf(t, 2, n, "expect 2 workers running, but got %d", p.Running()) // test pool tune scale up one p.Tune(3) _ = p.Submit(func() { <-c }) - if n := p.Running(); n != 3 { - t.Errorf("expect 3 workers running, but got %d", n) - } + n = p.Running() + require.EqualValuesf(t, 3, n, "expect 3 workers running, but got %d", n) // test pool tune scale up multiple var wg sync.WaitGroup for i := 0; i < 5; i++ { @@ -905,73 +1230,111 @@ func TestPoolTuneScaleUp(t *testing.T) { } p.Tune(8) wg.Wait() - if n := p.Running(); n != 8 { - t.Errorf("expect 8 workers running, but got %d", n) - } + n = p.Running() + require.EqualValuesf(t, 8, n, "expect 8 workers running, but got %d", n) for i := 0; i < 8; i++ { c <- struct{}{} } p.Release() - // test PoolWithFunc + // Test PoolWithFunc pf, _ := NewPoolWithFunc(2, func(_ any) { <-c }) for i := 0; i < 2; i++ { _ = pf.Invoke(1) } - if n := pf.Running(); n != 2 { - t.Errorf("expect 2 workers running, but got %d", n) - } + n = pf.Running() + require.EqualValuesf(t, 2, n, "expect 2 workers running, but got %d", n) // test pool tune scale up one pf.Tune(3) _ = pf.Invoke(1) - if n := pf.Running(); n != 3 { - t.Errorf("expect 3 workers running, but got %d", n) - } + n = pf.Running() + require.EqualValuesf(t, 3, n, "expect 3 workers running, but got %d", n) // test pool tune scale up multiple - var pfwg sync.WaitGroup for i := 0; i < 5; i++ { - pfwg.Add(1) + wg.Add(1) go func() { - defer pfwg.Done() + defer wg.Done() _ = pf.Invoke(1) }() } pf.Tune(8) - pfwg.Wait() - if n := pf.Running(); n != 8 { - t.Errorf("expect 8 workers running, but got %d", n) + wg.Wait() + n = pf.Running() + require.EqualValuesf(t, 8, n, "expect 8 workers running, but got %d", n) + for i := 0; i < 8; i++ { + c <- struct{}{} + } + pf.Release() + + // Test PoolWithFuncGeneric + pfg, _ := NewPoolWithFuncGeneric(2, func(_ int) { + <-c + }) + for i := 0; i < 2; i++ { + _ = pfg.Invoke(1) } + n = pfg.Running() + require.EqualValuesf(t, 2, n, "expect 2 workers running, but got %d", n) + // test pool tune scale up one + pfg.Tune(3) + _ = pfg.Invoke(1) + n = pfg.Running() + require.EqualValuesf(t, 3, n, "expect 3 workers running, but got %d", n) + // test pool tune scale up multiple + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = pfg.Invoke(1) + }() + } + pfg.Tune(8) + wg.Wait() + n = pfg.Running() + require.EqualValuesf(t, 8, n, "expect 8 workers running, but got %d", n) for i := 0; i < 8; i++ { c <- struct{}{} } close(c) - pf.Release() + pfg.Release() } func TestReleaseTimeout(t *testing.T) { - p, _ := NewPool(10) + p, err := NewPool(10) + require.NoError(t, err) for i := 0; i < 5; i++ { _ = p.Submit(func() { time.Sleep(time.Second) }) } - assert.NotZero(t, p.Running()) - err := p.ReleaseTimeout(2 * time.Second) - assert.NoError(t, err) + require.NotZero(t, p.Running()) + err = p.ReleaseTimeout(2 * time.Second) + require.NoError(t, err) - var pf *PoolWithFunc - pf, _ = NewPoolWithFunc(10, func(i any) { + pf, err := NewPoolWithFunc(10, func(i any) { dur := i.(time.Duration) time.Sleep(dur) }) + require.NoError(t, err) for i := 0; i < 5; i++ { _ = pf.Invoke(time.Second) } - assert.NotZero(t, pf.Running()) + require.NotZero(t, pf.Running()) err = pf.ReleaseTimeout(2 * time.Second) - assert.NoError(t, err) + require.NoError(t, err) + + pfg, err := NewPoolWithFuncGeneric(10, func(d time.Duration) { + time.Sleep(d) + }) + require.NoError(t, err) + for i := 0; i < 5; i++ { + _ = pfg.Invoke(time.Second) + } + require.NotZero(t, pfg.Running()) + err = pfg.ReleaseTimeout(2 * time.Second) + require.NoError(t, err) } func TestDefaultPoolReleaseTimeout(t *testing.T) { @@ -981,50 +1344,56 @@ func TestDefaultPoolReleaseTimeout(t *testing.T) { time.Sleep(time.Second) }) } - assert.NotZero(t, Running()) + require.NotZero(t, Running()) err := ReleaseTimeout(2 * time.Second) - assert.NoError(t, err) + require.NoError(t, err) } func TestMultiPool(t *testing.T) { - _, err := NewMultiPool(10, -1, 8) - assert.ErrorIs(t, err, ErrInvalidLoadBalancingStrategy) + _, err := NewMultiPool(-1, 10, 8) + require.ErrorIs(t, err, ErrInvalidMultiPoolSize) + _, err = NewMultiPool(10, -1, 8) + require.ErrorIs(t, err, ErrInvalidLoadBalancingStrategy) + _, err = NewMultiPool(10, 10, RoundRobin, WithExpiryDuration(-1)) + require.ErrorIs(t, err, ErrInvalidPoolExpiry) mp, err := NewMultiPool(10, 5, RoundRobin) testFn := func() { for i := 0; i < 50; i++ { err = mp.Submit(longRunningFunc) - assert.NoError(t, err) + require.NoError(t, err) } - assert.EqualValues(t, mp.Waiting(), 0) + require.EqualValues(t, mp.Waiting(), 0) _, err = mp.WaitingByIndex(-1) - assert.ErrorIs(t, err, ErrInvalidPoolIndex) + require.ErrorIs(t, err, ErrInvalidPoolIndex) _, err = mp.WaitingByIndex(11) - assert.ErrorIs(t, err, ErrInvalidPoolIndex) - assert.EqualValues(t, 50, mp.Running()) + require.ErrorIs(t, err, ErrInvalidPoolIndex) + require.EqualValues(t, 50, mp.Running()) _, err = mp.RunningByIndex(-1) - assert.ErrorIs(t, err, ErrInvalidPoolIndex) + require.ErrorIs(t, err, ErrInvalidPoolIndex) _, err = mp.RunningByIndex(11) - assert.ErrorIs(t, err, ErrInvalidPoolIndex) - assert.EqualValues(t, 0, mp.Free()) + require.ErrorIs(t, err, ErrInvalidPoolIndex) + require.EqualValues(t, 0, mp.Free()) _, err = mp.FreeByIndex(-1) - assert.ErrorIs(t, err, ErrInvalidPoolIndex) + require.ErrorIs(t, err, ErrInvalidPoolIndex) _, err = mp.FreeByIndex(11) - assert.ErrorIs(t, err, ErrInvalidPoolIndex) - assert.EqualValues(t, 50, mp.Cap()) - assert.False(t, mp.IsClosed()) + require.ErrorIs(t, err, ErrInvalidPoolIndex) + require.EqualValues(t, 50, mp.Cap()) + require.False(t, mp.IsClosed()) for i := 0; i < 10; i++ { n, _ := mp.WaitingByIndex(i) - assert.EqualValues(t, 0, n) + require.EqualValues(t, 0, n) n, _ = mp.RunningByIndex(i) - assert.EqualValues(t, 5, n) + require.EqualValues(t, 5, n) n, _ = mp.FreeByIndex(i) - assert.EqualValues(t, 0, n) + require.EqualValues(t, 0, n) } atomic.StoreInt32(&stopLongRunningFunc, 1) - assert.NoError(t, mp.ReleaseTimeout(3*time.Second)) - assert.Zero(t, mp.Running()) - assert.True(t, mp.IsClosed()) + require.NoError(t, mp.ReleaseTimeout(3*time.Second)) + require.ErrorIs(t, mp.ReleaseTimeout(3*time.Second), ErrPoolClosed) + require.ErrorIs(t, mp.Submit(nil), ErrPoolClosed) + require.Zero(t, mp.Running()) + require.True(t, mp.IsClosed()) atomic.StoreInt32(&stopLongRunningFunc, 0) } testFn() @@ -1042,45 +1411,52 @@ func TestMultiPool(t *testing.T) { } func TestMultiPoolWithFunc(t *testing.T) { - _, err := NewMultiPoolWithFunc(10, -1, longRunningPoolFunc, 8) - assert.ErrorIs(t, err, ErrInvalidLoadBalancingStrategy) + _, err := NewMultiPoolWithFunc(-1, 10, longRunningPoolFunc, 8) + require.ErrorIs(t, err, ErrInvalidMultiPoolSize) + _, err = NewMultiPoolWithFunc(10, -1, longRunningPoolFunc, 8) + require.ErrorIs(t, err, ErrInvalidLoadBalancingStrategy) + _, err = NewMultiPoolWithFunc(10, 10, longRunningPoolFunc, RoundRobin, WithExpiryDuration(-1)) + require.ErrorIs(t, err, ErrInvalidPoolExpiry) + ch := make(chan struct{}) mp, err := NewMultiPoolWithFunc(10, 5, longRunningPoolFunc, RoundRobin) testFn := func() { for i := 0; i < 50; i++ { - err = mp.Invoke(i) - assert.NoError(t, err) + err = mp.Invoke(ch) + require.NoError(t, err) } - assert.EqualValues(t, mp.Waiting(), 0) + require.EqualValues(t, mp.Waiting(), 0) _, err = mp.WaitingByIndex(-1) - assert.ErrorIs(t, err, ErrInvalidPoolIndex) + require.ErrorIs(t, err, ErrInvalidPoolIndex) _, err = mp.WaitingByIndex(11) - assert.ErrorIs(t, err, ErrInvalidPoolIndex) - assert.EqualValues(t, 50, mp.Running()) + require.ErrorIs(t, err, ErrInvalidPoolIndex) + require.EqualValues(t, 50, mp.Running()) _, err = mp.RunningByIndex(-1) - assert.ErrorIs(t, err, ErrInvalidPoolIndex) + require.ErrorIs(t, err, ErrInvalidPoolIndex) _, err = mp.RunningByIndex(11) - assert.ErrorIs(t, err, ErrInvalidPoolIndex) - assert.EqualValues(t, 0, mp.Free()) + require.ErrorIs(t, err, ErrInvalidPoolIndex) + require.EqualValues(t, 0, mp.Free()) _, err = mp.FreeByIndex(-1) - assert.ErrorIs(t, err, ErrInvalidPoolIndex) + require.ErrorIs(t, err, ErrInvalidPoolIndex) _, err = mp.FreeByIndex(11) - assert.ErrorIs(t, err, ErrInvalidPoolIndex) - assert.EqualValues(t, 50, mp.Cap()) - assert.False(t, mp.IsClosed()) + require.ErrorIs(t, err, ErrInvalidPoolIndex) + require.EqualValues(t, 50, mp.Cap()) + require.False(t, mp.IsClosed()) for i := 0; i < 10; i++ { n, _ := mp.WaitingByIndex(i) - assert.EqualValues(t, 0, n) + require.EqualValues(t, 0, n) n, _ = mp.RunningByIndex(i) - assert.EqualValues(t, 5, n) + require.EqualValues(t, 5, n) n, _ = mp.FreeByIndex(i) - assert.EqualValues(t, 0, n) + require.EqualValues(t, 0, n) } - atomic.StoreInt32(&stopLongRunningPoolFunc, 1) - assert.NoError(t, mp.ReleaseTimeout(3*time.Second)) - assert.Zero(t, mp.Running()) - assert.True(t, mp.IsClosed()) - atomic.StoreInt32(&stopLongRunningPoolFunc, 0) + close(ch) + require.NoError(t, mp.ReleaseTimeout(3*time.Second)) + require.ErrorIs(t, mp.ReleaseTimeout(3*time.Second), ErrPoolClosed) + require.ErrorIs(t, mp.Invoke(nil), ErrPoolClosed) + require.Zero(t, mp.Running()) + require.True(t, mp.IsClosed()) + ch = make(chan struct{}) } testFn() @@ -1095,3 +1471,65 @@ func TestMultiPoolWithFunc(t *testing.T) { mp.Tune(10) } + +func TestMultiPoolWithFuncGeneric(t *testing.T) { + _, err := NewMultiPoolWithFuncGeneric(-1, 10, longRunningPoolFuncCh, 8) + require.ErrorIs(t, err, ErrInvalidMultiPoolSize) + _, err = NewMultiPoolWithFuncGeneric(10, -1, longRunningPoolFuncCh, 8) + require.ErrorIs(t, err, ErrInvalidLoadBalancingStrategy) + _, err = NewMultiPoolWithFuncGeneric(10, 10, longRunningPoolFuncCh, RoundRobin, WithExpiryDuration(-1)) + require.ErrorIs(t, err, ErrInvalidPoolExpiry) + + ch := make(chan struct{}) + mp, err := NewMultiPoolWithFuncGeneric(10, 5, longRunningPoolFuncCh, RoundRobin) + testFn := func() { + for i := 0; i < 50; i++ { + err = mp.Invoke(ch) + require.NoError(t, err) + } + require.EqualValues(t, mp.Waiting(), 0) + _, err = mp.WaitingByIndex(-1) + require.ErrorIs(t, err, ErrInvalidPoolIndex) + _, err = mp.WaitingByIndex(11) + require.ErrorIs(t, err, ErrInvalidPoolIndex) + require.EqualValues(t, 50, mp.Running()) + _, err = mp.RunningByIndex(-1) + require.ErrorIs(t, err, ErrInvalidPoolIndex) + _, err = mp.RunningByIndex(11) + require.ErrorIs(t, err, ErrInvalidPoolIndex) + require.EqualValues(t, 0, mp.Free()) + _, err = mp.FreeByIndex(-1) + require.ErrorIs(t, err, ErrInvalidPoolIndex) + _, err = mp.FreeByIndex(11) + require.ErrorIs(t, err, ErrInvalidPoolIndex) + require.EqualValues(t, 50, mp.Cap()) + require.False(t, mp.IsClosed()) + for i := 0; i < 10; i++ { + n, _ := mp.WaitingByIndex(i) + require.EqualValues(t, 0, n) + n, _ = mp.RunningByIndex(i) + require.EqualValues(t, 5, n) + n, _ = mp.FreeByIndex(i) + require.EqualValues(t, 0, n) + } + close(ch) + require.NoError(t, mp.ReleaseTimeout(3*time.Second)) + require.ErrorIs(t, mp.ReleaseTimeout(3*time.Second), ErrPoolClosed) + require.ErrorIs(t, mp.Invoke(nil), ErrPoolClosed) + require.Zero(t, mp.Running()) + require.True(t, mp.IsClosed()) + ch = make(chan struct{}) + } + testFn() + + mp.Reboot() + testFn() + + mp, err = NewMultiPoolWithFuncGeneric(10, 5, longRunningPoolFuncCh, LeastTasks) + testFn() + + mp.Reboot() + testFn() + + mp.Tune(10) +} diff --git a/multipool.go b/multipool.go index 3f78ce2..342b038 100644 --- a/multipool.go +++ b/multipool.go @@ -25,6 +25,7 @@ package ants import ( "errors" "fmt" + "math" "strings" "sync/atomic" "time" @@ -58,6 +59,10 @@ type MultiPool struct { // NewMultiPool instantiates a MultiPool with a size of the pool list and a size // per pool, and the load-balancing strategy. func NewMultiPool(size, sizePerPool int, lbs LoadBalancingStrategy, options ...Option) (*MultiPool, error) { + if size <= 0 { + return nil, ErrInvalidMultiPoolSize + } + if lbs != RoundRobin && lbs != LeastTasks { return nil, ErrInvalidLoadBalancingStrategy } @@ -69,16 +74,13 @@ func NewMultiPool(size, sizePerPool int, lbs LoadBalancingStrategy, options ...O } pools[i] = pool } - return &MultiPool{pools: pools, lbs: lbs}, nil + return &MultiPool{pools: pools, index: math.MaxUint32, lbs: lbs}, nil } func (mp *MultiPool) next(lbs LoadBalancingStrategy) (idx int) { switch lbs { case RoundRobin: - if idx = int((atomic.AddUint32(&mp.index, 1) - 1) % uint32(len(mp.pools))); idx == -1 { - idx = 0 - } - return + return int(atomic.AddUint32(&mp.index, 1) % uint32(len(mp.pools))) case LeastTasks: leastTasks := 1<<31 - 1 for i, pool := range mp.pools { diff --git a/multipool_func.go b/multipool_func.go index ed7e1dc..7b4b6e5 100644 --- a/multipool_func.go +++ b/multipool_func.go @@ -25,6 +25,7 @@ package ants import ( "errors" "fmt" + "math" "strings" "sync/atomic" "time" @@ -47,6 +48,10 @@ type MultiPoolWithFunc struct { // NewMultiPoolWithFunc instantiates a MultiPoolWithFunc with a size of the pool list and a size // per pool, and the load-balancing strategy. func NewMultiPoolWithFunc(size, sizePerPool int, fn func(any), lbs LoadBalancingStrategy, options ...Option) (*MultiPoolWithFunc, error) { + if size <= 0 { + return nil, ErrInvalidMultiPoolSize + } + if lbs != RoundRobin && lbs != LeastTasks { return nil, ErrInvalidLoadBalancingStrategy } @@ -58,16 +63,13 @@ func NewMultiPoolWithFunc(size, sizePerPool int, fn func(any), lbs LoadBalancing } pools[i] = pool } - return &MultiPoolWithFunc{pools: pools, lbs: lbs}, nil + return &MultiPoolWithFunc{pools: pools, index: math.MaxUint32, lbs: lbs}, nil } func (mp *MultiPoolWithFunc) next(lbs LoadBalancingStrategy) (idx int) { switch lbs { case RoundRobin: - if idx = int((atomic.AddUint32(&mp.index, 1) - 1) % uint32(len(mp.pools))); idx == -1 { - idx = 0 - } - return + return int(atomic.AddUint32(&mp.index, 1) % uint32(len(mp.pools))) case LeastTasks: leastTasks := 1<<31 - 1 for i, pool := range mp.pools { diff --git a/multipool_func_generic.go b/multipool_func_generic.go new file mode 100644 index 0000000..f5931e5 --- /dev/null +++ b/multipool_func_generic.go @@ -0,0 +1,215 @@ +// MIT License + +// Copyright (c) 2025 Andy Pan + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package ants + +import ( + "errors" + "fmt" + "math" + "strings" + "sync/atomic" + "time" + + "golang.org/x/sync/errgroup" +) + +// MultiPoolWithFuncGeneric is the generic version of MultiPoolWithFunc. +type MultiPoolWithFuncGeneric[T any] struct { + pools []*PoolWithFuncGeneric[T] + index uint32 + state int32 + lbs LoadBalancingStrategy +} + +// NewMultiPoolWithFuncGeneric instantiates a MultiPoolWithFunc with a size of the pool list and a size +// per pool, and the load-balancing strategy. +func NewMultiPoolWithFuncGeneric[T any](size, sizePerPool int, fn func(T), lbs LoadBalancingStrategy, options ...Option) (*MultiPoolWithFuncGeneric[T], error) { + if size <= 0 { + return nil, ErrInvalidMultiPoolSize + } + + if lbs != RoundRobin && lbs != LeastTasks { + return nil, ErrInvalidLoadBalancingStrategy + } + pools := make([]*PoolWithFuncGeneric[T], size) + for i := 0; i < size; i++ { + pool, err := NewPoolWithFuncGeneric(sizePerPool, fn, options...) + if err != nil { + return nil, err + } + pools[i] = pool + } + return &MultiPoolWithFuncGeneric[T]{pools: pools, index: math.MaxUint32, lbs: lbs}, nil +} + +func (mp *MultiPoolWithFuncGeneric[T]) next(lbs LoadBalancingStrategy) (idx int) { + switch lbs { + case RoundRobin: + return int(atomic.AddUint32(&mp.index, 1) % uint32(len(mp.pools))) + case LeastTasks: + leastTasks := 1<<31 - 1 + for i, pool := range mp.pools { + if n := pool.Running(); n < leastTasks { + leastTasks = n + idx = i + } + } + return + } + return -1 +} + +// Invoke submits a task to a pool selected by the load-balancing strategy. +func (mp *MultiPoolWithFuncGeneric[T]) Invoke(args T) (err error) { + if mp.IsClosed() { + return ErrPoolClosed + } + + if err = mp.pools[mp.next(mp.lbs)].Invoke(args); err == nil { + return + } + if err == ErrPoolOverload && mp.lbs == RoundRobin { + return mp.pools[mp.next(LeastTasks)].Invoke(args) + } + return +} + +// Running returns the number of the currently running workers across all pools. +func (mp *MultiPoolWithFuncGeneric[T]) Running() (n int) { + for _, pool := range mp.pools { + n += pool.Running() + } + return +} + +// RunningByIndex returns the number of the currently running workers in the specific pool. +func (mp *MultiPoolWithFuncGeneric[T]) RunningByIndex(idx int) (int, error) { + if idx < 0 || idx >= len(mp.pools) { + return -1, ErrInvalidPoolIndex + } + return mp.pools[idx].Running(), nil +} + +// Free returns the number of available workers across all pools. +func (mp *MultiPoolWithFuncGeneric[T]) Free() (n int) { + for _, pool := range mp.pools { + n += pool.Free() + } + return +} + +// FreeByIndex returns the number of available workers in the specific pool. +func (mp *MultiPoolWithFuncGeneric[T]) FreeByIndex(idx int) (int, error) { + if idx < 0 || idx >= len(mp.pools) { + return -1, ErrInvalidPoolIndex + } + return mp.pools[idx].Free(), nil +} + +// Waiting returns the number of the currently waiting tasks across all pools. +func (mp *MultiPoolWithFuncGeneric[T]) Waiting() (n int) { + for _, pool := range mp.pools { + n += pool.Waiting() + } + return +} + +// WaitingByIndex returns the number of the currently waiting tasks in the specific pool. +func (mp *MultiPoolWithFuncGeneric[T]) WaitingByIndex(idx int) (int, error) { + if idx < 0 || idx >= len(mp.pools) { + return -1, ErrInvalidPoolIndex + } + return mp.pools[idx].Waiting(), nil +} + +// Cap returns the capacity of this multi-pool. +func (mp *MultiPoolWithFuncGeneric[T]) Cap() (n int) { + for _, pool := range mp.pools { + n += pool.Cap() + } + return +} + +// Tune resizes each pool in multi-pool. +// +// Note that this method doesn't resize the overall +// capacity of multi-pool. +func (mp *MultiPoolWithFuncGeneric[T]) Tune(size int) { + for _, pool := range mp.pools { + pool.Tune(size) + } +} + +// IsClosed indicates whether the multi-pool is closed. +func (mp *MultiPoolWithFuncGeneric[T]) IsClosed() bool { + return atomic.LoadInt32(&mp.state) == CLOSED +} + +// ReleaseTimeout closes the multi-pool with a timeout, +// it waits all pools to be closed before timing out. +func (mp *MultiPoolWithFuncGeneric[T]) ReleaseTimeout(timeout time.Duration) error { + if !atomic.CompareAndSwapInt32(&mp.state, OPENED, CLOSED) { + return ErrPoolClosed + } + + errCh := make(chan error, len(mp.pools)) + var wg errgroup.Group + for i, pool := range mp.pools { + func(p *PoolWithFuncGeneric[T], idx int) { + wg.Go(func() error { + err := p.ReleaseTimeout(timeout) + if err != nil { + err = fmt.Errorf("pool %d: %v", idx, err) + } + errCh <- err + return err + }) + }(pool, i) + } + + _ = wg.Wait() + + var errStr strings.Builder + for i := 0; i < len(mp.pools); i++ { + if err := <-errCh; err != nil { + errStr.WriteString(err.Error()) + errStr.WriteString(" | ") + } + } + + if errStr.Len() == 0 { + return nil + } + + return errors.New(strings.TrimSuffix(errStr.String(), " | ")) +} + +// Reboot reboots a released multi-pool. +func (mp *MultiPoolWithFuncGeneric[T]) Reboot() { + if atomic.CompareAndSwapInt32(&mp.state, CLOSED, OPENED) { + atomic.StoreUint32(&mp.index, 0) + for _, pool := range mp.pools { + pool.Reboot() + } + } +} diff --git a/pool_func.go b/pool_func.go index 70f5fae..a181b43 100644 --- a/pool_func.go +++ b/pool_func.go @@ -26,8 +26,8 @@ package ants type PoolWithFunc struct { *poolCommon - // poolFunc is the unified function for processing tasks. - poolFunc func(any) + // fn is the unified function for processing tasks. + fn func(any) } // Invoke passes arguments to the pool. @@ -36,14 +36,14 @@ type PoolWithFunc struct { // but what calls for special attention is that you will get blocked with the last // Pool.Invoke() call once the current Pool runs out of its capacity, and to avoid this, // you should instantiate a PoolWithFunc with ants.WithNonblocking(true). -func (p *PoolWithFunc) Invoke(args any) error { +func (p *PoolWithFunc) Invoke(arg any) error { if p.IsClosed() { return ErrPoolClosed } w, err := p.retrieveWorker() if w != nil { - w.inputParam(args) + w.inputArg(arg) } return err } @@ -61,13 +61,13 @@ func NewPoolWithFunc(size int, pf func(any), options ...Option) (*PoolWithFunc, pool := &PoolWithFunc{ poolCommon: pc, - poolFunc: pf, + fn: pf, } pool.workerCache.New = func() any { return &goWorkerWithFunc{ pool: pool, - args: make(chan any, workerChanCap), + arg: make(chan any, workerChanCap), } } diff --git a/pool_func_generic.go b/pool_func_generic.go new file mode 100644 index 0000000..06ed3ca --- /dev/null +++ b/pool_func_generic.go @@ -0,0 +1,71 @@ +// MIT License + +// Copyright (c) 2025 Andy Pan + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package ants + +// PoolWithFuncGeneric is the generic version of PoolWithFunc. +type PoolWithFuncGeneric[T any] struct { + *poolCommon + + // fn is the unified function for processing tasks. + fn func(T) +} + +// Invoke passes the argument to the pool to start a new task. +func (p *PoolWithFuncGeneric[T]) Invoke(arg T) error { + if p.IsClosed() { + return ErrPoolClosed + } + + w, err := p.retrieveWorker() + if w != nil { + w.(*goWorkerWithFuncGeneric[T]).arg <- arg + } + return err +} + +// NewPoolWithFuncGeneric instantiates a PoolWithFuncGeneric[T] with customized options. +func NewPoolWithFuncGeneric[T any](size int, pf func(T), options ...Option) (*PoolWithFuncGeneric[T], error) { + if pf == nil { + return nil, ErrLackPoolFunc + } + + pc, err := newPool(size, options...) + if err != nil { + return nil, err + } + + pool := &PoolWithFuncGeneric[T]{ + poolCommon: pc, + fn: pf, + } + + pool.workerCache.New = func() any { + return &goWorkerWithFuncGeneric[T]{ + pool: pool, + arg: make(chan T, workerChanCap), + exit: make(chan struct{}, 1), + } + } + + return pool, nil +} diff --git a/worker.go b/worker.go index f8dd650..03b4bd7 100644 --- a/worker.go +++ b/worker.go @@ -31,6 +31,8 @@ import ( // it starts a goroutine that accepts tasks and // performs function calls. type goWorker struct { + worker + // pool who owns this worker. pool *Pool @@ -64,11 +66,11 @@ func (w *goWorker) run() { w.pool.cond.Signal() }() - for f := range w.task { - if f == nil { + for fn := range w.task { + if fn == nil { return } - f() + fn() if ok := w.pool.revertWorker(w); !ok { return } @@ -91,7 +93,3 @@ func (w *goWorker) setLastUsedTime(t time.Time) { func (w *goWorker) inputFunc(fn func()) { w.task <- fn } - -func (w *goWorker) inputParam(any) { - panic("unreachable") -} diff --git a/worker_func.go b/worker_func.go index 76c697a..8437e40 100644 --- a/worker_func.go +++ b/worker_func.go @@ -31,11 +31,13 @@ import ( // it starts a goroutine that accepts tasks and // performs function calls. type goWorkerWithFunc struct { + worker + // pool who owns this worker. pool *PoolWithFunc - // args is a job should be done. - args chan any + // arg is the argument for the function. + arg chan any // lastUsed will be updated when putting a worker back into queue. lastUsed time.Time @@ -64,11 +66,11 @@ func (w *goWorkerWithFunc) run() { w.pool.cond.Signal() }() - for args := range w.args { - if args == nil { + for arg := range w.arg { + if arg == nil { return } - w.pool.poolFunc(args) + w.pool.fn(arg) if ok := w.pool.revertWorker(w); !ok { return } @@ -77,7 +79,7 @@ func (w *goWorkerWithFunc) run() { } func (w *goWorkerWithFunc) finish() { - w.args <- nil + w.arg <- nil } func (w *goWorkerWithFunc) lastUsedTime() time.Time { @@ -88,10 +90,6 @@ func (w *goWorkerWithFunc) setLastUsedTime(t time.Time) { w.lastUsed = t } -func (w *goWorkerWithFunc) inputFunc(func()) { - panic("unreachable") -} - -func (w *goWorkerWithFunc) inputParam(arg any) { - w.args <- arg +func (w *goWorkerWithFunc) inputArg(arg any) { + w.arg <- arg } diff --git a/worker_func_generic.go b/worker_func_generic.go new file mode 100644 index 0000000..a76d109 --- /dev/null +++ b/worker_func_generic.go @@ -0,0 +1,96 @@ +// MIT License + +// Copyright (c) 2025 Andy Pan + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package ants + +import ( + "runtime/debug" + "time" +) + +// goWorkerWithFunc is the actual executor who runs the tasks, +// it starts a goroutine that accepts tasks and +// performs function calls. +type goWorkerWithFuncGeneric[T any] struct { + worker + + // pool who owns this worker. + pool *PoolWithFuncGeneric[T] + + // arg is a job should be done. + arg chan T + + // exit signals the goroutine to exit. + exit chan struct{} + + // lastUsed will be updated when putting a worker back into queue. + lastUsed time.Time +} + +// run starts a goroutine to repeat the process +// that performs the function calls. +func (w *goWorkerWithFuncGeneric[T]) run() { + w.pool.addRunning(1) + go func() { + defer func() { + if w.pool.addRunning(-1) == 0 && w.pool.IsClosed() { + w.pool.once.Do(func() { + close(w.pool.allDone) + }) + } + w.pool.workerCache.Put(w) + if p := recover(); p != nil { + if ph := w.pool.options.PanicHandler; ph != nil { + ph(p) + } else { + w.pool.options.Logger.Printf("worker exits from panic: %v\n%s\n", p, debug.Stack()) + } + } + // Call Signal() here in case there are goroutines waiting for available workers. + w.pool.cond.Signal() + }() + + for { + select { + case <-w.exit: + return + case arg := <-w.arg: + w.pool.fn(arg) + if ok := w.pool.revertWorker(w); !ok { + return + } + } + } + }() +} + +func (w *goWorkerWithFuncGeneric[T]) finish() { + w.exit <- struct{}{} +} + +func (w *goWorkerWithFuncGeneric[T]) lastUsedTime() time.Time { + return w.lastUsed +} + +func (w *goWorkerWithFuncGeneric[T]) setLastUsedTime(t time.Time) { + w.lastUsed = t +} diff --git a/worker_loop_queue.go b/worker_loop_queue.go index a5451ab..52091f3 100644 --- a/worker_loop_queue.go +++ b/worker_loop_queue.go @@ -12,6 +12,9 @@ type loopQueue struct { } func newWorkerLoopQueue(size int) *loopQueue { + if size <= 0 { + return nil + } return &loopQueue{ items: make([]worker, size), size: size, @@ -39,10 +42,6 @@ func (wq *loopQueue) isEmpty() bool { } func (wq *loopQueue) insert(w worker) error { - if wq.size == 0 { - return errQueueIsReleased - } - if wq.isFull { return errQueueIsFull } diff --git a/worker_loop_queue_test.go b/worker_loop_queue_test.go index 755cf15..8e04394 100644 --- a/worker_loop_queue_test.go +++ b/worker_loop_queue_test.go @@ -6,15 +6,17 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewLoopQueue(t *testing.T) { size := 100 q := newWorkerLoopQueue(size) - assert.EqualValues(t, 0, q.len(), "Len error") - assert.Equal(t, true, q.isEmpty(), "IsEmpty error") - assert.Nil(t, q.detach(), "Dequeue error") + require.EqualValues(t, 0, q.len(), "Len error") + require.Equal(t, true, q.isEmpty(), "IsEmpty error") + require.Nil(t, q.detach(), "Dequeue error") + + require.Nil(t, newWorkerLoopQueue(0)) } func TestLoopQueue(t *testing.T) { @@ -27,9 +29,9 @@ func TestLoopQueue(t *testing.T) { break } } - assert.EqualValues(t, 5, q.len(), "Len error") + require.EqualValues(t, 5, q.len(), "Len error") _ = q.detach() - assert.EqualValues(t, 4, q.len(), "Len error") + require.EqualValues(t, 4, q.len(), "Len error") time.Sleep(time.Second) @@ -39,13 +41,13 @@ func TestLoopQueue(t *testing.T) { break } } - assert.EqualValues(t, 10, q.len(), "Len error") + require.EqualValues(t, 10, q.len(), "Len error") err := q.insert(&goWorker{lastUsed: time.Now()}) - assert.Error(t, err, "Enqueue, error") + require.Error(t, err, "Enqueue, error") q.refresh(time.Second) - assert.EqualValuesf(t, 6, q.len(), "Len error: %d", q.len()) + require.EqualValuesf(t, 6, q.len(), "Len error: %d", q.len()) } func TestRotatedQueueSearch(t *testing.T) { @@ -57,18 +59,18 @@ func TestRotatedQueueSearch(t *testing.T) { _ = q.insert(&goWorker{lastUsed: time.Now()}) - assert.EqualValues(t, 0, q.binarySearch(time.Now()), "index should be 0") - assert.EqualValues(t, -1, q.binarySearch(expiry1), "index should be -1") + require.EqualValues(t, 0, q.binarySearch(time.Now()), "index should be 0") + require.EqualValues(t, -1, q.binarySearch(expiry1), "index should be -1") // 2 expiry2 := time.Now() _ = q.insert(&goWorker{lastUsed: time.Now()}) - assert.EqualValues(t, -1, q.binarySearch(expiry1), "index should be -1") + require.EqualValues(t, -1, q.binarySearch(expiry1), "index should be -1") - assert.EqualValues(t, 0, q.binarySearch(expiry2), "index should be 0") + require.EqualValues(t, 0, q.binarySearch(expiry2), "index should be 0") - assert.EqualValues(t, 1, q.binarySearch(time.Now()), "index should be 1") + require.EqualValues(t, 1, q.binarySearch(time.Now()), "index should be 1") // more for i := 0; i < 5; i++ { @@ -83,7 +85,7 @@ func TestRotatedQueueSearch(t *testing.T) { err = q.insert(&goWorker{lastUsed: time.Now()}) } - assert.EqualValues(t, 7, q.binarySearch(expiry3), "index should be 7") + require.EqualValues(t, 7, q.binarySearch(expiry3), "index should be 7") // rotate for i := 0; i < 6; i++ { @@ -98,7 +100,7 @@ func TestRotatedQueueSearch(t *testing.T) { } // head = 6, tail = 5, insert direction -> // [expiry4, time, time, time, time, nil/tail, time/head, time, time, time] - assert.EqualValues(t, 0, q.binarySearch(expiry4), "index should be 0") + require.EqualValues(t, 0, q.binarySearch(expiry4), "index should be 0") for i := 0; i < 3; i++ { _ = q.detach() @@ -108,17 +110,17 @@ func TestRotatedQueueSearch(t *testing.T) { // head = 6, tail = 5, insert direction -> // [expiry4, time, time, time, time, expiry5, nil/tail, nil, nil, time/head] - assert.EqualValues(t, 5, q.binarySearch(expiry5), "index should be 5") + require.EqualValues(t, 5, q.binarySearch(expiry5), "index should be 5") for i := 0; i < 3; i++ { _ = q.insert(&goWorker{lastUsed: time.Now()}) } // head = 9, tail = 9, insert direction -> // [expiry4, time, time, time, time, expiry5, time, time, time, time/head/tail] - assert.EqualValues(t, -1, q.binarySearch(expiry2), "index should be -1") + require.EqualValues(t, -1, q.binarySearch(expiry2), "index should be -1") - assert.EqualValues(t, 9, q.binarySearch(q.items[9].lastUsedTime()), "index should be 9") - assert.EqualValues(t, 8, q.binarySearch(time.Now()), "index should be 8") + require.EqualValues(t, 9, q.binarySearch(q.items[9].lastUsedTime()), "index should be 9") + require.EqualValues(t, 8, q.binarySearch(time.Now()), "index should be 8") } func TestRetrieveExpiry(t *testing.T) { @@ -139,7 +141,7 @@ func TestRetrieveExpiry(t *testing.T) { } workers := q.refresh(u) - assert.EqualValues(t, expirew, workers, "expired workers aren't right") + require.EqualValues(t, expirew, workers, "expired workers aren't right") // test [ time, time, time, time, time, time+1s, time+1s, time+1s, time+1s, time+1s] time.Sleep(u) @@ -152,7 +154,7 @@ func TestRetrieveExpiry(t *testing.T) { workers2 := q.refresh(u) - assert.EqualValues(t, expirew, workers2, "expired workers aren't right") + require.EqualValues(t, expirew, workers2, "expired workers aren't right") // test [ time+1s, time+1s, time+1s, nil, nil, time+1s, time+1s, time+1s, time+1s, time+1s] for i := 0; i < size/2; i++ { @@ -172,5 +174,5 @@ func TestRetrieveExpiry(t *testing.T) { workers3 := q.refresh(u) - assert.EqualValues(t, expirew, workers3, "expired workers aren't right") + require.EqualValues(t, expirew, workers3, "expired workers aren't right") } diff --git a/worker_queue.go b/worker_queue.go index 1c44ee6..4131972 100644 --- a/worker_queue.go +++ b/worker_queue.go @@ -5,13 +5,8 @@ import ( "time" ) -var ( - // errQueueIsFull will be returned when the worker queue is full. - errQueueIsFull = errors.New("the queue is full") - - // errQueueIsReleased will be returned when trying to insert item to a released worker queue. - errQueueIsReleased = errors.New("the queue length is zero") -) +// errQueueIsFull will be returned when the worker queue is full. +var errQueueIsFull = errors.New("the queue is full") type worker interface { run() @@ -19,7 +14,7 @@ type worker interface { lastUsedTime() time.Time setLastUsedTime(t time.Time) inputFunc(func()) - inputParam(any) + inputArg(any) } type workerQueue interface { diff --git a/worker_stack.go b/worker_stack.go index 6b01abc..8eb12ab 100644 --- a/worker_stack.go +++ b/worker_stack.go @@ -13,57 +13,57 @@ func newWorkerStack(size int) *workerStack { } } -func (wq *workerStack) len() int { - return len(wq.items) +func (ws *workerStack) len() int { + return len(ws.items) } -func (wq *workerStack) isEmpty() bool { - return len(wq.items) == 0 +func (ws *workerStack) isEmpty() bool { + return len(ws.items) == 0 } -func (wq *workerStack) insert(w worker) error { - wq.items = append(wq.items, w) +func (ws *workerStack) insert(w worker) error { + ws.items = append(ws.items, w) return nil } -func (wq *workerStack) detach() worker { - l := wq.len() +func (ws *workerStack) detach() worker { + l := ws.len() if l == 0 { return nil } - w := wq.items[l-1] - wq.items[l-1] = nil // avoid memory leaks - wq.items = wq.items[:l-1] + w := ws.items[l-1] + ws.items[l-1] = nil // avoid memory leaks + ws.items = ws.items[:l-1] return w } -func (wq *workerStack) refresh(duration time.Duration) []worker { - n := wq.len() +func (ws *workerStack) refresh(duration time.Duration) []worker { + n := ws.len() if n == 0 { return nil } expiryTime := time.Now().Add(-duration) - index := wq.binarySearch(0, n-1, expiryTime) + index := ws.binarySearch(0, n-1, expiryTime) - wq.expiry = wq.expiry[:0] + ws.expiry = ws.expiry[:0] if index != -1 { - wq.expiry = append(wq.expiry, wq.items[:index+1]...) - m := copy(wq.items, wq.items[index+1:]) + ws.expiry = append(ws.expiry, ws.items[:index+1]...) + m := copy(ws.items, ws.items[index+1:]) for i := m; i < n; i++ { - wq.items[i] = nil + ws.items[i] = nil } - wq.items = wq.items[:m] + ws.items = ws.items[:m] } - return wq.expiry + return ws.expiry } -func (wq *workerStack) binarySearch(l, r int, expiryTime time.Time) int { +func (ws *workerStack) binarySearch(l, r int, expiryTime time.Time) int { for l <= r { mid := l + ((r - l) >> 1) // avoid overflow when computing mid - if expiryTime.Before(wq.items[mid].lastUsedTime()) { + if expiryTime.Before(ws.items[mid].lastUsedTime()) { r = mid - 1 } else { l = mid + 1 @@ -72,10 +72,10 @@ func (wq *workerStack) binarySearch(l, r int, expiryTime time.Time) int { return r } -func (wq *workerStack) reset() { - for i := 0; i < wq.len(); i++ { - wq.items[i].finish() - wq.items[i] = nil +func (ws *workerStack) reset() { + for i := 0; i < ws.len(); i++ { + ws.items[i].finish() + ws.items[i] = nil } - wq.items = wq.items[:0] + ws.items = ws.items[:0] } diff --git a/worker_stack_test.go b/worker_stack_test.go index 453d6e3..87fca0d 100644 --- a/worker_stack_test.go +++ b/worker_stack_test.go @@ -6,15 +6,15 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewWorkerStack(t *testing.T) { size := 100 q := newWorkerStack(size) - assert.EqualValues(t, 0, q.len(), "Len error") - assert.Equal(t, true, q.isEmpty(), "IsEmpty error") - assert.Nil(t, q.detach(), "Dequeue error") + require.EqualValues(t, 0, q.len(), "Len error") + require.Equal(t, true, q.isEmpty(), "IsEmpty error") + require.Nil(t, q.detach(), "Dequeue error") } func TestWorkerStack(t *testing.T) { @@ -26,7 +26,7 @@ func TestWorkerStack(t *testing.T) { break } } - assert.EqualValues(t, 5, q.len(), "Len error") + require.EqualValues(t, 5, q.len(), "Len error") expired := time.Now() @@ -43,9 +43,9 @@ func TestWorkerStack(t *testing.T) { t.Fatal("Enqueue error") } } - assert.EqualValues(t, 12, q.len(), "Len error") + require.EqualValues(t, 12, q.len(), "Len error") q.refresh(time.Second) - assert.EqualValues(t, 6, q.len(), "Len error") + require.EqualValues(t, 6, q.len(), "Len error") } // It seems that something wrong with time.Now() on Windows, not sure whether it is a bug on Windows, @@ -58,18 +58,18 @@ func TestSearch(t *testing.T) { _ = q.insert(&goWorker{lastUsed: time.Now()}) - assert.EqualValues(t, 0, q.binarySearch(0, q.len()-1, time.Now()), "index should be 0") - assert.EqualValues(t, -1, q.binarySearch(0, q.len()-1, expiry1), "index should be -1") + require.EqualValues(t, 0, q.binarySearch(0, q.len()-1, time.Now()), "index should be 0") + require.EqualValues(t, -1, q.binarySearch(0, q.len()-1, expiry1), "index should be -1") // 2 expiry2 := time.Now() _ = q.insert(&goWorker{lastUsed: time.Now()}) - assert.EqualValues(t, -1, q.binarySearch(0, q.len()-1, expiry1), "index should be -1") + require.EqualValues(t, -1, q.binarySearch(0, q.len()-1, expiry1), "index should be -1") - assert.EqualValues(t, 0, q.binarySearch(0, q.len()-1, expiry2), "index should be 0") + require.EqualValues(t, 0, q.binarySearch(0, q.len()-1, expiry2), "index should be 0") - assert.EqualValues(t, 1, q.binarySearch(0, q.len()-1, time.Now()), "index should be 1") + require.EqualValues(t, 1, q.binarySearch(0, q.len()-1, time.Now()), "index should be 1") // more for i := 0; i < 5; i++ { @@ -84,5 +84,5 @@ func TestSearch(t *testing.T) { _ = q.insert(&goWorker{lastUsed: time.Now()}) } - assert.EqualValues(t, 7, q.binarySearch(0, q.len()-1, expiry3), "index should be 7") + require.EqualValues(t, 7, q.binarySearch(0, q.len()-1, expiry3), "index should be 7") }