From 8a2b33c289a921f53518e795a205fba9d4bd7058 Mon Sep 17 00:00:00 2001 From: Dustin Deus Date: Sat, 15 Feb 2025 22:24:25 +0100 Subject: [PATCH] fix: deadlock when waiting on inflight events of a trigger (#1073) This PR fixes an issue where `WaitGroup.Wait` caused a deadlock with the progressing update events because waiting in the event loop starves it forever. Additionally, I removed the field `done`, which caused a race with a concurrent event from the WS providers. It also wasn't useful at all. I battle-tested the change with a load test scenario where connections are reestablished multiple times a second. --- v2/Makefile | 8 +++ v2/pkg/engine/resolve/resolve.go | 29 +++++------ v2/pkg/engine/resolve/resolve_test.go | 75 +++++++++++++-------------- 3 files changed, 55 insertions(+), 57 deletions(-) diff --git a/v2/Makefile b/v2/Makefile index 23fb63a451..4dde3fa63d 100644 --- a/v2/Makefile +++ b/v2/Makefile @@ -2,6 +2,11 @@ test: go test ./... +test-fresh: clean-testcache test + +test-stability: + @while $(MAKE) test-fresh; do :; done + .PHONY: test-quick test-quick: go test -count=1 ./... @@ -10,6 +15,9 @@ test-quick: test-race: go test -race ./... +clean-testcache: + go clean -testcache + # updateTestFixtures will update all! golden fixtures .PHONY: updateTestFixtures updateTestFixtures: diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 7ffc23b784..35a9219b12 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -279,8 +279,8 @@ type trigger struct { id uint64 cancel context.CancelFunc subscriptions map[*Context]*sub - inFlight *sync.WaitGroup - initialized bool + // initialized is set to true when the trigger is started and initialized + initialized bool } type sub struct { @@ -366,6 +366,7 @@ func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput } } +// handleEvents maintains the single threaded event loop that processes all events func (r *Resolver) handleEvents() { done := r.ctx.Done() heartbeat := time.NewTicker(r.multipartSubHeartbeatInterval) @@ -528,7 +529,6 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) id: triggerID, subscriptions: make(map[*Context]*sub), cancel: cancel, - inFlight: &sync.WaitGroup{}, } r.triggers[triggerID] = trig trig.subscriptions[add.ctx] = s @@ -676,22 +676,25 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { if skip { continue } - trig.inFlight.Add(1) fn := func() { r.executeSubscriptionUpdate(c, s, data) } - go func(fn func()) { - defer trig.inFlight.Done() + + // Needs to be executed in a separate goroutine to prevent blocking the event loop. + go func() { + + // Send the update to the executor channel to be executed on the main thread + // Only relevant for SSE/Multipart subscriptions if s.executor != nil { select { case <-r.ctx.Done(): case <-c.ctx.Done(): - case s.executor <- fn: + case s.executor <- fn: // Run the update on the main thread and close subscription } } else { fn() } - }(fn) + }() } } @@ -703,7 +706,7 @@ func (r *Resolver) shutdownTrigger(id uint64) { if !ok { return } - trig.inFlight.Wait() + count := len(trig.subscriptions) r.shutdownTriggerSubscriptions(id, nil) trig.cancel() @@ -1004,7 +1007,6 @@ func (r *Resolver) subscriptionInput(ctx *Context, subscription *GraphQLSubscrip } type subscriptionUpdater struct { - done bool debug bool triggerID uint64 ch chan subscriptionEvent @@ -1022,9 +1024,6 @@ func (s *subscriptionUpdater) Update(data []byte) { } defer s.updateSem.Release(1) - if s.done { - return - } select { case <-s.ctx.Done(): return @@ -1046,9 +1045,6 @@ func (s *subscriptionUpdater) Done() { } defer s.updateSem.Release(1) - if s.done { - return - } select { case <-s.ctx.Done(): return @@ -1057,7 +1053,6 @@ func (s *subscriptionUpdater) Done() { kind: subscriptionEventKindTriggerDone, }: } - s.done = true } type subscriptionEvent struct { diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index d1bc62852f..c4a41ebaed 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -4835,9 +4835,6 @@ func (s *SubscriptionRecorder) Write(p []byte) (n int, err error) { } func (s *SubscriptionRecorder) Flush() error { - if s.onFlush != nil { - s.onFlush(s.buf.Bytes()) - } s.mux.Lock() defer s.mux.Unlock() s.messages = append(s.messages, s.buf.String()) @@ -5168,7 +5165,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { fakeStream := createFakeStream(func(counter int) (message string, done bool) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 2 - }, 0, func(input []byte) { + }, 1*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) }) @@ -5185,14 +5182,14 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { assert.NoError(t, err) recorder.AwaitComplete(t, defaultTimeout) - assert.Equal(t, 3, len(recorder.Messages())) + messages := recorder.Messages() + + assert.Greater(t, len(messages), 2) time.Sleep(2 * resolver.multipartSubHeartbeatInterval) // Validate that despite the time, we don't see any heartbeats sent - assert.ElementsMatch(t, []string{ - `{"data":{"counter":0}}`, - `{"data":{"counter":1}}`, - `{"data":{"counter":2}}`, - }, recorder.Messages()) + assert.Contains(t, messages, `{"data":{"counter":0}}`) + assert.Contains(t, messages, `{"data":{"counter":1}}`) + assert.Contains(t, messages, `{"data":{"counter":2}}`) }) t.Run("should successfully delete multiple finished subscriptions", func(t *testing.T) { @@ -5243,7 +5240,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { fakeStream := createFakeStream(func(counter int) (message string, done bool) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 2 - }, 0, func(input []byte) { + }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }","extensions":{"foo":"bar"}}}`, string(input)) }) @@ -5257,12 +5254,12 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { err := resolver.AsyncResolveGraphQLSubscription(&ctx, plan, recorder, id) assert.NoError(t, err) recorder.AwaitComplete(t, defaultTimeout) - assert.Equal(t, 3, len(recorder.Messages())) - assert.ElementsMatch(t, []string{ - `{"data":{"counter":0}}`, - `{"data":{"counter":1}}`, - `{"data":{"counter":2}}`, - }, recorder.Messages()) + + messages := recorder.Messages() + assert.Len(t, messages, 3) + assert.Contains(t, messages, `{"data":{"counter":0}}`) + assert.Contains(t, messages, `{"data":{"counter":1}}`) + assert.Contains(t, messages, `{"data":{"counter":2}}`) }) t.Run("should propagate initial payload to stream", func(t *testing.T) { @@ -5271,7 +5268,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { fakeStream := createFakeStream(func(counter int) (message string, done bool) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 2 - }, 0, func(input []byte) { + }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"},"initial_payload":{"hello":"world"}}`, string(input)) }) @@ -5285,12 +5282,12 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { err := resolver.AsyncResolveGraphQLSubscription(&ctx, plan, recorder, id) assert.NoError(t, err) recorder.AwaitComplete(t, defaultTimeout) - assert.Equal(t, 3, len(recorder.Messages())) - assert.ElementsMatch(t, []string{ - `{"data":{"counter":0}}`, - `{"data":{"counter":1}}`, - `{"data":{"counter":2}}`, - }, recorder.Messages()) + + messages := recorder.Messages() + assert.Len(t, messages, 3) + assert.Contains(t, messages, `{"data":{"counter":0}}`) + assert.Contains(t, messages, `{"data":{"counter":1}}`) + assert.Contains(t, messages, `{"data":{"counter":2}}`) }) t.Run("should stop stream on unsubscribe subscription", func(t *testing.T) { @@ -5349,7 +5346,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { fakeStream := createFakeStream(func(counter int) (message string, done bool) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0 - }, 0, func(input []byte) { + }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) }) @@ -5378,7 +5375,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { fakeStream := createFakeStream(func(counter int) (message string, done bool) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0 - }, 0, func(input []byte) { + }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { countryUpdated { name time { local } } }"}}`, string(input)) }) @@ -5407,7 +5404,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { fakeStream := createFakeStream(func(counter int) (message string, done bool) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), false - }, 0, func(input []byte) { + }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) }) @@ -5465,7 +5462,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { fakeStream := createFakeStream(func(counter int) (message string, done bool) { defer started.Store(true) return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), true - }, 0, func(input []byte) { + }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) }) @@ -5474,11 +5471,6 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { buf: &bytes.Buffer{}, messages: []string{}, complete: atomic.Bool{}, - onFlush: func(p []byte) { - for !complete.Load() { - time.Sleep(time.Millisecond * 10) - } - }, } recorder.complete.Store(false) @@ -5490,14 +5482,17 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { assert.NoError(t, err) assert.Eventually(t, func() bool { return started.Load() - }, defaultTimeout, time.Millisecond*100) + }, defaultTimeout, time.Millisecond*10) + + assert.Len(t, resolver.triggers, 1) + var unsubscribeComplete atomic.Bool go func() { defer unsubscribeComplete.Store(true) err = resolver.AsyncUnsubscribeSubscription(id) assert.NoError(t, err) }() - assert.Len(t, resolver.triggers, 1) + complete.Store(true) assert.Eventually(t, unsubscribeComplete.Load, defaultTimeout, time.Millisecond*100) recorder.AwaitComplete(t, defaultTimeout) @@ -5573,7 +5568,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { return `{"id":1}`, false } return `{"id":2}`, true - }, 0, func(input []byte) { + }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000"}`, string(input)) }) @@ -5669,7 +5664,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { return `{"id":1}`, false } return `{"id":2}`, true - }, 0, func(input []byte) { + }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000"}`, string(input)) }) @@ -5763,7 +5758,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { return fmt.Sprintf(`{"id":%d}`, count), false } return `{"id":4}`, true - }, 0, func(input []byte) { + }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000"}`, string(input)) }) @@ -5858,7 +5853,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { return fmt.Sprintf(`{"id":"x.%d"}`, count), false } return `{"id":"x.4"}`, true - }, 0, func(input []byte) { + }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000"}`, string(input)) }) @@ -5957,7 +5952,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { return `{"id":"x.1"}`, false } return `{"id":"x.2"}`, true - }, 0, func(input []byte) { + }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000"}`, string(input)) })