From f5d7ba79d0df895b1a5c3853382221397d5864b3 Mon Sep 17 00:00:00 2001 From: Felix Gateru Date: Wed, 22 Jan 2025 13:05:38 +0300 Subject: [PATCH] feat: add conn id to connect and disconnect events Signed-off-by: Felix Gateru --- cmd/http/main.go | 15 ++++++++--- cmd/ws/main.go | 11 +++++++- coap/events/doc.go | 2 +- coap/events/events.go | 53 +++++++++------------------------------ coap/events/stream.go | 29 ++++++++++++--------- docker/docker-compose.yml | 2 ++ http/events/doc.go | 2 +- http/events/streams.go | 2 -- mqtt/events/events.go | 4 ++- mqtt/events/streams.go | 20 +++++++++------ mqtt/handler.go | 31 +++++++++++------------ mqtt/handler_test.go | 14 +++++------ mqtt/mocks/events.go | 40 ++++++++++++++--------------- ws/handler.go | 4 +-- 14 files changed, 112 insertions(+), 117 deletions(-) diff --git a/cmd/http/main.go b/cmd/http/main.go index 992024dd40..343b138ca1 100644 --- a/cmd/http/main.go +++ b/cmd/http/main.go @@ -23,6 +23,7 @@ import ( grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" adapter "github.com/absmach/supermq/http" httpapi "github.com/absmach/supermq/http/api" + "github.com/absmach/supermq/http/events" smqlog "github.com/absmach/supermq/logger" smqauthn "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/authn/authsvc" @@ -59,6 +60,7 @@ type config struct { SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` InstanceID string `env:"SMQ_HTTP_ADAPTER_INSTANCE_ID" envDefault:""` TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` + ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` } func main() { @@ -141,6 +143,13 @@ func main() { defer authnHandler.Close() logger.Info("authn successfully connected to auth gRPC server " + authnHandler.Secure()) + eventStore, err := events.NewEventStore(ctx, cfg.ESURL, cfg.InstanceID) + if err != nil { + logger.Error(fmt.Sprintf("failed to create %s event store : %s", svcName, err)) + exitCode = 1 + return + } + tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) if err != nil { logger.Error(fmt.Sprintf("Failed to init Jaeger: %s", err)) @@ -163,7 +172,7 @@ func main() { defer pub.Close() pub = brokerstracing.NewPublisher(httpServerConfig, tracer, pub) - svc := newService(pub, authn, clientsClient, channelsClient, logger, tracer) + svc := newService(pub, eventStore, authn, clientsClient, channelsClient, logger, tracer) targetServerCfg := server.Config{Port: targetHTTPPort} hs := httpserver.NewServer(ctx, cancel, svcName, targetServerCfg, httpapi.MakeHandler(logger, cfg.InstanceID), logger) @@ -190,8 +199,8 @@ func main() { } } -func newService(pub messaging.Publisher, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, logger *slog.Logger, tracer trace.Tracer) session.Handler { - svc := adapter.NewHandler(pub, authn, clients, channels, logger) +func newService(pub messaging.Publisher, es events.EventStore, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, logger *slog.Logger, tracer trace.Tracer) session.Handler { + svc := adapter.NewHandler(pub, es, authn, clients, channels, logger) svc = handler.NewTracing(tracer, svc) svc = handler.LoggingMiddleware(svc, logger) counter, latency := prometheus.MakeMetrics(svcName, "api") diff --git a/cmd/ws/main.go b/cmd/ws/main.go index 9f54fd1eb9..d2b7d89c91 100644 --- a/cmd/ws/main.go +++ b/cmd/ws/main.go @@ -31,6 +31,7 @@ import ( "github.com/absmach/supermq/pkg/uuid" "github.com/absmach/supermq/ws" httpapi "github.com/absmach/supermq/ws/api" + "github.com/absmach/supermq/ws/events" "github.com/absmach/supermq/ws/tracing" "github.com/caarlos0/env/v11" "go.opentelemetry.io/otel/trace" @@ -55,6 +56,7 @@ type config struct { SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` InstanceID string `env:"SMQ_WS_ADAPTER_INSTANCE_ID" envDefault:""` TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` + ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` } func main() { @@ -143,6 +145,13 @@ func main() { defer authnHandler.Close() logger.Info("authn successfully connected to auth gRPC server " + authnHandler.Secure()) + eventStore, err := events.NewEventStore(ctx, cfg.ESURL, cfg.InstanceID) + if err != nil { + logger.Error(fmt.Sprintf("failed to create %s event store : %s", svcName, err)) + exitCode = 1 + return + } + tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) if err != nil { logger.Error(fmt.Sprintf("failed to init Jaeger: %s", err)) @@ -178,7 +187,7 @@ func main() { g.Go(func() error { return hs.Start() }) - handler := ws.NewHandler(nps, logger, authn, clientsClient, channelsClient) + handler := ws.NewHandler(nps, eventStore, logger, authn, clientsClient, channelsClient) return proxyWS(ctx, httpServerConfig, targetServerConfig, logger, handler) }) diff --git a/coap/events/doc.go b/coap/events/doc.go index 38ecdeb597..62541a0b4b 100644 --- a/coap/events/doc.go +++ b/coap/events/doc.go @@ -3,4 +3,4 @@ // Package events provides the domain concept definitions needed to support // coap events functionality. -package events \ No newline at end of file +package events diff --git a/coap/events/events.go b/coap/events/events.go index fa392c8baa..660778b20c 100644 --- a/coap/events/events.go +++ b/coap/events/events.go @@ -10,50 +10,21 @@ const ( clientUnsubscribe = coapPrefix + ".client_unsubscribe" ) -type clientPublishEvent struct { - ChannelID string - ClientID string - Topic string +type coapEvent struct { + operation string + channelID string + clientID string + connID string + topic string } -func (cpe clientPublishEvent) Encode() (map[string]interface{}, error) { +func (ce coapEvent) Encode() (map[string]interface{}, error) { val := map[string]interface{}{ - "operation": clientPublish, - "channel_id": cpe.ChannelID, - "client_id": cpe.ClientID, - "topic": cpe.Topic, - } - return val, nil -} - -type clientSubscribeEvent struct { - ChannelID string - ClientID string - Topic string -} - -func (cse clientSubscribeEvent) Encode() (map[string]interface{}, error) { - val := map[string]interface{}{ - "operation": clientSubscribe, - "channel_id": cse.ChannelID, - "client_id": cse.ClientID, - "topic": cse.Topic, - } - return val, nil -} - -type clientUnsubscribeEvent struct { - ChannelID string - ClientID string - Topic string -} - -func (cse clientUnsubscribeEvent) Encode() (map[string]interface{}, error) { - val := map[string]interface{}{ - "operation": clientUnsubscribe, - "channel_id": cse.ChannelID, - "client_id": cse.ClientID, - "topic": cse.Topic, + "operation": ce.operation, + "channel_id": ce.channelID, + "client_id": ce.clientID, + "conn_id": ce.connID, + "topic": ce.topic, } return val, nil } diff --git a/coap/events/stream.go b/coap/events/stream.go index f723b5bee3..5a46a26c64 100644 --- a/coap/events/stream.go +++ b/coap/events/stream.go @@ -39,10 +39,11 @@ func (es *eventStore) Publish(ctx context.Context, clientID string, msg *messagi return err } - event := clientPublishEvent{ - ClientID: clientID, - ChannelID: msg.GetChannel(), - Topic: msg.GetSubtopic(), + event := coapEvent{ + operation: clientPublish, + clientID: clientID, + channelID: msg.GetChannel(), + topic: msg.GetSubtopic(), } if err := es.events.Publish(ctx, event); err != nil { return err @@ -57,10 +58,12 @@ func (es *eventStore) Subscribe(ctx context.Context, clientID, channelID, subtop return err } - event := clientSubscribeEvent{ - ClientID: clientID, - ChannelID: channelID, - Topic: subtopic, + event := coapEvent{ + operation: clientSubscribe, + clientID: clientID, + channelID: channelID, + connID: c.Token(), + topic: subtopic, } if err := es.events.Publish(ctx, event); err != nil { return err @@ -75,10 +78,12 @@ func (es *eventStore) Unsubscribe(ctx context.Context, clientID, channelID, subt return err } - event := clientUnsubscribeEvent{ - ClientID: clientID, - ChannelID: channelID, - Topic: subtopic, + event := coapEvent{ + operation: clientUnsubscribe, + clientID: clientID, + channelID: channelID, + connID: token, + topic: subtopic, } if err := es.events.Publish(ctx, event); err != nil { return err diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 20622fc14f..897d5e771d 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -970,6 +970,7 @@ services: SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} SMQ_HTTP_ADAPTER_INSTANCE_ID: ${SMQ_HTTP_ADAPTER_INSTANCE_ID} + SMQ_ES_URL: ${SMQ_ES_URL} ports: - ${SMQ_HTTP_ADAPTER_PORT}:${SMQ_HTTP_ADAPTER_PORT} networks: @@ -1129,6 +1130,7 @@ services: SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} SMQ_WS_ADAPTER_INSTANCE_ID: ${SMQ_WS_ADAPTER_INSTANCE_ID} + SMQ_ES_URL: ${SMQ_ES_URL} ports: - ${SMQ_WS_ADAPTER_HTTP_PORT}:${SMQ_WS_ADAPTER_HTTP_PORT} networks: diff --git a/http/events/doc.go b/http/events/doc.go index 69932795ce..77e9503cd8 100644 --- a/http/events/doc.go +++ b/http/events/doc.go @@ -3,4 +3,4 @@ // Package events provides the domain concept definitions needed to support // http events functionality. -package events \ No newline at end of file +package events diff --git a/http/events/streams.go b/http/events/streams.go index 5d848c04ac..744d7d5ffa 100644 --- a/http/events/streams.go +++ b/http/events/streams.go @@ -12,8 +12,6 @@ import ( const streamID = "supermq.http" - - //go:generate mockery --name EventStore --output=../mocks --filename events.go --quiet --note "Copyright (c) Abstract Machines" type EventStore interface { Connect(ctx context.Context, clientID string) error diff --git a/mqtt/events/events.go b/mqtt/events/events.go index 5d6ee94b7e..ad00916f08 100644 --- a/mqtt/events/events.go +++ b/mqtt/events/events.go @@ -6,7 +6,7 @@ package events import "github.com/absmach/supermq/pkg/events" const ( - mqttPrefix = "http" + mqttPrefix = "mqtt" clientPublish = mqttPrefix + ".client_publish" clientSubscribe = mqttPrefix + ".client_subscribe" clientUnsubscribe = mqttPrefix + ".client_unsubscribe" @@ -20,6 +20,7 @@ type mqttEvent struct { operation string channelID string clientID string + connID string topic string instance string } @@ -29,6 +30,7 @@ func (me mqttEvent) Encode() (map[string]interface{}, error) { "operation": me.operation, "channel_id": me.channelID, "client_id": me.clientID, + "conn_id": me.connID, "topic": me.topic, }, nil } diff --git a/mqtt/events/streams.go b/mqtt/events/streams.go index b7aab6caa1..1f6804daa8 100644 --- a/mqtt/events/streams.go +++ b/mqtt/events/streams.go @@ -14,11 +14,11 @@ const streamID = "supermq.mqtt" //go:generate mockery --name EventStore --output=../mocks --filename events.go --quiet --note "Copyright (c) Abstract Machines" type EventStore interface { - Connect(ctx context.Context, clientID string) error - Disconnect(ctx context.Context, clientID string) error + Connect(ctx context.Context, clientID, connID string) error + Disconnect(ctx context.Context, clientID, connID string) error Publish(ctx context.Context, clientID, channelID, topic string) error - Subscribe(ctx context.Context, clientID, channelID, subtopic string) error - Unsubscribe(ctx context.Context, clientID, channelID, subtopic string) error + Subscribe(ctx context.Context, clientID, channelID, connID, subtopic string) error + Unsubscribe(ctx context.Context, clientID, channelID, connID, subtopic string) error } // EventStore is a struct used to store event streams in Redis. @@ -42,10 +42,11 @@ func NewEventStore(ctx context.Context, url, instance string) (EventStore, error } // Connect issues event on MQTT CONNECT. -func (es *eventStore) Connect(ctx context.Context, clientID string) error { +func (es *eventStore) Connect(ctx context.Context, clientID, connID string) error { ev := mqttEvent{ clientID: clientID, operation: clientConnect, + connID: connID, instance: es.instance, } @@ -53,10 +54,11 @@ func (es *eventStore) Connect(ctx context.Context, clientID string) error { } // Disconnect issues event on MQTT CONNECT. -func (es *eventStore) Disconnect(ctx context.Context, clientID string) error { +func (es *eventStore) Disconnect(ctx context.Context, clientID, connID string) error { ev := mqttEvent{ clientID: clientID, operation: clientDisconnect, + connID: connID, instance: es.instance, } @@ -77,11 +79,12 @@ func (es *eventStore) Publish(ctx context.Context, clientID, channelID, topic st } // Subscribe issues event on MQTT SUBSCRIBE. -func (es *eventStore) Subscribe(ctx context.Context, clientID, channelID, subtopic string) error { +func (es *eventStore) Subscribe(ctx context.Context, clientID, channelID, connID, subtopic string) error { ev := mqttEvent{ clientID: clientID, operation: clientSubscribe, channelID: channelID, + connID: connID, topic: subtopic, instance: es.instance, } @@ -90,11 +93,12 @@ func (es *eventStore) Subscribe(ctx context.Context, clientID, channelID, subtop } // Unsubscribe issues event on MQTT UNSUBSCRIBE. -func (es *eventStore) Unsubscribe(ctx context.Context, clientID, channelID, subtopic string) error { +func (es *eventStore) Unsubscribe(ctx context.Context, clientID, channelID, connID, subtopic string) error { ev := mqttEvent{ clientID: clientID, operation: clientUnsubscribe, channelID: channelID, + connID: connID, topic: subtopic, instance: es.instance, } diff --git a/mqtt/handler.go b/mqtt/handler.go index 9a488e5bb8..6439a74b92 100644 --- a/mqtt/handler.go +++ b/mqtt/handler.go @@ -110,7 +110,7 @@ func (h *handler) AuthConnect(ctx context.Context) error { return errInvalidUserId } - if err := h.es.Connect(ctx, pwd); err != nil { + if err := h.es.Connect(ctx, s.Username, s.ID); err != nil { h.logger.Error(errors.Wrap(ErrFailedPublishConnectEvent, err).Error()) } @@ -129,7 +129,6 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt } return h.authAccess(ctx, string(s.Username), *topic, connections.Publish) - } // AuthSubscribe is called on device subscribe, @@ -144,17 +143,9 @@ func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error { } for _, topic := range *topics { - err := h.authAccess(ctx, s.Username, topic, connections.Subscribe) - if err != nil { - return err - } - channelID, subTopic, err := parseTopic(topic) - if err != nil { + if err := h.authAccess(ctx, string(s.Username), topic, connections.Subscribe); err != nil { return err } - if err := h.es.Subscribe(ctx, s.Username, channelID, subTopic); err != nil { - return errors.Wrap(ErrFailedSubscribeEvent, err) - } } return nil @@ -211,6 +202,16 @@ func (h *handler) Subscribe(ctx context.Context, topics *[]string) error { return errors.Wrap(ErrFailedSubscribe, ErrClientNotInitialized) } + for _, topic := range *topics { + channelID, subTopic, err := parseTopic(topic) + if err != nil { + return err + } + if err := h.es.Subscribe(ctx, s.Username, channelID, s.ID, subTopic); err != nil { + return errors.Wrap(ErrFailedUnsubscribeEvent, err) + } + } + h.logger.Info(fmt.Sprintf(LogInfoSubscribed, s.ID, strings.Join(*topics, ","))) return nil @@ -228,15 +229,11 @@ func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error { } for _, topic := range *topics { - err := h.authAccess(ctx, s.Username, topic, connections.Subscribe) - if err != nil { - return err - } channelID, subTopic, err := parseTopic(topic) if err != nil { return err } - if err := h.es.Unsubscribe(ctx, s.Username, channelID, subTopic); err != nil { + if err := h.es.Unsubscribe(ctx, s.Username, channelID, s.ID, subTopic); err != nil { return errors.Wrap(ErrFailedUnsubscribeEvent, err) } } @@ -252,7 +249,7 @@ func (h *handler) Disconnect(ctx context.Context) error { return errors.Wrap(ErrFailedDisconnect, ErrClientNotInitialized) } h.logger.Error(fmt.Sprintf(LogInfoDisconnected, s.ID, s.Password)) - if err := h.es.Disconnect(ctx, string(s.Password)); err != nil { + if err := h.es.Disconnect(ctx, s.Username, s.ID); err != nil { return errors.Wrap(ErrFailedPublishDisconnectEvent, err) } return nil diff --git a/mqtt/handler_test.go b/mqtt/handler_test.go index 86293d253a..21d250fd51 100644 --- a/mqtt/handler_test.go +++ b/mqtt/handler_test.go @@ -147,7 +147,7 @@ func TestAuthConnect(t *testing.T) { password = string(tc.session.Password) } clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{ClientSecret: password}).Return(tc.authNRes, tc.authNErr) - svcCall := eventStore.On("Connect", mock.Anything, password).Return(tc.err) + svcCall := eventStore.On("Connect", mock.Anything, mock.Anything, clientID).Return(tc.err) err := handler.AuthConnect(ctx) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) svcCall.Unset() @@ -449,7 +449,7 @@ func TestSubscribe(t *testing.T) { if tc.session != nil { ctx = session.NewContext(ctx, tc.session) } - eventsCall := eventStore.On("Subscribe", mock.Anything, clientID, mock.Anything, mock.Anything).Return(nil) + eventsCall := eventStore.On("Subscribe", mock.Anything, clientID, mock.Anything, clientID, mock.Anything).Return(nil) err := handler.Subscribe(ctx, &tc.topic) assert.Contains(t, logBuffer.String(), tc.logMsg) assert.Equal(t, tc.err, err) @@ -476,7 +476,7 @@ func TestUnsubscribe(t *testing.T) { session: nil, topic: topics, channelID: chanID, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, + authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, err: errors.Wrap(mqtt.ErrFailedUnsubscribe, mqtt.ErrClientNotInitialized), }, { @@ -484,7 +484,7 @@ func TestUnsubscribe(t *testing.T) { session: &sessionClient, topic: topics, channelID: chanID, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, + authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, logMsg: fmt.Sprintf(mqtt.LogInfoUnsubscribed, clientID, topics[0]), }, } @@ -494,7 +494,7 @@ func TestUnsubscribe(t *testing.T) { if tc.session != nil { ctx = session.NewContext(ctx, tc.session) } - eventsCall := eventStore.On("Unsubscribe", mock.Anything, clientID, mock.Anything, mock.Anything).Return(nil) + eventsCall := eventStore.On("Unsubscribe", mock.Anything, clientID, mock.Anything, clientID, mock.Anything).Return(nil) channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{ ChannelId: tc.channelID, ClientId: clientID, @@ -536,12 +536,10 @@ func TestDisconnect(t *testing.T) { for _, tc := range cases { ctx := context.TODO() - password := "" if tc.session != nil { ctx = session.NewContext(ctx, tc.session) - password = string(tc.session.Password) } - svcCall := eventStore.On("Disconnect", mock.Anything, password).Return(tc.err) + svcCall := eventStore.On("Disconnect", mock.Anything, sessionClient.Username, sessionClient.ID).Return(tc.err) err := handler.Disconnect(ctx) assert.Contains(t, logBuffer.String(), tc.logMsg) assert.Equal(t, tc.err, err) diff --git a/mqtt/mocks/events.go b/mqtt/mocks/events.go index 1ad11e9ae8..7f26fb360e 100644 --- a/mqtt/mocks/events.go +++ b/mqtt/mocks/events.go @@ -15,17 +15,17 @@ type EventStore struct { mock.Mock } -// Connect provides a mock function with given fields: ctx, clientID -func (_m *EventStore) Connect(ctx context.Context, clientID string) error { - ret := _m.Called(ctx, clientID) +// Connect provides a mock function with given fields: ctx, clientID, connID +func (_m *EventStore) Connect(ctx context.Context, clientID string, connID string) error { + ret := _m.Called(ctx, clientID, connID) if len(ret) == 0 { panic("no return value specified for Connect") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { - r0 = rf(ctx, clientID) + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, clientID, connID) } else { r0 = ret.Error(0) } @@ -33,17 +33,17 @@ func (_m *EventStore) Connect(ctx context.Context, clientID string) error { return r0 } -// Disconnect provides a mock function with given fields: ctx, clientID -func (_m *EventStore) Disconnect(ctx context.Context, clientID string) error { - ret := _m.Called(ctx, clientID) +// Disconnect provides a mock function with given fields: ctx, clientID, connID +func (_m *EventStore) Disconnect(ctx context.Context, clientID string, connID string) error { + ret := _m.Called(ctx, clientID, connID) if len(ret) == 0 { panic("no return value specified for Disconnect") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { - r0 = rf(ctx, clientID) + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, clientID, connID) } else { r0 = ret.Error(0) } @@ -69,17 +69,17 @@ func (_m *EventStore) Publish(ctx context.Context, clientID string, channelID st return r0 } -// Subscribe provides a mock function with given fields: ctx, clientID, channelID, subtopic -func (_m *EventStore) Subscribe(ctx context.Context, clientID string, channelID string, subtopic string) error { - ret := _m.Called(ctx, clientID, channelID, subtopic) +// Subscribe provides a mock function with given fields: ctx, clientID, channelID, connID, subtopic +func (_m *EventStore) Subscribe(ctx context.Context, clientID string, channelID string, connID string, subtopic string) error { + ret := _m.Called(ctx, clientID, channelID, connID, subtopic) if len(ret) == 0 { panic("no return value specified for Subscribe") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { - r0 = rf(ctx, clientID, channelID, subtopic) + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, clientID, channelID, connID, subtopic) } else { r0 = ret.Error(0) } @@ -87,17 +87,17 @@ func (_m *EventStore) Subscribe(ctx context.Context, clientID string, channelID return r0 } -// Unsubscribe provides a mock function with given fields: ctx, clientID, channelID, subtopic -func (_m *EventStore) Unsubscribe(ctx context.Context, clientID string, channelID string, subtopic string) error { - ret := _m.Called(ctx, clientID, channelID, subtopic) +// Unsubscribe provides a mock function with given fields: ctx, clientID, channelID, connID, subtopic +func (_m *EventStore) Unsubscribe(ctx context.Context, clientID string, channelID string, connID string, subtopic string) error { + ret := _m.Called(ctx, clientID, channelID, connID, subtopic) if len(ret) == 0 { panic("no return value specified for Unsubscribe") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { - r0 = rf(ctx, clientID, channelID, subtopic) + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, clientID, channelID, connID, subtopic) } else { r0 = ret.Error(0) } diff --git a/ws/handler.go b/ws/handler.go index 5d54de6029..611b0ce3f7 100644 --- a/ws/handler.go +++ b/ws/handler.go @@ -49,7 +49,7 @@ var ( errFailedPublish = errors.New("failed to publish") errFailedParseSubtopic = errors.New("failed to parse subtopic") errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker") - errFailedPublishEvent = errors.New("failed to publish publish event") + errFailedPublishEvent = errors.New("failed to publish event") errFailedSubscribeEvent = errors.New("failed to publish subscribe event") errFailedUnsubscribeEvent = errors.New("failed to publish unsubscribe event") ) @@ -147,7 +147,7 @@ func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error { return err } if err := h.es.Subscribe(ctx, clientID, channelID, subTopic); err != nil { - return errors.Wrap(errFailedSubscribeEvent, err) + return errors.Wrap(errFailedUnsubscribeEvent, err) } }