Skip to content

Commit

Permalink
fix: deadlock when waiting on inflight events of a trigger (#1073)
Browse files Browse the repository at this point in the history
This PR fixes an issue where `WaitGroup.Wait` caused a deadlock with the
progressing update events because waiting in the event loop starves it
forever.

Additionally, I removed the field `done`, which caused a race with a
concurrent event from the WS providers. It also wasn't useful at all.

I battle-tested the change with a load test scenario where connections
are reestablished multiple times a second.
  • Loading branch information
StarpTech authored Feb 15, 2025
1 parent eb45cc5 commit 8a2b33c
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 57 deletions.
8 changes: 8 additions & 0 deletions v2/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
test:
go test ./...

test-fresh: clean-testcache test

test-stability:
@while $(MAKE) test-fresh; do :; done

.PHONY: test-quick
test-quick:
go test -count=1 ./...
Expand All @@ -10,6 +15,9 @@ test-quick:
test-race:
go test -race ./...

clean-testcache:
go clean -testcache

# updateTestFixtures will update all! golden fixtures
.PHONY: updateTestFixtures
updateTestFixtures:
Expand Down
29 changes: 12 additions & 17 deletions v2/pkg/engine/resolve/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ type trigger struct {
id uint64
cancel context.CancelFunc
subscriptions map[*Context]*sub
inFlight *sync.WaitGroup
initialized bool
// initialized is set to true when the trigger is started and initialized
initialized bool
}

type sub struct {
Expand Down Expand Up @@ -366,6 +366,7 @@ func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput
}
}

// handleEvents maintains the single threaded event loop that processes all events
func (r *Resolver) handleEvents() {
done := r.ctx.Done()
heartbeat := time.NewTicker(r.multipartSubHeartbeatInterval)
Expand Down Expand Up @@ -528,7 +529,6 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription)
id: triggerID,
subscriptions: make(map[*Context]*sub),
cancel: cancel,
inFlight: &sync.WaitGroup{},
}
r.triggers[triggerID] = trig
trig.subscriptions[add.ctx] = s
Expand Down Expand Up @@ -676,22 +676,25 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) {
if skip {
continue
}
trig.inFlight.Add(1)
fn := func() {
r.executeSubscriptionUpdate(c, s, data)
}
go func(fn func()) {
defer trig.inFlight.Done()

// Needs to be executed in a separate goroutine to prevent blocking the event loop.
go func() {

// Send the update to the executor channel to be executed on the main thread
// Only relevant for SSE/Multipart subscriptions
if s.executor != nil {
select {
case <-r.ctx.Done():
case <-c.ctx.Done():
case s.executor <- fn:
case s.executor <- fn: // Run the update on the main thread and close subscription
}
} else {
fn()
}
}(fn)
}()
}
}

Expand All @@ -703,7 +706,7 @@ func (r *Resolver) shutdownTrigger(id uint64) {
if !ok {
return
}
trig.inFlight.Wait()

count := len(trig.subscriptions)
r.shutdownTriggerSubscriptions(id, nil)
trig.cancel()
Expand Down Expand Up @@ -1004,7 +1007,6 @@ func (r *Resolver) subscriptionInput(ctx *Context, subscription *GraphQLSubscrip
}

type subscriptionUpdater struct {
done bool
debug bool
triggerID uint64
ch chan subscriptionEvent
Expand All @@ -1022,9 +1024,6 @@ func (s *subscriptionUpdater) Update(data []byte) {
}
defer s.updateSem.Release(1)

if s.done {
return
}
select {
case <-s.ctx.Done():
return
Expand All @@ -1046,9 +1045,6 @@ func (s *subscriptionUpdater) Done() {
}
defer s.updateSem.Release(1)

if s.done {
return
}
select {
case <-s.ctx.Done():
return
Expand All @@ -1057,7 +1053,6 @@ func (s *subscriptionUpdater) Done() {
kind: subscriptionEventKindTriggerDone,
}:
}
s.done = true
}

type subscriptionEvent struct {
Expand Down
75 changes: 35 additions & 40 deletions v2/pkg/engine/resolve/resolve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4835,9 +4835,6 @@ func (s *SubscriptionRecorder) Write(p []byte) (n int, err error) {
}

func (s *SubscriptionRecorder) Flush() error {
if s.onFlush != nil {
s.onFlush(s.buf.Bytes())
}
s.mux.Lock()
defer s.mux.Unlock()
s.messages = append(s.messages, s.buf.String())
Expand Down Expand Up @@ -5168,7 +5165,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {

fakeStream := createFakeStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 2
}, 0, func(input []byte) {
}, 1*time.Millisecond, func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input))
})

Expand All @@ -5185,14 +5182,14 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
assert.NoError(t, err)

recorder.AwaitComplete(t, defaultTimeout)
assert.Equal(t, 3, len(recorder.Messages()))
messages := recorder.Messages()

assert.Greater(t, len(messages), 2)
time.Sleep(2 * resolver.multipartSubHeartbeatInterval)
// Validate that despite the time, we don't see any heartbeats sent
assert.ElementsMatch(t, []string{
`{"data":{"counter":0}}`,
`{"data":{"counter":1}}`,
`{"data":{"counter":2}}`,
}, recorder.Messages())
assert.Contains(t, messages, `{"data":{"counter":0}}`)
assert.Contains(t, messages, `{"data":{"counter":1}}`)
assert.Contains(t, messages, `{"data":{"counter":2}}`)
})

