From b59aeac9942db3b93a0575cbd368c5630cc7a682 Mon Sep 17 00:00:00 2001 From: Sergey Melekhin Date: Fri, 15 Sep 2023 13:00:42 +0700 Subject: [PATCH] Protect map and queue variable access with mutex, add test --- client.go | 16 +++++------ client_test.go | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index 4fd595e..e2a4dd8 100644 --- a/client.go +++ b/client.go @@ -631,10 +631,12 @@ func (c *Client) reader(t transport, disconnectCh chan struct{}) { func (c *Client) runHandlerSync(fn func()) { waitCh := make(chan struct{}) + c.mu.RLock() c.cbQueue.push(func(delay time.Duration) { defer close(waitCh) fn() }) + c.mu.RUnlock() <-waitCh } @@ -696,44 +698,38 @@ func (c *Client) handleMessage(msg *protocol.Message) error { } func (c *Client) handlePush(push *protocol.Push) { + channel := push.Channel + c.mu.RLock() + sub, ok := c.subs[channel] + c.mu.RUnlock() switch { case push.Message != nil: _ = c.handleMessage(push.Message) case push.Unsubscribe != nil: - channel := push.Channel - sub, ok := c.subs[channel] if !ok { c.handleServerUnsub(channel, push.Unsubscribe) return } sub.handleUnsubscribe(push.Unsubscribe) case push.Pub != nil: - channel := push.Channel - sub, ok := c.subs[channel] if !ok { c.handleServerPublication(channel, push.Pub) return } sub.handlePublication(push.Pub) case push.Join != nil: - channel := push.Channel - sub, ok := c.subs[channel] if !ok { c.handleServerJoin(channel, push.Join) return } sub.handleJoin(push.Join.Info) case push.Leave != nil: - channel := push.Channel - sub, ok := c.subs[channel] if !ok { c.handleServerLeave(channel, push.Leave) return } sub.handleLeave(push.Leave.Info) case push.Subscribe != nil: - channel := push.Channel - _, ok := c.subs[channel] if ok { // Client-side subscription exists. return diff --git a/client_test.go b/client_test.go index 4129c36..94df53b 100644 --- a/client_test.go +++ b/client_test.go @@ -392,3 +392,76 @@ func TestClient_History(t *testing.T) { t.Fatal("expected not available error, got " + strconv.FormatUint(uint64(e.Code), 10)) } } + +func TestConcurrentPublishSubscribe(t *testing.T) { + const ( + numMessages = 1000 + numResubscritpions = 100 + ) + + producer := NewJsonClient("ws://localhost:8000/connection/websocket?cf_protocol_version=v2", Config{}) + defer producer.Close() + + if err := producer.Connect(); err != nil { + t.Fatalf("error on connect: %v", err) + } + + errChan := make(chan error) + defer close(errChan) + go func() { + for i := 0; i < numMessages; i++ { + msg := []byte(`{"unique":"` + randString(6) + strconv.FormatInt(time.Now().UnixNano(), 10) + `"}`) + _, err := producer.Publish(context.Background(), "test_concurrent", msg) + if err != nil { + errChan <- fmt.Errorf("error on publish: %v", err) + return + } + } + errChan <- nil + }() + + go func() { + for i := 0; i < numResubscritpions; i++ { + consumer := NewJsonClient("ws://localhost:8000/connection/websocket?cf_protocol_version=v2", Config{}) + if err := consumer.Connect(); err != nil { + errChan <- fmt.Errorf("error on connect: %v", err) + return + } + + handler := &testSubscriptionHandler{} + sub, err := consumer.NewSubscription("test_concurrent") + if err != nil { + errChan <- fmt.Errorf("error on new subscription: %v (%d)", err, i) + return + } + sub.OnSubscribed(handler.OnSubscribe) + sub.OnPublication(handler.OnPublication) + if err := sub.Subscribe(); err != nil { + errChan <- fmt.Errorf("error on subscribe: %v (%d)", err, i) + return + } + sub2, err := consumer.NewSubscription("something_else") + if err != nil { + errChan <- fmt.Errorf("error on new subscription: %v (%d)", err, i) + return + } + sub2.OnSubscribed(handler.OnSubscribe) + sub2.OnPublication(handler.OnPublication) + if err := sub2.Subscribe(); err != nil { + errChan <- fmt.Errorf("error on subscribe: %v (%d)", err, i) + return + } + } + errChan <- nil + }() + + var err error + for i := 0; i < 2; i++ { + if e := <-errChan; e != nil { + err = e + } + } + if err != nil { + t.Fatal(err) + } +}