Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix V2Typing send multiple times #214

Merged
merged 14 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 33 additions & 18 deletions sync2/handler2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package handler2
import (
"context"
"encoding/json"
"hash/fnv"
"os"
"sync"
"time"
Expand Down Expand Up @@ -40,8 +39,9 @@ type Handler struct {
Highlight int
Notif int
}
// room_id => fnv_hash([typing user ids])
typingMap map[string]uint64
// room_id -> PollerID, stores which Poller is allowed to update typing notifications
typingHandler map[string]sync2.PollerID
typingMu *sync.Mutex
PendingTxnIDs *sync2.PendingTransactionIDs

deviceDataTicker *sync2.DeviceDataTicker
Expand All @@ -64,7 +64,8 @@ func NewHandler(
Highlight int
Notif int
}),
typingMap: make(map[string]uint64),
typingMu: &sync.Mutex{},
typingHandler: make(map[string]sync2.PollerID),
PendingTxnIDs: sync2.NewPendingTransactionIDs(pMap.DeviceIDs),
deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration),
e2eeWorkerPool: internal.NewWorkerPool(500), // TODO: assign as fraction of db max conns, not hardcoded
Expand Down Expand Up @@ -166,7 +167,15 @@ func (h *Handler) updateMetrics() {
h.numPollers.Set(float64(h.pMap.NumPollers()))
}

func (h *Handler) OnTerminated(ctx context.Context, userID, deviceID string) {
func (h *Handler) OnTerminated(ctx context.Context, pollerID sync2.PollerID) {
// Check if this device is handling any typing notifications, of so, remove it
h.typingMu.Lock()
defer h.typingMu.Unlock()
for roomID, devID := range h.typingHandler {
if devID == pollerID {
delete(h.typingHandler, roomID)
}
}
h.updateMetrics()
}

Expand Down Expand Up @@ -352,13 +361,20 @@ func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.Ra
return res.PrependTimelineEvents
}

func (h *Handler) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
next := typingHash(ephEvent)
existing := h.typingMap[roomID]
if existing == next {
func (h *Handler) SetTyping(ctx context.Context, pollerID sync2.PollerID, roomID string, ephEvent json.RawMessage) {
h.typingMu.Lock()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though I couldn't get the test fail in CI, this was causing a race condition.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the bits which touch typingMap or typingDeviceHandler should be protected.

defer h.typingMu.Unlock()

existingDevice := h.typingHandler[roomID]
isPollerAssigned := existingDevice.DeviceID != "" && existingDevice.UserID != ""
if isPollerAssigned && existingDevice != pollerID {
// A different device is already handling typing notifications for this room
return
} else if !isPollerAssigned {
// We're the first to call SetTyping, assign our pollerID
h.typingHandler[roomID] = pollerID
}
h.typingMap[roomID] = next

// we don't persist this for long term storage as typing notifs are inherently ephemeral.
// So rather than maintaining them forever, they will naturally expire when we terminate.
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Typing{
Expand Down Expand Up @@ -472,6 +488,13 @@ func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string) {
logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to retire invite")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
}

// Remove room from the typing deviceHandler map, this ensures we always
// have a device handling typing notifications for a given room.
h.typingMu.Lock()
defer h.typingMu.Unlock()
delete(h.typingHandler, roomID)

h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2LeaveRoom{
UserID: userID,
RoomID: roomID,
Expand Down Expand Up @@ -507,11 +530,3 @@ func (h *Handler) EnsurePolling(p *pubsub.V3EnsurePolling) {
})
}()
}

func typingHash(ephEvent json.RawMessage) uint64 {
h := fnv.New64a()
for _, userID := range gjson.ParseBytes(ephEvent).Get("content.user_ids").Array() {
h.Write([]byte(userID.Str))
}
return h.Sum64()
}
54 changes: 51 additions & 3 deletions sync2/handler2/handler_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package handler2_test

