From 07dbb86203c75da783612662590c1c4539dc6268 Mon Sep 17 00:00:00 2001 From: Felix Gateru Date: Thu, 30 Jan 2025 14:45:40 +0300 Subject: [PATCH] SMQ-2546 - Add events to adapters (#2659) Signed-off-by: Felix Gateru --- cmd/coap/main.go | 9 ++++ cmd/http/main.go | 9 ++++ cmd/mqtt/main.go | 22 +++++++++ cmd/ws/main.go | 9 ++++ docker/docker-compose.yml | 5 +- mqtt/events/events.go | 48 +++++++++++++++---- mqtt/events/streams.go | 50 ++++++++++++------- mqtt/handler.go | 33 ++++++++++++- mqtt/handler_test.go | 8 ++-- mqtt/mocks/events.go | 38 +++++++++++---- pkg/messaging/events/events.go | 47 ++++++++++++++++++ pkg/messaging/events/publisher.go | 49 +++++++++++++++++++ pkg/messaging/events/pubsub.go | 79 +++++++++++++++++++++++++++++++ 13 files changed, 362 insertions(+), 44 deletions(-) create mode 100644 pkg/messaging/events/events.go create mode 100644 pkg/messaging/events/publisher.go create mode 100644 pkg/messaging/events/pubsub.go diff --git a/cmd/coap/main.go b/cmd/coap/main.go index 0ad2b73727..8dae202479 100644 --- a/cmd/coap/main.go +++ b/cmd/coap/main.go @@ -21,6 +21,7 @@ import ( jaegerclient "github.com/absmach/supermq/pkg/jaeger" "github.com/absmach/supermq/pkg/messaging/brokers" brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing" + msgevents "github.com/absmach/supermq/pkg/messaging/events" "github.com/absmach/supermq/pkg/prometheus" "github.com/absmach/supermq/pkg/server" coapserver "github.com/absmach/supermq/pkg/server/coap" @@ -47,6 +48,7 @@ type config struct { SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` InstanceID string `env:"SMQ_COAP_ADAPTER_INSTANCE_ID" envDefault:""` TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` + ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` } func main() { @@ -143,6 +145,13 @@ func main() { defer nps.Close() nps = brokerstracing.NewPubSub(coapServerConfig, tracer, nps) + nps, err = msgevents.NewPubSubMiddleware(ctx, nps, cfg.ESURL) + if err != nil { + logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err)) + exitCode = 1 + return + } + svc := coap.New(clientsClient, channelsClient, nps) svc = tracing.New(tracer, svc) diff --git a/cmd/http/main.go b/cmd/http/main.go index 992024dd40..4b03eb2426 100644 --- a/cmd/http/main.go +++ b/cmd/http/main.go @@ -31,6 +31,7 @@ import ( "github.com/absmach/supermq/pkg/messaging" "github.com/absmach/supermq/pkg/messaging/brokers" brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing" + msgevents "github.com/absmach/supermq/pkg/messaging/events" "github.com/absmach/supermq/pkg/messaging/handler" "github.com/absmach/supermq/pkg/prometheus" "github.com/absmach/supermq/pkg/server" @@ -59,6 +60,7 @@ type config struct { SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` InstanceID string `env:"SMQ_HTTP_ADAPTER_INSTANCE_ID" envDefault:""` TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` + ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` } func main() { @@ -163,6 +165,13 @@ func main() { defer pub.Close() pub = brokerstracing.NewPublisher(httpServerConfig, tracer, pub) + pub, err = msgevents.NewPublisherMiddleware(ctx, pub, cfg.ESURL) + if err != nil { + logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err)) + exitCode = 1 + return + } + svc := newService(pub, authn, clientsClient, channelsClient, logger, tracer) targetServerCfg := server.Config{Port: targetHTTPPort} diff --git a/cmd/mqtt/main.go b/cmd/mqtt/main.go index 496393d28e..8f527daf04 100644 --- a/cmd/mqtt/main.go +++ b/cmd/mqtt/main.go @@ -32,6 +32,7 @@ import ( jaegerclient "github.com/absmach/supermq/pkg/jaeger" "github.com/absmach/supermq/pkg/messaging/brokers" brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing" + msgevents "github.com/absmach/supermq/pkg/messaging/events" "github.com/absmach/supermq/pkg/messaging/handler" mqttpub "github.com/absmach/supermq/pkg/messaging/mqtt" "github.com/absmach/supermq/pkg/server" @@ -134,6 +135,13 @@ func main() { defer bsub.Close() bsub = brokerstracing.NewPubSub(serverConfig, tracer, bsub) + bsub, err = msgevents.NewPubSubMiddleware(ctx, bsub, cfg.ESURL) + if err != nil { + logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err)) + exitCode = 1 + return + } + mpub, err := mqttpub.NewPublisher(fmt.Sprintf("mqtt://%s:%s", cfg.MQTTTargetHost, cfg.MQTTTargetPort), cfg.MQTTQoS, cfg.MQTTForwarderTimeout) if err != nil { logger.Error(fmt.Sprintf("failed to create MQTT publisher: %s", err)) @@ -142,6 +150,13 @@ func main() { } defer mpub.Close() + mpub, err = msgevents.NewPublisherMiddleware(ctx, mpub, cfg.ESURL) + if err != nil { + logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err)) + exitCode = 1 + return + } + fwd := mqtt.NewForwarder(brokers.SubjectAllChannels, logger) fwd = mqtttracing.New(serverConfig, tracer, fwd, brokers.SubjectAllChannels) if err := fwd.Forward(ctx, svcName, bsub, mpub); err != nil { @@ -159,6 +174,13 @@ func main() { defer np.Close() np = brokerstracing.NewPublisher(serverConfig, tracer, np) + np, err = msgevents.NewPublisherMiddleware(ctx, np, cfg.ESURL) + if err != nil { + logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err)) + exitCode = 1 + return + } + es, err := events.NewEventStore(ctx, cfg.ESURL, cfg.Instance) if err != nil { logger.Error(fmt.Sprintf("failed to create %s event store : %s", svcName, err)) diff --git a/cmd/ws/main.go b/cmd/ws/main.go index 9f54fd1eb9..8dc5223f92 100644 --- a/cmd/ws/main.go +++ b/cmd/ws/main.go @@ -25,6 +25,7 @@ import ( "github.com/absmach/supermq/pkg/messaging" "github.com/absmach/supermq/pkg/messaging/brokers" brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing" + msgevents "github.com/absmach/supermq/pkg/messaging/events" "github.com/absmach/supermq/pkg/prometheus" "github.com/absmach/supermq/pkg/server" httpserver "github.com/absmach/supermq/pkg/server/http" @@ -55,6 +56,7 @@ type config struct { SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` InstanceID string `env:"SMQ_WS_ADAPTER_INSTANCE_ID" envDefault:""` TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` + ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` } func main() { @@ -165,6 +167,13 @@ func main() { defer nps.Close() nps = brokerstracing.NewPubSub(targetServerConfig, tracer, nps) + nps, err = msgevents.NewPubSubMiddleware(ctx, nps, cfg.ESURL) + if err != nil { + logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err)) + exitCode = 1 + return + } + svc := newService(clientsClient, channelsClient, nps, logger, tracer) hs := httpserver.NewServer(ctx, cancel, svcName, targetServerConfig, httpapi.MakeHandler(ctx, svc, logger, cfg.InstanceID), logger) diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 37a5dce0df..ae702bebea 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -845,7 +845,6 @@ services: bind: create_host_path: true - groups-db: image: postgres:16.2-alpine container_name: supermq-groups-db @@ -948,7 +947,6 @@ services: bind: create_host_path: true - jaeger: image: jaegertracing/all-in-one:1.60 container_name: supermq-jaeger @@ -1067,6 +1065,7 @@ services: SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} SMQ_HTTP_ADAPTER_INSTANCE_ID: ${SMQ_HTTP_ADAPTER_INSTANCE_ID} + SMQ_ES_URL: ${SMQ_ES_URL} ports: - ${SMQ_HTTP_ADAPTER_PORT}:${SMQ_HTTP_ADAPTER_PORT} networks: @@ -1153,6 +1152,7 @@ services: SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} SMQ_COAP_ADAPTER_INSTANCE_ID: ${SMQ_COAP_ADAPTER_INSTANCE_ID} + SMQ_ES_URL: ${SMQ_ES_URL} ports: - ${SMQ_COAP_ADAPTER_PORT}:${SMQ_COAP_ADAPTER_PORT}/udp - ${SMQ_COAP_ADAPTER_HTTP_PORT}:${SMQ_COAP_ADAPTER_HTTP_PORT}/tcp @@ -1230,6 +1230,7 @@ services: SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} SMQ_WS_ADAPTER_INSTANCE_ID: ${SMQ_WS_ADAPTER_INSTANCE_ID} + SMQ_ES_URL: ${SMQ_ES_URL} ports: - ${SMQ_WS_ADAPTER_HTTP_PORT}:${SMQ_WS_ADAPTER_HTTP_PORT} networks: diff --git a/mqtt/events/events.go b/mqtt/events/events.go index b8ea682bf4..f5e6b7e02c 100644 --- a/mqtt/events/events.go +++ b/mqtt/events/events.go @@ -5,18 +5,48 @@ package events import "github.com/absmach/supermq/pkg/events" -var _ events.Event = (*mqttEvent)(nil) +const ( + mqttPrefix = "mqtt" + clientSubscribe = mqttPrefix + ".client_subscribe" + clientConnect = mqttPrefix + ".client_connect" + clientDisconnect = mqttPrefix + ".client_disconnect" +) -type mqttEvent struct { - clientID string - operation string - instance string +var ( + _ events.Event = (*connectEvent)(nil) + _ events.Event = (*subscribeEvent)(nil) +) + +type connectEvent struct { + operation string + clientID string + subscriberID string + instance string +} + +func (ce connectEvent) Encode() (map[string]interface{}, error) { + return map[string]interface{}{ + "operation": ce.operation, + "client_id": ce.clientID, + "subscriber_id": ce.subscriberID, + "instance": ce.instance, + }, nil +} + +type subscribeEvent struct { + operation string + clientID string + subscriberID string + channelID string + subtopic string } -func (me mqttEvent) Encode() (map[string]interface{}, error) { +func (se subscribeEvent) Encode() (map[string]interface{}, error) { return map[string]interface{}{ - "client_id": me.clientID, - "operation": me.operation, - "instance": me.instance, + "operation": se.operation, + "client_id": se.clientID, + "subscriber_id": se.subscriberID, + "channel_id": se.channelID, + "subtopic": se.subtopic, }, nil } diff --git a/mqtt/events/streams.go b/mqtt/events/streams.go index 515ccd6457..b81316e437 100644 --- a/mqtt/events/streams.go +++ b/mqtt/events/streams.go @@ -14,13 +14,14 @@ const streamID = "supermq.mqtt" //go:generate mockery --name EventStore --output=../mocks --filename events.go --quiet --note "Copyright (c) Abstract Machines" type EventStore interface { - Connect(ctx context.Context, clientID string) error - Disconnect(ctx context.Context, clientID string) error + Connect(ctx context.Context, clientID, subscriberID string) error + Disconnect(ctx context.Context, clientID, subscriberID string) error + Subscribe(ctx context.Context, clientID, channelID, subscriberID, subtopic string) error } // EventStore is a struct used to store event streams in Redis. type eventStore struct { - events.Publisher + ep events.Publisher instance string } @@ -33,29 +34,44 @@ func NewEventStore(ctx context.Context, url, instance string) (EventStore, error } return &eventStore{ - instance: instance, - Publisher: publisher, + instance: instance, + ep: publisher, }, nil } // Connect issues event on MQTT CONNECT. -func (es *eventStore) Connect(ctx context.Context, clientID string) error { - ev := mqttEvent{ - clientID: clientID, - operation: "connect", - instance: es.instance, +func (es *eventStore) Connect(ctx context.Context, clientID, subscriberID string) error { + ev := connectEvent{ + clientID: clientID, + operation: clientConnect, + subscriberID: subscriberID, + instance: es.instance, } - return es.Publish(ctx, ev) + return es.ep.Publish(ctx, ev) } // Disconnect issues event on MQTT CONNECT. -func (es *eventStore) Disconnect(ctx context.Context, clientID string) error { - ev := mqttEvent{ - clientID: clientID, - operation: "disconnect", - instance: es.instance, +func (es *eventStore) Disconnect(ctx context.Context, clientID, subscriberID string) error { + ev := connectEvent{ + clientID: clientID, + operation: clientDisconnect, + subscriberID: subscriberID, + instance: es.instance, } - return es.Publish(ctx, ev) + return es.ep.Publish(ctx, ev) +} + +// Subscribe issues event on MQTT SUBSCRIBE. +func (es *eventStore) Subscribe(ctx context.Context, clientID, channelID, subscriberID, subtopic string) error { + ev := subscribeEvent{ + operation: clientSubscribe, + clientID: clientID, + channelID: channelID, + subscriberID: subscriberID, + subtopic: subtopic, + } + + return es.ep.Publish(ctx, ev) } diff --git a/mqtt/handler.go b/mqtt/handler.go index d0ba090519..47436331c3 100644 --- a/mqtt/handler.go +++ b/mqtt/handler.go @@ -52,6 +52,7 @@ var ( ErrFailedPublishDisconnectEvent = errors.New("failed to publish disconnect event") ErrFailedParseSubtopic = errors.New("failed to parse subtopic") ErrFailedPublishConnectEvent = errors.New("failed to publish connect event") + ErrFailedSubscribeEvent = errors.New("failed to publish subscribe event") ErrFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker") ) @@ -106,7 +107,7 @@ func (h *handler) AuthConnect(ctx context.Context) error { return errInvalidUserId } - if err := h.es.Connect(ctx, pwd); err != nil { + if err := h.es.Connect(ctx, s.Username, s.ID); err != nil { h.logger.Error(errors.Wrap(ErrFailedPublishConnectEvent, err).Error()) } @@ -202,6 +203,17 @@ func (h *handler) Subscribe(ctx context.Context, topics *[]string) error { if !ok { return errors.Wrap(ErrFailedSubscribe, ErrClientNotInitialized) } + + for _, topic := range *topics { + channelID, subTopic, err := parseTopic(topic) + if err != nil { + return err + } + if err := h.es.Subscribe(ctx, s.Username, channelID, s.ID, subTopic); err != nil { + return errors.Wrap(ErrFailedSubscribeEvent, err) + } + } + h.logger.Info(fmt.Sprintf(LogInfoSubscribed, s.ID, strings.Join(*topics, ","))) return nil } @@ -223,7 +235,7 @@ func (h *handler) Disconnect(ctx context.Context) error { return errors.Wrap(ErrFailedDisconnect, ErrClientNotInitialized) } h.logger.Error(fmt.Sprintf(LogInfoDisconnected, s.ID, s.Password)) - if err := h.es.Disconnect(ctx, string(s.Password)); err != nil { + if err := h.es.Disconnect(ctx, s.Username, s.ID); err != nil { return errors.Wrap(ErrFailedPublishDisconnectEvent, err) } return nil @@ -260,6 +272,23 @@ func (h *handler) authAccess(ctx context.Context, clientID, topic string, msgTyp return nil } +func parseTopic(topic string) (string, string, error) { + channelParts := channelRegExp.FindStringSubmatch(topic) + if len(channelParts) < 2 { + return "", "", errors.Wrap(ErrFailedPublish, ErrMalformedTopic) + } + + chanID := channelParts[1] + subtopic := channelParts[2] + + subtopic, err := parseSubtopic(subtopic) + if err != nil { + return "", "", errors.Wrap(ErrFailedParseSubtopic, err) + } + + return chanID, subtopic, nil +} + func parseSubtopic(subtopic string) (string, error) { if subtopic == "" { return subtopic, nil diff --git a/mqtt/handler_test.go b/mqtt/handler_test.go index 1b4d8ebe96..c86ba9fb00 100644 --- a/mqtt/handler_test.go +++ b/mqtt/handler_test.go @@ -147,7 +147,7 @@ func TestAuthConnect(t *testing.T) { password = string(tc.session.Password) } clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{ClientSecret: password}).Return(tc.authNRes, tc.authNErr) - svcCall := eventStore.On("Connect", mock.Anything, password).Return(tc.err) + svcCall := eventStore.On("Connect", mock.Anything, clientID, mock.Anything).Return(tc.err) err := handler.AuthConnect(ctx) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) svcCall.Unset() @@ -445,9 +445,11 @@ func TestSubscribe(t *testing.T) { if tc.session != nil { ctx = session.NewContext(ctx, tc.session) } + eventsCall := eventStore.On("Subscribe", mock.Anything, clientID, chanID, clientID, mock.Anything).Return(nil) err := handler.Subscribe(ctx, &tc.topic) assert.Contains(t, logBuffer.String(), tc.logMsg) assert.Equal(t, tc.err, err) + eventsCall.Unset() } } @@ -514,12 +516,10 @@ func TestDisconnect(t *testing.T) { for _, tc := range cases { ctx := context.TODO() - password := "" if tc.session != nil { ctx = session.NewContext(ctx, tc.session) - password = string(tc.session.Password) } - svcCall := eventStore.On("Disconnect", mock.Anything, password).Return(tc.err) + svcCall := eventStore.On("Disconnect", mock.Anything, clientID, mock.Anything).Return(tc.err) err := handler.Disconnect(ctx) assert.Contains(t, logBuffer.String(), tc.logMsg) assert.Equal(t, tc.err, err) diff --git a/mqtt/mocks/events.go b/mqtt/mocks/events.go index 7dcebfd763..30c64e42d8 100644 --- a/mqtt/mocks/events.go +++ b/mqtt/mocks/events.go @@ -15,17 +15,17 @@ type EventStore struct { mock.Mock } -// Connect provides a mock function with given fields: ctx, clientID -func (_m *EventStore) Connect(ctx context.Context, clientID string) error { - ret := _m.Called(ctx, clientID) +// Connect provides a mock function with given fields: ctx, clientID, subscriberID +func (_m *EventStore) Connect(ctx context.Context, clientID string, subscriberID string) error { + ret := _m.Called(ctx, clientID, subscriberID) if len(ret) == 0 { panic("no return value specified for Connect") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { - r0 = rf(ctx, clientID) + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, clientID, subscriberID) } else { r0 = ret.Error(0) } @@ -33,17 +33,35 @@ func (_m *EventStore) Connect(ctx context.Context, clientID string) error { return r0 } -// Disconnect provides a mock function with given fields: ctx, clientID -func (_m *EventStore) Disconnect(ctx context.Context, clientID string) error { - ret := _m.Called(ctx, clientID) +// Disconnect provides a mock function with given fields: ctx, clientID, subscriberID +func (_m *EventStore) Disconnect(ctx context.Context, clientID string, subscriberID string) error { + ret := _m.Called(ctx, clientID, subscriberID) if len(ret) == 0 { panic("no return value specified for Disconnect") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { - r0 = rf(ctx, clientID) + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, clientID, subscriberID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Subscribe provides a mock function with given fields: ctx, clientID, channelID, subscriberID, subtopic +func (_m *EventStore) Subscribe(ctx context.Context, clientID string, channelID string, subscriberID string, subtopic string) error { + ret := _m.Called(ctx, clientID, channelID, subscriberID, subtopic) + + if len(ret) == 0 { + panic("no return value specified for Subscribe") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, clientID, channelID, subscriberID, subtopic) } else { r0 = ret.Error(0) } diff --git a/pkg/messaging/events/events.go b/pkg/messaging/events/events.go new file mode 100644 index 0000000000..12f6ce3df4 --- /dev/null +++ b/pkg/messaging/events/events.go @@ -0,0 +1,47 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package events + +import "github.com/absmach/supermq/pkg/events" + +const ( + messagingPrefix = "messaging" + clientPublish = messagingPrefix + ".client_publish" + clientSubscribe = messagingPrefix + ".client_subscribe" + clientUnsubscribe = messagingPrefix + ".client_unsubscribe" +) + +var ( + _ events.Event = (*publishEvent)(nil) + _ events.Event = (*subscribeEvent)(nil) +) + +type publishEvent struct { + channelID string + clientID string + subtopic string +} + +func (pe publishEvent) Encode() (map[string]interface{}, error) { + return map[string]interface{}{ + "operation": clientPublish, + "channel_id": pe.channelID, + "client_id": pe.clientID, + "subtopic": pe.subtopic, + }, nil +} + +type subscribeEvent struct { + operation string + subscriberID string + subtopic string +} + +func (se subscribeEvent) Encode() (map[string]interface{}, error) { + return map[string]interface{}{ + "operation": se.operation, + "subscriber_id": se.subscriberID, + "subtopic": se.subtopic, + }, nil +} diff --git a/pkg/messaging/events/publisher.go b/pkg/messaging/events/publisher.go new file mode 100644 index 0000000000..cb59398e5d --- /dev/null +++ b/pkg/messaging/events/publisher.go @@ -0,0 +1,49 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package events + +import ( + "context" + + "github.com/absmach/supermq/pkg/events" + "github.com/absmach/supermq/pkg/events/store" + "github.com/absmach/supermq/pkg/messaging" +) + +var _ messaging.Publisher = (*publisherES)(nil) + +type publisherES struct { + ep events.Publisher + pub messaging.Publisher +} + +func NewPublisherMiddleware(ctx context.Context, pub messaging.Publisher, url string) (messaging.Publisher, error) { + publisher, err := store.NewPublisher(ctx, url, streamID) + if err != nil { + return nil, err + } + + return &publisherES{ + ep: publisher, + pub: pub, + }, nil +} + +func (es *publisherES) Publish(ctx context.Context, topic string, msg *messaging.Message) error { + if err := es.pub.Publish(ctx, topic, msg); err != nil { + return err + } + + me := publishEvent{ + channelID: msg.Channel, + clientID: msg.Publisher, + subtopic: msg.Subtopic, + } + + return es.ep.Publish(ctx, me) +} + +func (es *publisherES) Close() error { + return es.pub.Close() +} diff --git a/pkg/messaging/events/pubsub.go b/pkg/messaging/events/pubsub.go new file mode 100644 index 0000000000..8e792ae891 --- /dev/null +++ b/pkg/messaging/events/pubsub.go @@ -0,0 +1,79 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package events + +import ( + "context" + + "github.com/absmach/supermq/pkg/events" + "github.com/absmach/supermq/pkg/events/store" + "github.com/absmach/supermq/pkg/messaging" +) + +const streamID = "supermq.messaging" + +var _ messaging.PubSub = (*pubsubES)(nil) + +type pubsubES struct { + ep events.Publisher + pubsub messaging.PubSub +} + +func NewPubSubMiddleware(ctx context.Context, pubsub messaging.PubSub, url string) (messaging.PubSub, error) { + publisher, err := store.NewPublisher(ctx, url, streamID) + if err != nil { + return nil, err + } + + return &pubsubES{ + ep: publisher, + pubsub: pubsub, + }, nil +} + +func (es *pubsubES) Publish(ctx context.Context, topic string, msg *messaging.Message) error { + if err := es.pubsub.Publish(ctx, topic, msg); err != nil { + return err + } + + me := publishEvent{ + channelID: msg.Channel, + clientID: msg.Publisher, + subtopic: msg.Subtopic, + } + + return es.ep.Publish(ctx, me) +} + +func (es *pubsubES) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) error { + if err := es.pubsub.Subscribe(ctx, cfg); err != nil { + return err + } + + se := subscribeEvent{ + operation: clientSubscribe, + subscriberID: cfg.ID, + subtopic: cfg.Topic, + } + + return es.ep.Publish(ctx, se) +} + +func (es *pubsubES) Unsubscribe(ctx context.Context, id string, topic string) error { + if err := es.pubsub.Unsubscribe(ctx, id, topic); err != nil { + return err + } + + se := subscribeEvent{ + operation: clientUnsubscribe, + subscriberID: id, + subtopic: topic, + } + + return es.ep.Publish(ctx, se) +} + +func (es *pubsubES) Close() error { + return es.pubsub.Close() +}