From e56c9c93e41b0c9d121eb86aa65f0c3a05051ca9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Flc=E3=82=9B?= Date: Tue, 23 Apr 2024 00:12:36 +0800 Subject: [PATCH] feat(policy): Added `StrictRoundRobinPolicy` to solve the data race (#220) --- ent/driver/multi/policy.go | 8 ++++++++ ent/driver/multi/policy_test.go | 19 +++++++++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/ent/driver/multi/policy.go b/ent/driver/multi/policy.go index 4e8081d..f8c0dd1 100644 --- a/ent/driver/multi/policy.go +++ b/ent/driver/multi/policy.go @@ -2,6 +2,7 @@ package multi import ( "math/rand" + "sync/atomic" "entgo.io/ent/dialect" ) @@ -24,6 +25,13 @@ func RoundRobinPolicy() Policy { }) } +func StrictRoundRobinPolicy() Policy { + var i int64 + return PolicyFunc(func(drivers []dialect.Driver) dialect.Driver { + return drivers[int(atomic.LoadInt64(&i))%len(drivers)] + }) +} + func RandomPolicy() Policy { return PolicyFunc(func(drivers []dialect.Driver) dialect.Driver { return drivers[rand.Intn(len(drivers))] diff --git a/ent/driver/multi/policy_test.go b/ent/driver/multi/policy_test.go index d41d17a..3553c8c 100644 --- a/ent/driver/multi/policy_test.go +++ b/ent/driver/multi/policy_test.go @@ -1,6 +1,7 @@ package multi import ( + "sync/atomic" "testing" "entgo.io/ent/dialect" @@ -10,11 +11,13 @@ import ( var driver1, driver2, driver3 dialect.Driver func TestPolicy_RoundRobinPolicy(t *testing.T) { - p := RoundRobinPolicy() + p1 := RoundRobinPolicy() + p2 := StrictRoundRobinPolicy() drivers := []dialect.Driver{driver1, driver2, driver3} for i := 0; i < 10; i++ { - assert.Equal(t, drivers[i%3], p.Resolve(drivers)) + assert.Equal(t, drivers[i%3], p1.Resolve(drivers)) + assert.Equal(t, drivers[i%3], p2.Resolve(drivers)) } } @@ -26,3 +29,15 @@ func TestPolicy_RandomPolicy(t *testing.T) { assert.Contains(t, drivers, p.Resolve(drivers)) } } + +func BenchmarkPolicy_StrictRoundRobinPolicy(b *testing.B) { + p := StrictRoundRobinPolicy() + drivers := []dialect.Driver{driver1, driver2, driver3} + + var i int64 + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + assert.Equal(b, drivers[int(atomic.AddInt64(&i, 1))%3], p.Resolve(drivers)) + } + }) +}