import (
"context"
"encoding/json"
"os"
"reflect"
"sync"
Expand Down Expand Up @@ -97,12 +99,17 @@ func (p *mockPub) WaitForPayloadType(t string) chan struct{} {
return ch
}

func (p *mockPub) DoWait(t *testing.T, errMsg string, ch chan struct{}) {
func (p *mockPub) DoWait(t *testing.T, errMsg string, ch chan struct{}, wantTimeOut bool) {
select {
case <-ch:
if wantTimeOut {
t.Fatalf("expected to timeout, but received on channel")
}
return
case <-time.After(time.Second):
t.Fatalf("DoWait: timed out waiting: %s", errMsg)
if !wantTimeOut {
t.Fatalf("DoWait: timed out waiting: %s", errMsg)
}
}
}

Expand Down Expand Up @@ -160,7 +167,7 @@ func TestHandlerFreshEnsurePolling(t *testing.T) {
DeviceID: deviceID,
AccessTokenHash: tok.AccessTokenHash,
})
pub.DoWait(t, "didn't see V2InitialSyncComplete", ch)
pub.DoWait(t, "didn't see V2InitialSyncComplete", ch, false)

// make sure we polled with the token i.e it did a db hit
pMap.assertCallExists(t, pollInfo{
Expand All @@ -174,3 +181,44 @@ func TestHandlerFreshEnsurePolling(t *testing.T) {
})

}

func TestSetTypingConcurrently(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't what I'm interested in testing. The case here should be passing without your change.

The case I'm interested in is when you have 2 pollers receiving delayed typing notifs. For example. if alice starts typing then stops typing (so [A] then []) the problem is that 1 poller may be "behind" the other, such that it has yet to see [A] whilst the other "ahead" poller has already seen [A] and []. In this scenario, we flicker with 4 operations instead of 2, as we go [A], [], [A], [], which this test is not testing, nor does the code fix.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't what I'm interested in testing. The case here should be passing without your change.

It doesn't always pass. With luck it does, yea, if the machine is slow enough to execute both calls.

fatal error: concurrent map writes

goroutine 246 [running]:
github.com/matrix-org/sliding-sync/sync2/handler2.(*Handler).SetTyping(0xc0003d2080, {0x0?, 0x0?}, {0xa24ad3, 0x11}, {0xc0004e8090, 0x2d, 0x2d})
        github.com/sliding-sync/sync2/handler2/handler.go:344 +0x96
github.com/matrix-org/sliding-sync/sync2/handler2_test.TestSetTypingConcurrently.func2()
       github.com/sliding-sync/sync2/handler2/handler_test.go:203 +0xd9
created by github.com/matrix-org/sliding-sync/sync2/handler2_test.TestSetTypingConcurrently
       github.com/sliding-sync/sync2/handler2/handler_test.go:201 +0x2b0

which also means that h.typingMap[roomID] returned 0 as the existing value, resulting in duplicate notifications (if the machine is, again, slow enough, that the map writes aren't concurrent :D)

store := state.NewStorage(postgresURI)
v2Store := sync2.NewStore(postgresURI, "secret")
pMap := &mockPollerMap{}
pub := newMockPub()
sub := &mockSub{}
h, err := handler2.NewHandler(pMap, v2Store, store, pub, sub, false, time.Minute)
assertNoError(t, err)
ctx := context.Background()

roomID := "!typing:localhost"

typingType := pubsub.V2Typing{}

// startSignal is used to synchronize calling SetTyping
startSignal := make(chan struct{})
// Call SetTyping twice, this may happen with pollers for the same user
go func() {
<-startSignal
h.SetTyping(ctx, sync2.PollerID{UserID: "@alice", DeviceID: "aliceDevice"}, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`))
}()
go func() {
<-startSignal
h.SetTyping(ctx, sync2.PollerID{UserID: "@bob", DeviceID: "bobDevice"}, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`))
}()

close(startSignal)

// Wait for the event to be published
ch := pub.WaitForPayloadType(typingType.Type())
pub.DoWait(t, "didn't see V2Typing", ch, false)
ch = pub.WaitForPayloadType(typingType.Type())
// Wait again, but this time we expect to timeout.
pub.DoWait(t, "saw unexpected V2Typing", ch, true)

// We expect only one call to Notify, as the hashes should match
if gotCalls := len(pub.calls); gotCalls != 1 {
t.Fatalf("expected only one call to notify, got %d", gotCalls)
}
}
19 changes: 11 additions & 8 deletions sync2/poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type V2DataReceiver interface {
// If given a state delta from an incremental sync, returns the slice of all state events unknown to the DB.
Initialise(ctx context.Context, roomID string, state []json.RawMessage) []json.RawMessage // snapshot ID?
// SetTyping indicates which users are typing.
SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage)
SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage)
// Sent when there is a new receipt
OnReceipt(ctx context.Context, userID, roomID, ephEventType string, ephEvent json.RawMessage)
// AddToDeviceMessages adds this chunk of to_device messages. Preserve the ordering.
Expand All @@ -55,7 +55,7 @@ type V2DataReceiver interface {
// Sent when there is a _change_ in E2EE data, not all the time
OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int)
// Sent when the poll loop terminates
OnTerminated(ctx context.Context, userID, deviceID string)
OnTerminated(ctx context.Context, pollerID PollerID)
// Sent when the token gets a 401 response
OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string)
}
Expand Down Expand Up @@ -297,11 +297,11 @@ func (h *PollerMap) Initialise(ctx context.Context, roomID string, state []json.
wg.Wait()
return
}
func (h *PollerMap) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
func (h *PollerMap) SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) {
var wg sync.WaitGroup
wg.Add(1)
h.executor <- func() {
h.callbacks.SetTyping(ctx, roomID, ephEvent)
h.callbacks.SetTyping(ctx, pollerID, roomID, ephEvent)
wg.Done()
}
wg.Wait()
Expand Down Expand Up @@ -332,8 +332,8 @@ func (h *PollerMap) AddToDeviceMessages(ctx context.Context, userID, deviceID st
h.callbacks.AddToDeviceMessages(ctx, userID, deviceID, msgs)
}

func (h *PollerMap) OnTerminated(ctx context.Context, userID, deviceID string) {
h.callbacks.OnTerminated(ctx, userID, deviceID)
func (h *PollerMap) OnTerminated(ctx context.Context, pollerID PollerID) {
h.callbacks.OnTerminated(ctx, pollerID)
}

func (h *PollerMap) OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) {
Expand Down Expand Up @@ -473,7 +473,10 @@ func (p *poller) Poll(since string) {
logger.Error().Str("user", p.userID).Str("device", p.deviceID).Msgf("%s. Traceback:\n%s", panicErr, debug.Stack())
internal.GetSentryHubFromContextOrDefault(ctx).RecoverWithContext(ctx, panicErr)
}
p.receiver.OnTerminated(ctx, p.userID, p.deviceID)
p.receiver.OnTerminated(ctx, PollerID{
UserID: p.userID,
DeviceID: p.deviceID,
})
}()

state := pollLoopState{
Expand Down Expand Up @@ -706,7 +709,7 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) {
switch ephEventType {
case "m.typing":
typingCalls++
p.receiver.SetTyping(ctx, roomID, ephEvent)
p.receiver.SetTyping(ctx, PollerID{UserID: p.userID, DeviceID: p.deviceID}, roomID, ephEvent)
case "m.receipt":
receiptCalls++
p.receiver.OnReceipt(ctx, p.userID, roomID, ephEventType, ephEvent)
Expand Down
4 changes: 2 additions & 2 deletions sync2/poller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state
// timeline. Untested here---return nil for now.
return nil
}
func (a *mockDataReceiver) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
func (a *mockDataReceiver) SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) {
}
func (s *mockDataReceiver) UpdateDeviceSince(ctx context.Context, userID, deviceID, since string) {
s.pollerIDToSince[PollerID{UserID: userID, DeviceID: deviceID}] = since
Expand All @@ -620,7 +620,7 @@ func (s *mockDataReceiver) OnInvite(ctx context.Context, userID, roomID string,
func (s *mockDataReceiver) OnLeftRoom(ctx context.Context, userID, roomID string) {}
func (s *mockDataReceiver) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
}
func (s *mockDataReceiver) OnTerminated(ctx context.Context, userID, deviceID string) {}
func (s *mockDataReceiver) OnTerminated(ctx context.Context, pollerID PollerID) {}
func (s *mockDataReceiver) OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) {
}

Expand Down
Loading
Loading