t.Run("should successfully delete multiple finished subscriptions", func(t *testing.T) {
Expand Down Expand Up @@ -5243,7 +5240,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {

fakeStream := createFakeStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 2
}, 0, func(input []byte) {
}, 100*time.Millisecond, func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }","extensions":{"foo":"bar"}}}`, string(input))
})

Expand All @@ -5257,12 +5254,12 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
err := resolver.AsyncResolveGraphQLSubscription(&ctx, plan, recorder, id)
assert.NoError(t, err)
recorder.AwaitComplete(t, defaultTimeout)
assert.Equal(t, 3, len(recorder.Messages()))
assert.ElementsMatch(t, []string{
`{"data":{"counter":0}}`,
`{"data":{"counter":1}}`,
`{"data":{"counter":2}}`,
}, recorder.Messages())

messages := recorder.Messages()
assert.Len(t, messages, 3)
assert.Contains(t, messages, `{"data":{"counter":0}}`)
assert.Contains(t, messages, `{"data":{"counter":1}}`)
assert.Contains(t, messages, `{"data":{"counter":2}}`)
})

t.Run("should propagate initial payload to stream", func(t *testing.T) {
Expand All @@ -5271,7 +5268,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {

fakeStream := createFakeStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 2
}, 0, func(input []byte) {
}, 100*time.Millisecond, func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"},"initial_payload":{"hello":"world"}}`, string(input))
})

Expand All @@ -5285,12 +5282,12 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
err := resolver.AsyncResolveGraphQLSubscription(&ctx, plan, recorder, id)
assert.NoError(t, err)
recorder.AwaitComplete(t, defaultTimeout)
assert.Equal(t, 3, len(recorder.Messages()))
assert.ElementsMatch(t, []string{
`{"data":{"counter":0}}`,
`{"data":{"counter":1}}`,
`{"data":{"counter":2}}`,
}, recorder.Messages())

messages := recorder.Messages()
assert.Len(t, messages, 3)
assert.Contains(t, messages, `{"data":{"counter":0}}`)
assert.Contains(t, messages, `{"data":{"counter":1}}`)
assert.Contains(t, messages, `{"data":{"counter":2}}`)
})

t.Run("should stop stream on unsubscribe subscription", func(t *testing.T) {
Expand Down Expand Up @@ -5349,7 +5346,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {

fakeStream := createFakeStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0
}, 0, func(input []byte) {
}, 100*time.Millisecond, func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input))
})

Expand Down Expand Up @@ -5378,7 +5375,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {

fakeStream := createFakeStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0
}, 0, func(input []byte) {
}, 100*time.Millisecond, func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { countryUpdated { name time { local } } }"}}`, string(input))
})

Expand Down Expand Up @@ -5407,7 +5404,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {

fakeStream := createFakeStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), false
}, 0, func(input []byte) {
}, 100*time.Millisecond, func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input))
})

Expand Down Expand Up @@ -5465,7 +5462,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
fakeStream := createFakeStream(func(counter int) (message string, done bool) {
defer started.Store(true)
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), true
}, 0, func(input []byte) {
}, 100*time.Millisecond, func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input))
})

Expand All @@ -5474,11 +5471,6 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
buf: &bytes.Buffer{},
messages: []string{},
complete: atomic.Bool{},
onFlush: func(p []byte) {
for !complete.Load() {
time.Sleep(time.Millisecond * 10)
}
},
}
recorder.complete.Store(false)

Expand All @@ -5490,14 +5482,17 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
assert.NoError(t, err)
assert.Eventually(t, func() bool {
return started.Load()
}, defaultTimeout, time.Millisecond*100)
}, defaultTimeout, time.Millisecond*10)

assert.Len(t, resolver.triggers, 1)

var unsubscribeComplete atomic.Bool
go func() {
defer unsubscribeComplete.Store(true)
err = resolver.AsyncUnsubscribeSubscription(id)
assert.NoError(t, err)
}()
assert.Len(t, resolver.triggers, 1)

complete.Store(true)
assert.Eventually(t, unsubscribeComplete.Load, defaultTimeout, time.Millisecond*100)
recorder.AwaitComplete(t, defaultTimeout)
Expand Down Expand Up @@ -5573,7 +5568,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
return `{"id":1}`, false
}
return `{"id":2}`, true
}, 0, func(input []byte) {
}, 100*time.Millisecond, func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000"}`, string(input))
})

Expand Down Expand Up @@ -5669,7 +5664,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
return `{"id":1}`, false
}
return `{"id":2}`, true
}, 0, func(input []byte) {
}, 100*time.Millisecond, func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000"}`, string(input))
})

Expand Down Expand Up @@ -5763,7 +5758,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
return fmt.Sprintf(`{"id":%d}`, count), false
}
return `{"id":4}`, true
}, 0, func(input []byte) {
}, 100*time.Millisecond, func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000"}`, string(input))
})

Expand Down Expand Up @@ -5858,7 +5853,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
return fmt.Sprintf(`{"id":"x.%d"}`, count), false
}
return `{"id":"x.4"}`, true
}, 0, func(input []byte) {
}, 100*time.Millisecond, func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000"}`, string(input))
})

Expand Down Expand Up @@ -5957,7 +5952,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
return `{"id":"x.1"}`, false
}
return `{"id":"x.2"}`, true
}, 0, func(input []byte) {
}, 100*time.Millisecond, func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000"}`, string(input))
})

Expand Down

0 comments on commit 8a2b33c

Please sign in to comment.