diff --git a/cmd/mqtt/main.go b/cmd/mqtt/main.go index 8f527daf04..76162c959d 100644 --- a/cmd/mqtt/main.go +++ b/cmd/mqtt/main.go @@ -135,13 +135,6 @@ func main() { defer bsub.Close() bsub = brokerstracing.NewPubSub(serverConfig, tracer, bsub) - bsub, err = msgevents.NewPubSubMiddleware(ctx, bsub, cfg.ESURL) - if err != nil { - logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err)) - exitCode = 1 - return - } - mpub, err := mqttpub.NewPublisher(fmt.Sprintf("mqtt://%s:%s", cfg.MQTTTargetHost, cfg.MQTTTargetPort), cfg.MQTTQoS, cfg.MQTTForwarderTimeout) if err != nil { logger.Error(fmt.Sprintf("failed to create MQTT publisher: %s", err)) diff --git a/coap/adapter.go b/coap/adapter.go index 806f888262..9d6e21f6d1 100644 --- a/coap/adapter.go +++ b/coap/adapter.go @@ -121,9 +121,10 @@ func (svc *adapterService) Subscribe(ctx context.Context, key, chanID, subtopic authzc := newAuthzClient(clientID, chanID, subtopic, svc.channels, c) subCfg := messaging.SubscriberConfig{ - ID: c.Token(), - Topic: subject, - Handler: authzc, + ID: c.Token(), + ClientID: clientID, + Topic: subject, + Handler: authzc, } return svc.pubsub.Subscribe(ctx, subCfg) } diff --git a/journal/journal.go b/journal/journal.go index 8f694243f3..df4e0e0d25 100644 --- a/journal/journal.go +++ b/journal/journal.go @@ -140,13 +140,21 @@ func (page JournalsPage) MarshalJSON() ([]byte, error) { type ClientTelemetry struct { ClientID string `json:"client_id"` DomainID string `json:"domain_id"` - Subscriptions []string `json:"subscriptions"` + Subscriptions uint64 `json:"subscriptions"` InboundMessages uint64 `json:"inbound_messages"` OutboundMessages uint64 `json:"outbound_messages"` FirstSeen time.Time `json:"first_seen"` LastSeen time.Time `json:"last_seen"` } +type ClientSubscription struct { + ID string `json:"id" db:"id"` + SubscriberID string `json:"subscriber_id" db:"subscriber_id"` + ChannelID string `json:"channel_id" db:"channel_id"` + Subtopic string `json:"subtopic" db:"subtopic"` + ClientID string `json:"client_id" db:"client_id"` +} + // Service provides access to the journal log service. // //go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines" @@ -181,13 +189,13 @@ type Repository interface { DeleteClientTelemetry(ctx context.Context, clientID, domainID string) error // AddSubscription adds a subscription to the client telemetry. - AddSubscription(ctx context.Context, clientID, sub string) error + AddSubscription(ctx context.Context, sub ClientSubscription) error - // RemoveSubscription removes a subscription from the client telemetry. - RemoveSubscription(ctx context.Context, clientID, sub string) error + // CountSubscriptions returns the number of subscriptions for a client. + CountSubscriptions(ctx context.Context, clientID string) (uint64, error) - // RemoveSubscriptionWithConnID removes a subscription from the client telemetry using the connection ID. - RemoveSubscriptionWithConnID(ctx context.Context, connID, clientID string) error + // RemoveSubscription removes a subscription from the client telemetry. + RemoveSubscription(ctx context.Context, subscriberID string) error // IncrementInboundMessages increments the inbound messages count for a client. IncrementInboundMessages(ctx context.Context, clientID string) error diff --git a/journal/mocks/repository.go b/journal/mocks/repository.go index 4e58d26a44..10170dc2d7 100644 --- a/journal/mocks/repository.go +++ b/journal/mocks/repository.go @@ -16,17 +16,17 @@ type Repository struct { mock.Mock } -// AddSubscription provides a mock function with given fields: ctx, clientID, sub -func (_m *Repository) AddSubscription(ctx context.Context, clientID string, sub string) error { - ret := _m.Called(ctx, clientID, sub) +// AddSubscription provides a mock function with given fields: ctx, sub +func (_m *Repository) AddSubscription(ctx context.Context, sub journal.ClientSubscription) error { + ret := _m.Called(ctx, sub) if len(ret) == 0 { panic("no return value specified for AddSubscription") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, clientID, sub) + if rf, ok := ret.Get(0).(func(context.Context, journal.ClientSubscription) error); ok { + r0 = rf(ctx, sub) } else { r0 = ret.Error(0) } @@ -34,6 +34,34 @@ func (_m *Repository) AddSubscription(ctx context.Context, clientID string, sub return r0 } +// CountSubscriptions provides a mock function with given fields: ctx, clientID +func (_m *Repository) CountSubscriptions(ctx context.Context, clientID string) (uint64, error) { + ret := _m.Called(ctx, clientID) + + if len(ret) == 0 { + panic("no return value specified for CountSubscriptions") + } + + var r0 uint64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (uint64, error)); ok { + return rf(ctx, clientID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) uint64); ok { + r0 = rf(ctx, clientID) + } else { + r0 = ret.Get(0).(uint64) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, clientID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // DeleteClientTelemetry provides a mock function with given fields: ctx, clientID, domainID func (_m *Repository) DeleteClientTelemetry(ctx context.Context, clientID string, domainID string) error { ret := _m.Called(ctx, clientID, domainID) @@ -88,35 +116,17 @@ func (_m *Repository) IncrementOutboundMessages(ctx context.Context, channelID s return r0 } -// RemoveSubscription provides a mock function with given fields: ctx, clientID, sub -func (_m *Repository) RemoveSubscription(ctx context.Context, clientID string, sub string) error { - ret := _m.Called(ctx, clientID, sub) +// RemoveSubscription provides a mock function with given fields: ctx, subscriberID +func (_m *Repository) RemoveSubscription(ctx context.Context, subscriberID string) error { + ret := _m.Called(ctx, subscriberID) if len(ret) == 0 { panic("no return value specified for RemoveSubscription") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, clientID, sub) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// RemoveSubscriptionWithConnID provides a mock function with given fields: ctx, connID, clientID -func (_m *Repository) RemoveSubscriptionWithConnID(ctx context.Context, connID string, clientID string) error { - ret := _m.Called(ctx, connID, clientID) - - if len(ret) == 0 { - panic("no return value specified for RemoveSubscriptionWithConnID") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, connID, clientID) + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, subscriberID) } else { r0 = ret.Error(0) } diff --git a/journal/postgres/init.go b/journal/postgres/init.go index 69d63b2125..00d22cf5e9 100644 --- a/journal/postgres/init.go +++ b/journal/postgres/init.go @@ -28,19 +28,25 @@ func Migration() *migrate.MemoryMigrationSource { `CREATE INDEX idx_journal_default_client_filter ON journal(operation, (attributes->>'id'), (attributes->>'client_id'), occurred_at DESC);`, `CREATE INDEX idx_journal_default_channel_filter ON journal(operation, (attributes->>'id'), (attributes->>'channel_id'), occurred_at DESC);`, `CREATE TABLE IF NOT EXISTS clients_telemetry ( - client_id VARCHAR(36) NOT NULL, + client_id VARCHAR(36) PRIMARY KEY, domain_id VARCHAR(36) NOT NULL, - subscriptions TEXT[] DEFAULT '{}', inbound_messages BIGINT DEFAULT 0, outbound_messages BIGINT DEFAULT 0, first_seen TIMESTAMP, - last_seen TIMESTAMP, - PRIMARY KEY (client_id, domain_id) + last_seen TIMESTAMP + )`, + `CREATE TABLE IF NOT EXISTS subscriptions ( + id VARCHAR(36) PRIMARY KEY, + subscriber_id VARCHAR(1024) NOT NULL, + channel_id VARCHAR(36) NOT NULL, + subtopic VARCHAR(1024), + client_id VARCHAR(36), + FOREIGN KEY (client_id) REFERENCES clients_telemetry(client_id) ON DELETE CASCADE ON UPDATE CASCADE )`, - `CREATE INDEX idx_subscriptions_gin ON clients_telemetry USING GIN (subscriptions);`, }, Down: []string{ `DROP TABLE IF EXISTS clients_telemetry`, + `DROP TABLE IF EXISTS subscriptions`, `DROP TABLE IF EXISTS journal`, }, }, diff --git a/journal/postgres/telemetry.go b/journal/postgres/telemetry.go index 8ed765d564..a820397897 100644 --- a/journal/postgres/telemetry.go +++ b/journal/postgres/telemetry.go @@ -16,8 +16,8 @@ import ( ) func (repo *repository) SaveClientTelemetry(ctx context.Context, ct journal.ClientTelemetry) error { - q := `INSERT INTO clients_telemetry (client_id, domain_id, inbound_messages, outbound_messages, subscriptions, first_seen, last_seen) - VALUES (:client_id, :domain_id, :inbound_messages, :outbound_messages, :subscriptions, :first_seen, :last_seen);` + q := `INSERT INTO clients_telemetry (client_id, domain_id, inbound_messages, outbound_messages, first_seen, last_seen) + VALUES (:client_id, :domain_id, :inbound_messages, :outbound_messages, :first_seen, :last_seen);` dbct, err := toDBClientsTelemetry(ct) if err != nil { @@ -32,7 +32,7 @@ func (repo *repository) SaveClientTelemetry(ctx context.Context, ct journal.Clie } func (repo *repository) DeleteClientTelemetry(ctx context.Context, clientID, domainID string) error { - q := "DELETE FROM clients_telemetry AS ct WHERE ct.client_id = :client_id AND ct.domain_id = :domain_id;" + q := `DELETE FROM clients_telemetry AS ct WHERE ct.client_id = :client_id AND ct.domain_id = :domain_id;` dbct := dbClientTelemetry{ ClientID: clientID, @@ -50,7 +50,7 @@ func (repo *repository) DeleteClientTelemetry(ctx context.Context, clientID, dom } func (repo *repository) RetrieveClientTelemetry(ctx context.Context, clientID, domainID string) (journal.ClientTelemetry, error) { - q := "SELECT * FROM clients_telemetry WHERE client_id = :client_id AND domain_id = :domain_id;" + q := `SELECT * FROM clients_telemetry WHERE client_id = :client_id AND domain_id = :domain_id;` dbct := dbClientTelemetry{ ClientID: clientID, @@ -80,24 +80,12 @@ func (repo *repository) RetrieveClientTelemetry(ctx context.Context, clientID, d return journal.ClientTelemetry{}, repoerr.ErrNotFound } -func (repo *repository) AddSubscription(ctx context.Context, clientID, sub string) error { - q := ` - UPDATE clients_telemetry - SET subscriptions = ARRAY_APPEND(subscriptions, :subscriptions), - last_seen = :last_seen - WHERE client_id = :client_id; +func (repo *repository) AddSubscription(ctx context.Context, sub journal.ClientSubscription) error { + q := `INSERT INTO subscriptions (id, subscriber_id, channel_id, subtopic, client_id) + VALUES (:id, :subscriber_id, :channel_id, :subtopic, :client_id); ` - ct := journal.ClientTelemetry{ - ClientID: clientID, - Subscriptions: []string{sub}, - LastSeen: time.Now(), - } - dbct, err := toDBClientsTelemetry(ct) - if err != nil { - return errors.Wrap(repoerr.ErrUpdateEntity, err) - } - result, err := repo.db.NamedExecContext(ctx, q, dbct) + result, err := repo.db.NamedExecContext(ctx, q, sub) if err != nil { return postgres.HandleError(repoerr.ErrUpdateEntity, err) } @@ -109,45 +97,29 @@ func (repo *repository) AddSubscription(ctx context.Context, clientID, sub strin return nil } -func (repo *repository) RemoveSubscription(ctx context.Context, clientID, sub string) error { - q := ` - UPDATE clients_telemetry - SET subscriptions = ARRAY_REMOVE(subscriptions, :subscriptions) - WHERE client_id = :client_id - AND (:subscriptions = ANY(subscriptions)) - ` - ct := journal.ClientTelemetry{ - ClientID: clientID, - Subscriptions: []string{sub}, - } - dbct, err := toDBClientsTelemetry(ct) - if err != nil { - return errors.Wrap(repoerr.ErrUpdateEntity, err) - } +func (repo *repository) CountSubscriptions(ctx context.Context, clientID string) (uint64, error) { + q := `SELECT COUNT(*) FROM subscriptions WHERE client_id = :client_id;` - result, err := repo.db.NamedExecContext(ctx, q, dbct) - if err != nil { - return postgres.HandleError(repoerr.ErrUpdateEntity, err) + sb := journal.ClientSubscription{ + ClientID: clientID, } - if rows, _ := result.RowsAffected(); rows == 0 { - return repoerr.ErrNotFound + total, err := postgres.Total(ctx, repo.db, q, sb) + if err != nil { + return 0, postgres.HandleError(repoerr.ErrViewEntity, err) } - return nil + return total, nil } -func (repo *repository) RemoveSubscriptionWithConnID(ctx context.Context, connID, clientID string) error { - q := ` - UPDATE clients_telemetry - SET subscriptions = ARRAY( - SELECT sub - FROM unnest(subscriptions) AS sub - WHERE sub NOT LIKE '%' || $1 || '%' - ) - WHERE client_id = $2; - ` - _, err := repo.db.ExecContext(ctx, q, connID, clientID) +func (repo *repository) RemoveSubscription(ctx context.Context, subscriberID string) error { + q := `DELETE FROM subscriptions WHERE subscriber_id = :subscriber_id;` + + sb := journal.ClientSubscription{ + SubscriberID: subscriberID, + } + + _, err := repo.db.NamedExecContext(ctx, q, sb) if err != nil { return postgres.HandleError(repoerr.ErrUpdateEntity, err) } @@ -185,43 +157,67 @@ func (repo *repository) IncrementInboundMessages(ctx context.Context, clientID s } func (repo *repository) IncrementOutboundMessages(ctx context.Context, channelID, subtopic string) error { - q := ` - WITH matched_clients AS ( - SELECT - client_id, - domain_id, - COUNT(*) AS match_count - FROM - clients_telemetry, - unnest(subscriptions) AS sub - WHERE - sub LIKE '%' || $1 || ':' || $2 || '%' - GROUP BY - client_id, domain_id - ) - UPDATE clients_telemetry - SET outbound_messages = outbound_messages + matched_clients.match_count - FROM matched_clients - WHERE clients_telemetry.client_id = matched_clients.client_id - AND clients_telemetry.domain_id = matched_clients.domain_id; + query := ` + SELECT client_id, COUNT(*) AS match_count + FROM subscriptions + WHERE channel_id = :channel_id AND subtopic = :subtopic + GROUP BY client_id ` + sb := journal.ClientSubscription{ + ChannelID: channelID, + Subtopic: subtopic, + } - _, err := repo.db.ExecContext(ctx, q, channelID, subtopic) + rows, err := repo.db.NamedQueryContext(ctx, query, sb) if err != nil { return postgres.HandleError(repoerr.ErrUpdateEntity, err) } + defer rows.Close() + + tx, err := repo.db.BeginTxx(ctx, nil) + if err != nil { + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + + q := `UPDATE clients_telemetry + SET outbound_messages = outbound_messages + $1 + WHERE client_id = $2; + ` + + for rows.Next() { + var clientID string + var count uint64 + if err = rows.Scan(&clientID, &count); err != nil { + err := tx.Rollback() + if err == nil { + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + return errors.Wrap(errors.ErrRollbackTx, err) + } + + if _, err = repo.db.ExecContext(ctx, q, count, clientID); err != nil { + err := tx.Rollback() + if err == nil { + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + return errors.Wrap(errors.ErrRollbackTx, err) + } + } + + if err = tx.Commit(); err != nil { + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } return nil } type dbClientTelemetry struct { - ClientID string `db:"client_id"` - DomainID string `db:"domain_id"` - Subscriptions pgtype.TextArray `db:"subscriptions"` - InboundMessages uint64 `db:"inbound_messages"` - OutboundMessages uint64 `db:"outbound_messages"` - FirstSeen time.Time `db:"first_seen"` - LastSeen sql.NullTime `db:"last_seen"` + ClientID string `db:"client_id"` + DomainID string `db:"domain_id"` + InboundMessages uint64 `db:"inbound_messages"` + OutboundMessages uint64 `db:"outbound_messages"` + FirstSeen time.Time `db:"first_seen"` + LastSeen sql.NullTime `db:"last_seen"` } func toDBClientsTelemetry(ct journal.ClientTelemetry) (dbClientTelemetry, error) { @@ -238,7 +234,6 @@ func toDBClientsTelemetry(ct journal.ClientTelemetry) (dbClientTelemetry, error) return dbClientTelemetry{ ClientID: ct.ClientID, DomainID: ct.DomainID, - Subscriptions: subs, InboundMessages: ct.InboundMessages, OutboundMessages: ct.OutboundMessages, FirstSeen: ct.FirstSeen, @@ -247,11 +242,6 @@ func toDBClientsTelemetry(ct journal.ClientTelemetry) (dbClientTelemetry, error) } func toClientsTelemetry(dbct dbClientTelemetry) (journal.ClientTelemetry, error) { - var subs []string - for _, e := range dbct.Subscriptions.Elements { - subs = append(subs, e.String) - } - var lastSeen time.Time if dbct.LastSeen.Valid { lastSeen = dbct.LastSeen.Time @@ -260,7 +250,6 @@ func toClientsTelemetry(dbct dbClientTelemetry) (journal.ClientTelemetry, error) return journal.ClientTelemetry{ ClientID: dbct.ClientID, DomainID: dbct.DomainID, - Subscriptions: subs, InboundMessages: dbct.InboundMessages, OutboundMessages: dbct.OutboundMessages, FirstSeen: dbct.FirstSeen, diff --git a/journal/service.go b/journal/service.go index 066093e3aa..9367392b7c 100644 --- a/journal/service.go +++ b/journal/service.go @@ -6,6 +6,7 @@ package journal import ( "context" "fmt" + "strings" "time" "github.com/absmach/supermq" @@ -15,19 +16,13 @@ import ( ) const ( - clientCreate = "client.create" - clientRemove = "client.remove" - coapSubscribe = "coap.client_subscribe" - coapUnsubscribe = "coap.client_unsubscribe" - coapPublish = "coap.client_publish" - httpPublish = "http.client_publish" - mqttSubscribe = "mqtt.client_subscribe" - mqttUnsubscribe = "mqtt.client_unsubscribe" - mqttPublish = "mqtt.client_publish" - mqttDisconnect = "mqtt.client_disconnect" - wsSubscribe = "ws.client_subscribe" - wsUnsubscribe = "ws.client_unsubscribe" - wsPublish = "ws.client_publish" + clientCreate = "client.create" + clientRemove = "client.remove" + mqttSubscribe = "mqtt.client_subscribe" + mqttDisconnect = "mqtt.client_disconnect" + messagingPublish = "messaging.client_publish" + messagingSubscribe = "messaging.client_subscribe" + messagingUnsubscribe = "messaging.client_unsubscribe" ) type service struct { @@ -74,6 +69,13 @@ func (svc *service) RetrieveClientTelemetry(ctx context.Context, session smqauth return ClientTelemetry{}, errors.Wrap(svcerr.ErrViewEntity, err) } + subs, err := svc.repository.CountSubscriptions(ctx, clientID) + if err != nil { + return ClientTelemetry{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + + ct.Subscriptions = subs + return ct, nil } @@ -85,17 +87,20 @@ func (svc *service) handleTelemetry(ctx context.Context, journal Journal) error case clientRemove: return svc.removeClientTelemetry(ctx, journal) - case coapSubscribe, mqttSubscribe, wsSubscribe: + case mqttSubscribe: + return svc.addMqttSubscription(ctx, journal) + + case messagingSubscribe: return svc.addSubscription(ctx, journal) - case coapUnsubscribe, mqttUnsubscribe, wsUnsubscribe: + case messagingUnsubscribe: return svc.removeSubscription(ctx, journal) - case coapPublish, httpPublish, mqttPublish, wsPublish: + case messagingPublish: return svc.updateMessageCount(ctx, journal) case mqttDisconnect: - return svc.removeSubscriptionWithConnID(ctx, journal) + return svc.removeMqttSubscription(ctx, journal) default: return nil @@ -125,44 +130,84 @@ func (svc *service) removeClientTelemetry(ctx context.Context, journal Journal) } func (svc *service) addSubscription(ctx context.Context, journal Journal) error { - ae, err := toAdapterEvent(journal) + ae, err := toSubscribeEvent(journal) if err != nil { return err } - sub := fmt.Sprintf("%s:%s:%s", ae.connID, ae.channelID, ae.topic) - return svc.repository.AddSubscription(ctx, ae.clientID, sub) -} + var subtopic string + topics := strings.Split(ae.topic, ".") + if len(topics) > 2 { + subtopic = topics[2] + } -func (svc *service) removeSubscription(ctx context.Context, journal Journal) error { - ae, err := toAdapterEvent(journal) + id, err := svc.idProvider.ID() if err != nil { return err } - sub := fmt.Sprintf("%s:%s:%s", ae.connID, ae.channelID, ae.topic) - return svc.repository.RemoveSubscription(ctx, ae.clientID, sub) + + sub := ClientSubscription{ + ID: id, + SubscriberID: ae.subscriberID, + ChannelID: topics[1], + Subtopic: subtopic, + ClientID: ae.clientID, + } + + return svc.repository.AddSubscription(ctx, sub) } -func (svc *service) updateMessageCount(ctx context.Context, journal Journal) error { - ae, err := toAdapterEvent(journal) +func (svc *service) addMqttSubscription(ctx context.Context, journal Journal) error { + ae, err := toMqttSubscribeEvent(journal) if err != nil { return err } - if err := svc.repository.IncrementInboundMessages(ctx, ae.clientID); err != nil { + + id, err := svc.idProvider.ID() + if err != nil { return err } - if err := svc.repository.IncrementOutboundMessages(ctx, ae.channelID, ae.topic); err != nil { + + sub := ClientSubscription{ + ID: id, + SubscriberID: ae.subscriberID, + ChannelID: ae.channelID, + Subtopic: ae.subtopic, + ClientID: ae.clientID, + } + + return svc.repository.AddSubscription(ctx, sub) +} + +func (svc *service) removeSubscription(ctx context.Context, journal Journal) error { + ae, err := toUnsubscribeEvent(journal) + if err != nil { return err } - return nil + + return svc.repository.RemoveSubscription(ctx, ae.subscriberID) } -func (svc *service) removeSubscriptionWithConnID(ctx context.Context, journal Journal) error { - ae, err := toAdapterEvent(journal) +func (svc *service) removeMqttSubscription(ctx context.Context, journal Journal) error { + ae, err := toMqttDisconnectEvent(journal) if err != nil { return err } - return svc.repository.RemoveSubscriptionWithConnID(ctx, ae.connID, ae.clientID) + return svc.repository.RemoveSubscription(ctx, ae.subscriberID) +} + +func (svc *service) updateMessageCount(ctx context.Context, journal Journal) error { + ae, err := toPublishEvent(journal) + if err != nil { + return err + } + if err := svc.repository.IncrementInboundMessages(ctx, ae.clientID); err != nil { + return err + } + if err := svc.repository.IncrementOutboundMessages(ctx, ae.channelID, ae.subtopic); err != nil { + return err + } + return nil } type clientEvent struct { @@ -174,14 +219,15 @@ type clientEvent struct { func toClientEvent(journal Journal) (clientEvent, error) { var createdAt time.Time var err error - id, ok := journal.Attributes["id"].(string) - if !ok { - return clientEvent{}, fmt.Errorf("invalid id attribute") + id, err := getStringAttribute(journal, "id") + if err != nil { + return clientEvent{}, err } - domain, ok := journal.Attributes["domain"].(string) - if !ok { - return clientEvent{}, fmt.Errorf("invalid domain attribute") + domain, err := getStringAttribute(journal, "domain") + if err != nil { + return clientEvent{}, err } + createdAtStr := journal.Attributes["created_at"].(string) if createdAtStr != "" { createdAt, err = time.Parse(time.RFC3339, createdAtStr) @@ -197,21 +243,118 @@ func toClientEvent(journal Journal) (clientEvent, error) { } type adapterEvent struct { - clientID string - connID string - channelID string - topic string + clientID string + channelID string + subscriberID string + topic string + subtopic string } -func toAdapterEvent(journal Journal) (adapterEvent, error) { - clientID := journal.Attributes["client_id"].(string) - connID := journal.Attributes["conn_id"].(string) - channelID := journal.Attributes["channel_id"].(string) - topic := journal.Attributes["topic"].(string) +func toPublishEvent(journal Journal) (adapterEvent, error) { + clientID, err := getStringAttribute(journal, "client_id") + if err != nil { + return adapterEvent{}, err + } + channelID, err := getStringAttribute(journal, "channel_id") + if err != nil { + return adapterEvent{}, err + } + subtopic, err := getStringAttribute(journal, "subtopic") + if err != nil { + return adapterEvent{}, err + } + return adapterEvent{ clientID: clientID, - connID: connID, channelID: channelID, - topic: topic, + subtopic: subtopic, + }, nil +} + +func toSubscribeEvent(journal Journal) (adapterEvent, error) { + subscriberID, err := getStringAttribute(journal, "subscriber_id") + if err != nil { + return adapterEvent{}, err + } + topic, err := getStringAttribute(journal, "topic") + if err != nil { + return adapterEvent{}, err + } + var clientID string + clientID, err = getStringAttribute(journal, "client_id") + if err != nil { + clientID = "" + } + + return adapterEvent{ + clientID: clientID, + subscriberID: subscriberID, + topic: topic, + }, nil +} + +func toUnsubscribeEvent(journal Journal) (adapterEvent, error) { + subscriberID, err := getStringAttribute(journal, "subscriber_id") + if err != nil { + return adapterEvent{}, err + } + topic, err := getStringAttribute(journal, "topic") + if err != nil { + return adapterEvent{}, err + } + + return adapterEvent{ + subscriberID: subscriberID, + topic: topic, + }, nil +} + +func toMqttSubscribeEvent(journal Journal) (adapterEvent, error) { + clientID, err := getStringAttribute(journal, "client_id") + if err != nil { + return adapterEvent{}, err + } + subscriberID, err := getStringAttribute(journal, "subscriber_id") + if err != nil { + return adapterEvent{}, err + } + channelID, err := getStringAttribute(journal, "channel_id") + if err != nil { + return adapterEvent{}, err + } + subtopic, err := getStringAttribute(journal, "subtopic") + if err != nil { + return adapterEvent{}, err + } + + return adapterEvent{ + clientID: clientID, + subscriberID: subscriberID, + channelID: channelID, + subtopic: subtopic, + }, nil +} + +func toMqttDisconnectEvent(journal Journal) (adapterEvent, error) { + subscriberID, err := getStringAttribute(journal, "subscriber_id") + if err != nil { + return adapterEvent{}, err + } + clientID, err := getStringAttribute(journal, "client_id") + if err != nil { + return adapterEvent{}, err + } + + return adapterEvent{ + subscriberID: subscriberID, + channelID: clientID, }, nil } + +func getStringAttribute(journal Journal, key string) (string, error) { + value, ok := journal.Attributes[key].(string) + if !ok { + return "", fmt.Errorf("missing or invalid %s attribute", key) + } + return value, nil +} diff --git a/pkg/messaging/events/events.go b/pkg/messaging/events/events.go index 12f6ce3df4..ceead22212 100644 --- a/pkg/messaging/events/events.go +++ b/pkg/messaging/events/events.go @@ -35,13 +35,15 @@ func (pe publishEvent) Encode() (map[string]interface{}, error) { type subscribeEvent struct { operation string subscriberID string - subtopic string + clientID string + topic string } func (se subscribeEvent) Encode() (map[string]interface{}, error) { return map[string]interface{}{ "operation": se.operation, "subscriber_id": se.subscriberID, - "subtopic": se.subtopic, + "client_id": se.clientID, + "topic": se.topic, }, nil } diff --git a/pkg/messaging/events/pubsub.go b/pkg/messaging/events/pubsub.go index 8e792ae891..657f24908b 100644 --- a/pkg/messaging/events/pubsub.go +++ b/pkg/messaging/events/pubsub.go @@ -54,7 +54,8 @@ func (es *pubsubES) Subscribe(ctx context.Context, cfg messaging.SubscriberConfi se := subscribeEvent{ operation: clientSubscribe, subscriberID: cfg.ID, - subtopic: cfg.Topic, + clientID: cfg.ClientID, + topic: cfg.Topic, } return es.ep.Publish(ctx, se) @@ -68,7 +69,7 @@ func (es *pubsubES) Unsubscribe(ctx context.Context, id string, topic string) er se := subscribeEvent{ operation: clientUnsubscribe, subscriberID: id, - subtopic: topic, + topic: topic, } return es.ep.Publish(ctx, se) diff --git a/pkg/messaging/pubsub.go b/pkg/messaging/pubsub.go index 393de64fef..acdc0e146e 100644 --- a/pkg/messaging/pubsub.go +++ b/pkg/messaging/pubsub.go @@ -36,6 +36,7 @@ type MessageHandler interface { type SubscriberConfig struct { ID string + ClientID string Topic string Handler MessageHandler DeliveryPolicy DeliveryPolicy diff --git a/ws/adapter.go b/ws/adapter.go index f92fe15074..02c4cfe39e 100644 --- a/ws/adapter.go +++ b/ws/adapter.go @@ -75,9 +75,10 @@ func (svc *adapterService) Subscribe(ctx context.Context, clientKey, chanID, sub } subCfg := messaging.SubscriberConfig{ - ID: clientID, - Topic: subject, - Handler: c, + ID: clientID, + ClientID: clientID, + Topic: subject, + Handler: c, } if err := svc.pubsub.Subscribe(ctx, subCfg); err != nil { return ErrFailedSubscription