Skip to content

Commit

Permalink
Merge pull request #1420 from openmeterio/fix-notification-api
Browse files Browse the repository at this point in the history
fix: update of channel assignment for rules
  • Loading branch information
chrisgacsal authored Aug 23, 2024
2 parents f02274e + f15561a commit 5693abe
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 8 deletions.
1 change: 1 addition & 0 deletions internal/notification/repository/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ func (r repository) UpdateRule(ctx context.Context, params notification.UpdateRu
SetDisabled(params.Disabled).
SetConfig(params.Config).
SetName(params.Name).
ClearChannels().
AddChannelIDs(params.Channels...)

queryRow, err := query.Save(ctx)
Expand Down
14 changes: 14 additions & 0 deletions internal/notification/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ type CreateRuleInput struct {
Channels []string
}

const MaxChannelsPerRule = 5

func (i CreateRuleInput) Validate(ctx context.Context, service Service) error {
if i.Namespace == "" {
return ValidationError{
Expand All @@ -315,6 +317,12 @@ func (i CreateRuleInput) Validate(ctx context.Context, service Service) error {
return err
}

if len(i.Channels) > MaxChannelsPerRule {
return ValidationError{
Err: fmt.Errorf("too many channels: %d > %d", len(i.Channels), MaxChannelsPerRule),
}
}

return nil
}

Expand Down Expand Up @@ -386,6 +394,12 @@ func (i UpdateRuleInput) Validate(ctx context.Context, service Service) error {
}
}

if len(i.Channels) > MaxChannelsPerRule {
return ValidationError{
Err: fmt.Errorf("too many channels: %d > %d", len(i.Channels), MaxChannelsPerRule),
}
}

return nil
}

Expand Down
100 changes: 93 additions & 7 deletions internal/notification/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@ import (
"fmt"
"log/slog"

"github.com/samber/lo"

"github.com/openmeterio/openmeter/internal/notification/webhook"
"github.com/openmeterio/openmeter/internal/productcatalog"
"github.com/openmeterio/openmeter/pkg/convert"
"github.com/openmeterio/openmeter/pkg/models"
"github.com/openmeterio/openmeter/pkg/pagination"
)

type Service interface {
Expand Down Expand Up @@ -52,13 +56,19 @@ type FeatureService interface {

var _ Service = (*service)(nil)

const (
ChannelIDMetadataKey = "om-channel-id"
)

type service struct {
feature productcatalog.FeatureConnector

repo Repository
webhook webhook.Handler

eventHandler EventHandler

logger *slog.Logger
}

func (c service) Close() error {
Expand Down Expand Up @@ -105,6 +115,7 @@ func New(config Config) (Service, error) {
feature: config.FeatureConnector,
webhook: config.Webhook,
eventHandler: eventHandler,
logger: config.Logger,
}, nil
}

Expand Down Expand Up @@ -157,6 +168,10 @@ func (c service) CreateChannel(ctx context.Context, params CreateChannelInput) (
CustomHeaders: headers,
Disabled: channel.Disabled,
Secret: &channel.Config.WebHook.SigningSecret,
Metadata: map[string]string{
ChannelIDMetadataKey: channel.ID,
},
Description: convert.ToPointer("Notification Channel: " + channel.ID),
})
if err != nil {
return nil, fmt.Errorf("failed to create webhook for channel: %w", err)
Expand Down Expand Up @@ -379,18 +394,89 @@ func (c service) UpdateRule(ctx context.Context, params UpdateRuleInput) (*Rule,
return nil, fmt.Errorf("invalid params: %w", err)
}

logger := c.logger.WithGroup("rule").With(
"operation", "update",
"id", params.ID,
"namespace", params.Namespace,
)

rule, err := c.repo.GetRule(ctx, GetRuleInput{
ID: params.ID,
Namespace: params.Namespace,
})
if err != nil {
return nil, fmt.Errorf("failed to get rule: %w", err)
}

if rule.DeletedAt != nil {
return nil, UpdateAfterDeleteError{
Err: errors.New("not allowed to update deleted rule"),
}
}

// Get list of channel IDs currently assigned to rule
oldChannelIDs := lo.Map(rule.Channels, func(channel Channel, _ int) string {
return channel.ID
})
logger.Debug("currently assigned channels", "channels", oldChannelIDs)

// Calculate channels diff for the update
channelIDsDiff := NewChannelIDsDifference(params.Channels, oldChannelIDs)

logger.WithGroup("channels").Debug("difference in channels assignment",
"changed", channelIDsDiff.HasChanged(),
"additions", channelIDsDiff.Additions(),
"removals", channelIDsDiff.Removals(),
)

// We can return early ff there is no change in the list of channels assigned to rule.
if !channelIDsDiff.HasChanged() {
return c.repo.UpdateRule(ctx, params)
}

txFunc := func(ctx context.Context, repo TxRepository) (*Rule, error) {
channel, err := repo.GetRule(ctx, GetRuleInput{
ID: params.ID,
Namespace: params.Namespace,
// Fetch all the channels from repo which are either added or removed from rule
channels, err := repo.ListChannels(ctx, ListChannelsInput{
Page: pagination.Page{
// In order to avoid under-fetching. There cannot be more affected channels than
// twice as the maximum number of allowed channels per rule.
PageSize: 2 * MaxChannelsPerRule,
PageNumber: 1,
},
Namespaces: []string{params.Namespace},
Channels: channelIDsDiff.All(),
IncludeDisabled: true,
})
if err != nil {
return nil, fmt.Errorf("failed to get rule: %w", err)
return nil, fmt.Errorf("failed to list channels for rule: %w", err)
}
logger.Debug("fetched all affected channels", "channels", channels.Items)

// Update affected channels
for _, channel := range channels.Items {
switch channel.Type {
case ChannelTypeWebhook:
input := webhook.UpdateWebhookChannelsInput{
Namespace: params.Namespace,
ID: channel.ID,
}

if channelIDsDiff.InAdditions(channel.ID) {
input.AddChannels = []string{rule.ID}
}

if channelIDsDiff.InRemovals(channel.ID) {
input.RemoveChannels = []string{rule.ID}
}

logger.Debug("updating webhook for channel", "id", channel.ID, "input", input)

if channel.DeletedAt != nil {
return nil, UpdateAfterDeleteError{
Err: errors.New("not allowed to update deleted rule"),
_, err = c.webhook.UpdateWebhookChannels(ctx, input)
if err != nil {
return nil, fmt.Errorf("failed to update webhook for channel: %w", err)
}
default:
return nil, fmt.Errorf("invalid channel type: %s", channel.Type)
}
}

Expand Down
103 changes: 102 additions & 1 deletion internal/notification/utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package notification

import "fmt"
import (
"fmt"

"github.com/samber/lo"
)

// ChannelTypes returns a set of ChannelType from Channel slice
func ChannelTypes(channels []Channel) []ChannelType {
Expand Down Expand Up @@ -90,3 +94,100 @@ func InterfaceMapToStringMap(m map[string]interface{}) map[string]string {

return s
}

type difference[T comparable] struct {
leftMap map[T]struct{}
left []T

rightMap map[T]struct{}
right []T
}

func (d difference[T]) Has(item T) bool {
return d.HasLeft(item) || d.HasRight(item)
}

func (d difference[T]) HasLeft(item T) bool {
if _, ok := d.leftMap[item]; ok {
return true
}

return false
}

func (d difference[T]) HasRight(item T) bool {
if _, ok := d.rightMap[item]; ok {
return true
}

return false
}

func (d difference[T]) Left() []T {
return d.left
}

func (d difference[T]) Right() []T {
return d.right
}

func (d difference[T]) HasChanged() bool {
return len(d.left) > 0 || len(d.right) > 0
}

func (d difference[T]) All() []T {
return append(d.left, d.right...)
}

type ChannelIDsDifference struct {
diff difference[string]
}

func (d ChannelIDsDifference) Has(id string) bool {
return d.diff.Has(id)
}

func (d ChannelIDsDifference) HasChanged() bool {
return d.diff.HasChanged()
}

func (d ChannelIDsDifference) InAdditions(id string) bool {
return d.diff.HasLeft(id)
}

func (d ChannelIDsDifference) InRemovals(id string) bool {
return d.diff.HasRight(id)
}

func (d ChannelIDsDifference) Additions() []string {
return d.diff.Left()
}

func (d ChannelIDsDifference) Removals() []string {
return d.diff.Right()
}

func (d ChannelIDsDifference) All() []string {
return d.diff.All()
}

func NewChannelIDsDifference(new, old []string) *ChannelIDsDifference {
left, right := lo.Difference(new, old)

leftMap := lo.SliceToMap(left, func(item string) (string, struct{}) {
return item, struct{}{}
})

rightMap := lo.SliceToMap(right, func(item string) (string, struct{}) {
return item, struct{}{}
})

return &ChannelIDsDifference{
diff: difference[string]{
leftMap: leftMap,
left: left,
rightMap: rightMap,
right: right,
},
}
}
85 changes: 85 additions & 0 deletions internal/notification/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package notification

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestChannelIDsDifference(t *testing.T) {
tests := []struct {
Name string

New []string
Old []string

ExpectedAdditions []string
ExpectedRemovals []string
}{
{
Name: "No change",
New: []string{
"channel-1",
"channel-2",
},
Old: []string{
"channel-2",
"channel-1",
},
},
{
Name: "Add new channel",
New: []string{
"channel-1",
"channel-2",
"channel-3",
},
Old: []string{
"channel-2",
"channel-1",
},
ExpectedAdditions: []string{
"channel-3",
},
},
{
Name: "Remove old channel",
New: []string{
"channel-2",
},
Old: []string{
"channel-2",
"channel-1",
},
ExpectedRemovals: []string{
"channel-1",
},
},
{
Name: "Add and remove channels",
New: []string{
"channel-1",
"channel-3",
},
Old: []string{
"channel-2",
"channel-1",
},
ExpectedAdditions: []string{
"channel-3",
},
ExpectedRemovals: []string{
"channel-2",
},
},
}

for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
diff := NewChannelIDsDifference(test.New, test.Old)

assert.ElementsMatch(t, test.ExpectedAdditions, diff.Additions())
assert.ElementsMatch(t, test.ExpectedRemovals, diff.Removals())
})
}
}

0 comments on commit 5693abe

Please sign in to comment.