diff --git a/components/cqrs/command_config.go b/components/cqrs/command_config.go index 1a0dc59e1..e0a83d354 100644 --- a/components/cqrs/command_config.go +++ b/components/cqrs/command_config.go @@ -18,6 +18,12 @@ type CommandConfig struct { OnSend OnCommandSendFn OnHandle OnCommandHandleFn + // RequestReplyEnabled enables request-reply pattern for commands. + // Reply is sent **just** from the CommandBus.SendAndWait method. + // This configuration doesn't affect CommandBus.Send method. + RequestReplyEnabled bool + RequestReplyBackend RequestReplyBackend + Marshaler CommandEventMarshaler Logger watermill.LoggerAdapter @@ -45,6 +51,10 @@ func (c CommandConfig) Validate() error { err = stdErrors.Join(err, errors.New("missing Marshaler")) } + if c.RequestReplyEnabled && c.RequestReplyBackend == nil { + err = stdErrors.Join(err, errors.New("missing RequestReply.Backend")) + } + return err } diff --git a/components/cqrs/command_processor.go b/components/cqrs/command_processor.go index ef6d01cf5..d1acf9216 100644 --- a/components/cqrs/command_processor.go +++ b/components/cqrs/command_processor.go @@ -1,6 +1,7 @@ package cqrs import ( + stdErrors "errors" "fmt" "github.com/pkg/errors" @@ -188,10 +189,28 @@ func (p CommandProcessor) routerHandlerFunc(handler CommandHandler, logger water Message: msg, }) + var replyErr error + // todo: test + if p.config.RequestReplyEnabled { + replyErr = p.config.RequestReplyBackend.OnCommandProcessed(msg, cmd, err) + } + + if p.config.AckCommandHandlingErrors && err != nil { + // we want to nack if we are using request-reply, + // and we failed to send information about failure + // todo: test + if replyErr != nil { + return replyErr + } + logger.Error("Error when handling command", err, nil) return nil + } else if replyErr != nil { + // todo: test + err = stdErrors.Join(err, replyErr) } + if err != nil { logger.Debug("Error when handling command", watermill.LogFields{"err": err}) return err diff --git a/components/cqrs/command_request_reply.go b/components/cqrs/command_request_reply.go new file mode 100644 index 000000000..8098d1a02 --- /dev/null +++ b/components/cqrs/command_request_reply.go @@ -0,0 +1,40 @@ +package cqrs + +import ( + "context" + "fmt" + "time" + + "github.com/ThreeDotsLabs/watermill/message" +) + +type RequestReplyBackend interface { + ModifyCommandMessageBeforePublish(cmdMsg *message.Message, command any) error + + ListenForReply(ctx context.Context, cmdMsg *message.Message, cmd any) (<-chan CommandReply, error) + + OnCommandProcessed(cmdMsg *message.Message, cmd any, handleErr error) error +} + +// ReplyTimeoutError is returned when the reply timeout is exceeded. +type ReplyTimeoutError struct { + Duration time.Duration + Err error +} + +func (e ReplyTimeoutError) Error() string { + return fmt.Sprintf("reply timeout after %s: %s", e.Duration, e.Err) +} + +// CommandHandlerError is returned when the command handler returns an error. +type CommandHandlerError struct { + Err error +} + +func (e CommandHandlerError) Error() string { + return e.Err.Error() +} + +func (e CommandHandlerError) Unwrap() error { + return e.Err +} diff --git a/components/cqrs/command_request_reply_bus.go b/components/cqrs/command_request_reply_bus.go new file mode 100644 index 000000000..acfcfc7e9 --- /dev/null +++ b/components/cqrs/command_request_reply_bus.go @@ -0,0 +1,51 @@ +package cqrs + +import ( + "context" + + "github.com/ThreeDotsLabs/watermill/message" + "github.com/pkg/errors" +) + +type CommandReply struct { + // HandlerErr contains the error returned by the command handler or by RequestReplyBackend if sending reply failed. + // + // If error from handler is returned, CommandHandlerError is returned. + // If listening for reply timed out, HandlerErr is ReplyTimeoutError. + // If processing was successful, HandlerErr is nil. + Err error + + // ReplyMsg contains the reply message from the command handler. + // Warning: ReplyMsg is nil if timeout occurred. + ReplyMsg *message.Message +} + +// todo: test +// todo: add cancel func? +// SendAndWait sends command to the command bus and waits for the command execution. +func (c CommandBus) SendAndWait(ctx context.Context, cmd interface{}) (<-chan CommandReply, error) { + if !c.config.RequestReplyEnabled { + return nil, errors.New("RequestReply is not enabled in config") + } + + msg, topicName, err := c.newMessage(ctx, cmd) + if err != nil { + return nil, err + } + + if err := c.config.RequestReplyBackend.ModifyCommandMessageBeforePublish(msg, cmd); err != nil { + return nil, errors.Wrap(err, "cannot modify command message before publish") + } + + // todo: wait for 1 reply by default? + replyChan, err := c.config.RequestReplyBackend.ListenForReply(ctx, msg, cmd) + if err != nil { + return nil, errors.Wrap(err, "cannot listen for reply") + } + + if err := c.publisher.Publish(topicName, msg); err != nil { + return nil, err + } + + return replyChan, nil +} diff --git a/components/cqrs/command_request_reply_pubsub.go b/components/cqrs/command_request_reply_pubsub.go new file mode 100644 index 000000000..06fddf2be --- /dev/null +++ b/components/cqrs/command_request_reply_pubsub.go @@ -0,0 +1,280 @@ +package cqrs + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/ThreeDotsLabs/watermill" + "github.com/ThreeDotsLabs/watermill/message" + "github.com/pkg/errors" +) + +//go:generate protoc --proto_path=. command_request_reply_pubsub.proto --go_out=. --go_opt=paths=source_relative --go-grpc_opt=require_unimplemented_servers=false --go-grpc_out=. --go-grpc_opt=paths=source_relative + +type PubSubRequestReplyMarshaler interface { + Marshal(v interface{}) (*message.Message, error) + Unmarshal(msg *message.Message, v interface{}) (err error) +} + +// PubSubRequestReply is a RequestReplyBackend that uses Pub/Sub to transport commands and replies. +type PubSubRequestReply struct { + config PubSubRequestReplyConfig +} + +// NewPubSubRequestReply creates a new PubSubRequestReply. +func NewPubSubRequestReply(config PubSubRequestReplyConfig) (*PubSubRequestReply, error) { + config.setDefaults() + + if err := config.Validate(); err != nil { + return nil, errors.Wrap(err, "invalid config") + } + + return &PubSubRequestReply{ + config: config, + }, nil +} + +type PubSubRequestReplySubscriberContext struct { + CommandMessage *message.Message + Command any +} + +type PubSubRequestReplyOnCommandProcessedContext struct { + HandleErr error + + PubSubRequestReplySubscriberContext +} + +type PubSubRequestReplySubscriberConstructorFn func(PubSubRequestReplySubscriberContext) (message.Subscriber, error) + +type PubSubRequestReplyTopicGeneratorFn func(PubSubRequestReplySubscriberContext) (string, error) + +type PubSubRequestReplyConfig struct { + Publisher message.Publisher + SubscriberConstructor PubSubRequestReplySubscriberConstructorFn + GenerateReplyNotificationTopic PubSubRequestReplyTopicGeneratorFn + + Marshaler PubSubRequestReplyMarshaler + + Logger watermill.LoggerAdapter + + ListenForReplyTimeout *time.Duration + + ModifyNotificationMessage func(msg *message.Message, context PubSubRequestReplyOnCommandProcessedContext) error + + OnListenForReplyFinished func(context.Context, PubSubRequestReplySubscriberContext) +} + +func (p *PubSubRequestReplyConfig) setDefaults() { + if p.Logger == nil { + p.Logger = watermill.NopLogger{} + } + + if p.Marshaler == nil { + p.Marshaler = JSONMarshaler{} + } +} + +func (p *PubSubRequestReplyConfig) Validate() error { + if p.Publisher == nil { + return errors.New("publisher cannot be nil") + } + if p.SubscriberConstructor == nil { + return errors.New("subscriber constructor cannot be nil") + } + if p.GenerateReplyNotificationTopic == nil { + return errors.New("GenerateReplyNotificationTopic cannot be nil") + } + + return nil +} + +const notifyWhenExecutedMetadataKey = "_watermill_notify_when_executed" + +func (p PubSubRequestReply) ModifyCommandMessageBeforePublish(cmdMsg *message.Message, command any) error { + cmdMsg.Metadata.Set(notifyWhenExecutedMetadataKey, "1") + + return nil +} + +func (p PubSubRequestReply) ListenForReply( + ctx context.Context, + cmdMsg *message.Message, + command any, +) (<-chan CommandReply, error) { + if !p.isRequestReplyEnabled(cmdMsg) { + return nil, errors.Errorf( + "RequestReply is enabled, but %s metadata is '%s' in command msg", + notifyWhenExecutedMetadataKey, + cmdMsg.Metadata.Get(notifyWhenExecutedMetadataKey), + ) + } + + start := time.Now() + + replyContext := PubSubRequestReplySubscriberContext{ + CommandMessage: cmdMsg, + Command: command, + } + + // this needs to be done before publishing the message to avoid race condition + notificationsSubscriber, err := p.config.SubscriberConstructor(replyContext) + if err != nil { + return nil, errors.Wrap(err, "cannot create request/reply notifications subscriber") + } + + replyNotificationTopic, err := p.config.GenerateReplyNotificationTopic(replyContext) + if err != nil { + return nil, errors.Wrap(err, "cannot generate request/reply notifications topic") + } + + var cancel context.CancelFunc + if p.config.ListenForReplyTimeout != nil { + ctx, cancel = context.WithTimeout(ctx, *p.config.ListenForReplyTimeout) + } else { + ctx, cancel = context.WithCancel(ctx) + } + + notifyMsgs, err := notificationsSubscriber.Subscribe(ctx, replyNotificationTopic) + if err != nil { + cancel() + return nil, errors.Wrap(err, "cannot subscribe to request/reply notifications topic") + } + + p.config.Logger.Debug( + "Subscribed to request/reply notifications topic", + watermill.LogFields{ + "request_reply_topic": replyNotificationTopic, + }, + ) + + replyChan := make(chan CommandReply, 1) + + go func() { + defer func() { + if p.config.OnListenForReplyFinished == nil { + return + } + + p.config.OnListenForReplyFinished(ctx, replyContext) + }() + defer close(replyChan) + defer cancel() + + for { + select { + case <-ctx.Done(): + replyChan <- CommandReply{ + Err: ReplyTimeoutError{time.Since(start), ctx.Err()}, + } + return + case notifyMsg, ok := <-notifyMsgs: + if !ok { + // subscriber is closed + replyChan <- CommandReply{ + Err: ReplyTimeoutError{time.Since(start), fmt.Errorf("subscriber closed")}, + } + return + } + + if ok, handlerErr := p.handleNotifyMsg(notifyMsg, cmdMsg.UUID); ok { + reply := CommandReply{ + ReplyMsg: notifyMsg, + } + if handlerErr != nil { + reply.Err = CommandHandlerError{handlerErr} + } + + replyChan <- reply + continue + } + } + } + }() + + return replyChan, nil +} + +const HandledCommandMessageUuidMetadataKey = "_watermill_command_message_uuid" + +func (p PubSubRequestReply) OnCommandProcessed(cmdMsg *message.Message, cmd any, handleErr error) error { + if !p.isRequestReplyEnabled(cmdMsg) { + p.config.Logger.Debug(fmt.Sprintf("RequestReply is enabled, but %s is missing", notifyWhenExecutedMetadataKey), nil) + return nil + } + + p.config.Logger.Debug("Sending request reply", nil) + + // we are using protobuf message, so it will work both with proto and json marshaler + notification := &RequestReplyNotification{} + if handleErr != nil { + notification.Error = handleErr.Error() + notification.HasError = true + } + + notificationMsg, err := p.config.Marshaler.Marshal(notification) + if err != nil { + return errors.Wrap(err, "cannot marshal request reply notification") + } + + notificationMsg.SetContext(cmdMsg.Context()) + notificationMsg.Metadata.Set(HandledCommandMessageUuidMetadataKey, cmdMsg.UUID) + + if p.config.ModifyNotificationMessage != nil { + processedContext := PubSubRequestReplyOnCommandProcessedContext{ + HandleErr: handleErr, + PubSubRequestReplySubscriberContext: PubSubRequestReplySubscriberContext{ + CommandMessage: cmdMsg, + Command: cmd, + }, + } + if err := p.config.ModifyNotificationMessage(notificationMsg, processedContext); err != nil { + return errors.Wrap(err, "cannot modify notification message") + } + } + + replyTopic, err := p.config.GenerateReplyNotificationTopic(PubSubRequestReplySubscriberContext{ + CommandMessage: cmdMsg, + Command: cmd, + }) + if err != nil { + return errors.Wrap(err, "cannot generate request/reply notify topic") + } + + if err := p.config.Publisher.Publish(replyTopic, notificationMsg); err != nil { + return errors.Wrap(err, "cannot publish command executed message") + } + + return nil +} + +func (p PubSubRequestReply) isRequestReplyEnabled(cmdMsg *message.Message) bool { + notificationEnabled := cmdMsg.Metadata.Get(notifyWhenExecutedMetadataKey) + enabled, _ := strconv.ParseBool(notificationEnabled) + + return enabled +} + +func (p PubSubRequestReply) handleNotifyMsg(msg *message.Message, expectedCommandUuid string) (bool, error) { + defer msg.Ack() + + if msg.Metadata.Get(HandledCommandMessageUuidMetadataKey) != expectedCommandUuid { + // todo: test + p.config.Logger.Debug("Received notify message with different command UUID", nil) + return false, nil + } + + // we are using protobuf message, so it will work both with proto and json marshaler + notification := &RequestReplyNotification{} + if err := p.config.Marshaler.Unmarshal(msg, notification); err != nil { + return false, errors.Wrap(err, "cannot unmarshal request reply notification") + } + + if notification.HasError { + return true, errors.New(notification.Error) + } else { + return true, nil + } +} diff --git a/components/cqrs/command_request_reply_pubsub.pb.go b/components/cqrs/command_request_reply_pubsub.pb.go new file mode 100644 index 000000000..ab945b5db --- /dev/null +++ b/components/cqrs/command_request_reply_pubsub.pb.go @@ -0,0 +1,156 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.28.1 +// protoc v3.21.12 +// source: command_request_reply_pubsub.proto + +package cqrs + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type RequestReplyNotification struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Error string `protobuf:"bytes,1,opt,name=error,proto3" json:"error,omitempty"` + HasError bool `protobuf:"varint,2,opt,name=has_error,json=hasError,proto3" json:"has_error,omitempty"` +} + +func (x *RequestReplyNotification) Reset() { + *x = RequestReplyNotification{} + if protoimpl.UnsafeEnabled { + mi := &file_command_request_reply_pubsub_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RequestReplyNotification) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RequestReplyNotification) ProtoMessage() {} + +func (x *RequestReplyNotification) ProtoReflect() protoreflect.Message { + mi := &file_command_request_reply_pubsub_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RequestReplyNotification.ProtoReflect.Descriptor instead. +func (*RequestReplyNotification) Descriptor() ([]byte, []int) { + return file_command_request_reply_pubsub_proto_rawDescGZIP(), []int{0} +} + +func (x *RequestReplyNotification) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +func (x *RequestReplyNotification) GetHasError() bool { + if x != nil { + return x.HasError + } + return false +} + +var File_command_request_reply_pubsub_proto protoreflect.FileDescriptor + +var file_command_request_reply_pubsub_proto_rawDesc = []byte{ + 0x0a, 0x22, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x5f, 0x72, 0x65, 0x70, 0x6c, 0x79, 0x5f, 0x70, 0x75, 0x62, 0x73, 0x75, 0x62, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x63, 0x71, 0x72, 0x73, 0x22, 0x4d, 0x0a, 0x18, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x4e, 0x6f, 0x74, 0x69, 0x66, 0x69, + 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x1b, 0x0a, 0x09, + 0x68, 0x61, 0x73, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x08, 0x68, 0x61, 0x73, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x42, 0x34, 0x5a, 0x32, 0x67, 0x69, 0x74, + 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x54, 0x68, 0x72, 0x65, 0x65, 0x44, 0x6f, 0x74, + 0x73, 0x4c, 0x61, 0x62, 0x73, 0x2f, 0x77, 0x61, 0x74, 0x65, 0x72, 0x6d, 0x69, 0x6c, 0x6c, 0x2f, + 0x63, 0x6f, 0x6d, 0x70, 0x6f, 0x6e, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x63, 0x71, 0x72, 0x73, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_command_request_reply_pubsub_proto_rawDescOnce sync.Once + file_command_request_reply_pubsub_proto_rawDescData = file_command_request_reply_pubsub_proto_rawDesc +) + +func file_command_request_reply_pubsub_proto_rawDescGZIP() []byte { + file_command_request_reply_pubsub_proto_rawDescOnce.Do(func() { + file_command_request_reply_pubsub_proto_rawDescData = protoimpl.X.CompressGZIP(file_command_request_reply_pubsub_proto_rawDescData) + }) + return file_command_request_reply_pubsub_proto_rawDescData +} + +var file_command_request_reply_pubsub_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_command_request_reply_pubsub_proto_goTypes = []interface{}{ + (*RequestReplyNotification)(nil), // 0: cqrs.RequestReplyNotification +} +var file_command_request_reply_pubsub_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_command_request_reply_pubsub_proto_init() } +func file_command_request_reply_pubsub_proto_init() { + if File_command_request_reply_pubsub_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_command_request_reply_pubsub_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RequestReplyNotification); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_command_request_reply_pubsub_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_command_request_reply_pubsub_proto_goTypes, + DependencyIndexes: file_command_request_reply_pubsub_proto_depIdxs, + MessageInfos: file_command_request_reply_pubsub_proto_msgTypes, + }.Build() + File_command_request_reply_pubsub_proto = out.File + file_command_request_reply_pubsub_proto_rawDesc = nil + file_command_request_reply_pubsub_proto_goTypes = nil + file_command_request_reply_pubsub_proto_depIdxs = nil +} diff --git a/components/cqrs/command_request_reply_pubsub.proto b/components/cqrs/command_request_reply_pubsub.proto new file mode 100644 index 000000000..d9c0cb30e --- /dev/null +++ b/components/cqrs/command_request_reply_pubsub.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; +package cqrs; +option go_package = "github.com/ThreeDotsLabs/watermill/components/cqrs"; + +message RequestReplyNotification { + string error = 1; + bool has_error = 2; +} diff --git a/components/cqrs/command_request_reply_pubsub_test.go b/components/cqrs/command_request_reply_pubsub_test.go new file mode 100644 index 000000000..d93bc01d0 --- /dev/null +++ b/components/cqrs/command_request_reply_pubsub_test.go @@ -0,0 +1,280 @@ +package cqrs_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/ThreeDotsLabs/watermill/components/cqrs" + "github.com/ThreeDotsLabs/watermill/message" + "github.com/ThreeDotsLabs/watermill/pubsub/gochannel" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPubSubRequestReply(t *testing.T) { + ts := NewTestServices() + + requestReplyPubSub := gochannel.NewGoChannel( + gochannel.Config{}, + ts.Logger, + ) + + cmdMsg := message.NewMessage("1", []byte("foo")) + command := &TestCommand{} + + handlerErr := fmt.Errorf("some error") + + onListenForReplyFinishedCalled := make(chan struct{}) + + pubSubBackend, err := cqrs.NewPubSubRequestReply(cqrs.PubSubRequestReplyConfig{ + Publisher: requestReplyPubSub, + SubscriberConstructor: func(subscriberContext cqrs.PubSubRequestReplySubscriberContext) (message.Subscriber, error) { + assert.EqualValues(t, command, subscriberContext.Command) + assert.True(t, subscriberContext.CommandMessage.Equals(cmdMsg)) + + return requestReplyPubSub, nil + }, + GenerateReplyNotificationTopic: func(subscriberContext cqrs.PubSubRequestReplySubscriberContext) (string, error) { + assert.EqualValues(t, command, subscriberContext.Command) + assert.True(t, subscriberContext.CommandMessage.Equals(cmdMsg)) + + return "reply", nil + }, + Marshaler: cqrs.JSONMarshaler{}, + Logger: ts.Logger, + ModifyNotificationMessage: func(msg *message.Message, context cqrs.PubSubRequestReplyOnCommandProcessedContext) error { + // to make it deterministic + msg.UUID = "1" + return nil + }, + OnListenForReplyFinished: func(ctx context.Context, subscriberContext cqrs.PubSubRequestReplySubscriberContext) { + assert.EqualValues(t, command, subscriberContext.Command) + assert.True(t, subscriberContext.CommandMessage.Equals(cmdMsg)) + + close(onListenForReplyFinishedCalled) + }, + }) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + + err = pubSubBackend.ModifyCommandMessageBeforePublish(cmdMsg, command) + require.NoError(t, err) + + repliesCh, err := pubSubBackend.ListenForReply( + ctx, + cmdMsg, + command, + ) + require.NoError(t, err) + + err = pubSubBackend.ModifyCommandMessageBeforePublish(cmdMsg, command) + require.NoError(t, err) + + err = pubSubBackend.OnCommandProcessed(cmdMsg, command, handlerErr) + require.NoError(t, err) + + select { + case reply := <-repliesCh: + assert.EqualValues(t, "1", reply.ReplyMsg.UUID) + assert.EqualValues( + t, + `{"error":"some error","has_error":true}`, + string(reply.ReplyMsg.Payload), + ) + assert.EqualError(t, reply.Err, handlerErr.Error()) + case <-time.After(1 * time.Second): + require.Fail(t, "timeout") + } + + cancel() + + select { + case <-onListenForReplyFinishedCalled: + // ok + case <-time.After(1 * time.Second): + require.Fail(t, "timeout waiting for OnListenForReplyFinished") + } +} + +func TestPubSubRequestReply_timeout(t *testing.T) { + ts := NewTestServices() + + requestReplyPubSub := gochannel.NewGoChannel( + gochannel.Config{}, + ts.Logger, + ) + + cmdMsg := message.NewMessage("1", []byte("foo")) + command := &TestCommand{} + + timeout := time.Millisecond * 1 + + pubSubBackend, err := cqrs.NewPubSubRequestReply(cqrs.PubSubRequestReplyConfig{ + Publisher: requestReplyPubSub, + SubscriberConstructor: func(subscriberContext cqrs.PubSubRequestReplySubscriberContext) (message.Subscriber, error) { + return requestReplyPubSub, nil + }, + GenerateReplyNotificationTopic: func(subscriberContext cqrs.PubSubRequestReplySubscriberContext) (string, error) { + return "reply", nil + }, + Marshaler: cqrs.JSONMarshaler{}, + Logger: ts.Logger, + ListenForReplyTimeout: &timeout, + }) + require.NoError(t, err) + + err = pubSubBackend.ModifyCommandMessageBeforePublish(cmdMsg, command) + require.NoError(t, err) + + repliesCh, err := pubSubBackend.ListenForReply( + context.Background(), + cmdMsg, + command, + ) + require.NoError(t, err) + + err = pubSubBackend.ModifyCommandMessageBeforePublish(cmdMsg, command) + require.NoError(t, err) + + select { + case reply := <-repliesCh: + assert.ErrorContains(t, reply.Err, "reply timeout after") + assert.ErrorContains(t, reply.Err, "context deadline exceeded") + case <-time.After(1 * time.Second): + require.Fail(t, "timeout") + } +} + +func TestPubSubRequestReply_timeout_context_cancelled(t *testing.T) { + ts := NewTestServices() + + requestReplyPubSub := gochannel.NewGoChannel( + gochannel.Config{}, + ts.Logger, + ) + + cmdMsg := message.NewMessage("1", []byte("foo")) + command := &TestCommand{} + + pubSubBackend, err := cqrs.NewPubSubRequestReply(cqrs.PubSubRequestReplyConfig{ + Publisher: requestReplyPubSub, + SubscriberConstructor: func(subscriberContext cqrs.PubSubRequestReplySubscriberContext) (message.Subscriber, error) { + return requestReplyPubSub, nil + }, + GenerateReplyNotificationTopic: func(subscriberContext cqrs.PubSubRequestReplySubscriberContext) (string, error) { + return "reply", nil + }, + Marshaler: cqrs.JSONMarshaler{}, + Logger: ts.Logger, + }) + require.NoError(t, err) + + err = pubSubBackend.ModifyCommandMessageBeforePublish(cmdMsg, command) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + + repliesCh, err := pubSubBackend.ListenForReply( + ctx, + cmdMsg, + command, + ) + require.NoError(t, err) + + err = pubSubBackend.ModifyCommandMessageBeforePublish(cmdMsg, command) + require.NoError(t, err) + + cancel() + + select { + case reply := <-repliesCh: + assert.ErrorContains(t, reply.Err, "reply timeout after") + assert.ErrorContains(t, reply.Err, "context canceled") + case <-time.After(1 * time.Second): + require.Fail(t, "timeout") + } +} + +func TestPubSubRequestReply_ListenForReply_unsupported_message(t *testing.T) { + ts := NewTestServices() + + requestReplyPubSub := gochannel.NewGoChannel( + gochannel.Config{}, + ts.Logger, + ) + + cmdMsg := message.NewMessage("1", []byte("foo")) + command := &TestCommand{} + + pubSubBackend, err := cqrs.NewPubSubRequestReply(cqrs.PubSubRequestReplyConfig{ + Publisher: requestReplyPubSub, + SubscriberConstructor: func(subscriberContext cqrs.PubSubRequestReplySubscriberContext) (message.Subscriber, error) { + return requestReplyPubSub, nil + }, + GenerateReplyNotificationTopic: func(subscriberContext cqrs.PubSubRequestReplySubscriberContext) (string, error) { + return "reply", nil + }, + Marshaler: cqrs.JSONMarshaler{}, + Logger: ts.Logger, + }) + require.NoError(t, err) + + repliesCh, err := pubSubBackend.ListenForReply( + context.Background(), + cmdMsg, + command, + ) + assert.Empty(t, repliesCh) + assert.EqualError(t, err, "RequestReply is enabled, but _watermill_notify_when_executed metadata is '' in command msg") +} + +func TestPubSubRequestReply_unsupported_message_received(t *testing.T) { + ts := NewTestServices() + + requestReplyPubSub := gochannel.NewGoChannel( + gochannel.Config{}, + ts.Logger, + ) + + cmdMsg := message.NewMessage("1", []byte("foo")) + command := &TestCommand{} + + pubSubBackend, err := cqrs.NewPubSubRequestReply(cqrs.PubSubRequestReplyConfig{ + Publisher: requestReplyPubSub, + SubscriberConstructor: func(subscriberContext cqrs.PubSubRequestReplySubscriberContext) (message.Subscriber, error) { + return requestReplyPubSub, nil + }, + GenerateReplyNotificationTopic: func(subscriberContext cqrs.PubSubRequestReplySubscriberContext) (string, error) { + return "reply", nil + }, + Marshaler: cqrs.JSONMarshaler{}, + Logger: ts.Logger, + }) + require.NoError(t, err) + + err = pubSubBackend.ModifyCommandMessageBeforePublish(cmdMsg, command) + require.NoError(t, err) + + repliesCh, err := pubSubBackend.ListenForReply( + context.Background(), + cmdMsg, + command, + ) + require.NoError(t, err) + + // this msg has no _watermill_notify_when_executed metadata - should be ignored + invalidCommandMsg := message.NewMessage("1", []byte("foo")) + + err = pubSubBackend.OnCommandProcessed(invalidCommandMsg, command, nil) + require.NoError(t, err) + + select { + case reply := <-repliesCh: + t.Fatalf("no reply should be sent, but received %#v", reply) + case <-time.After(time.Millisecond * 10): + // ok + } +} diff --git a/components/cqrs/command_request_reply_test.go b/components/cqrs/command_request_reply_test.go new file mode 100644 index 000000000..2fc865ba7 --- /dev/null +++ b/components/cqrs/command_request_reply_test.go @@ -0,0 +1,334 @@ +package cqrs_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/ThreeDotsLabs/watermill/components/cqrs" + "github.com/ThreeDotsLabs/watermill/message" + "github.com/ThreeDotsLabs/watermill/pubsub/gochannel" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestRequestReply is a functional test for request-reply command handling. +func TestRequestReply_functional(t *testing.T) { + ts := NewTestServices() + + requestReplyPubSub := gochannel.NewGoChannel( + gochannel.Config{BlockPublishUntilSubscriberAck: false}, + ts.Logger, + ) + + pubSubBackend, err := cqrs.NewPubSubRequestReply(cqrs.PubSubRequestReplyConfig{ + Publisher: requestReplyPubSub, + SubscriberConstructor: func(subscriberContext cqrs.PubSubRequestReplySubscriberContext) (message.Subscriber, error) { + return requestReplyPubSub, nil + }, + GenerateReplyNotificationTopic: func(subscriberContext cqrs.PubSubRequestReplySubscriberContext) (string, error) { + return "reply", nil + }, + Marshaler: cqrs.JSONMarshaler{}, + Logger: ts.Logger, + ModifyNotificationMessage: func(msg *message.Message, context cqrs.PubSubRequestReplyOnCommandProcessedContext) error { + // to make it deterministic + msg.UUID = "1" + return nil + }, + }) + require.NoError(t, err) + + mockBackend := &RequestReplyBackendMock{ + Replies: []cqrs.CommandReply{ + { + Err: nil, + ReplyMsg: message.NewMessage("1", []byte("foo")), + }, + }, + } + + testCases := []struct { + Name string + Backend cqrs.RequestReplyBackend + ExpectedReply cqrs.CommandReply + }{ + { + Name: "mock", + Backend: mockBackend, + ExpectedReply: mockBackend.Replies[0], + }, + { + Name: "pubsub", + Backend: pubSubBackend, + ExpectedReply: cqrs.CommandReply{ + Err: nil, + ReplyMsg: message.NewMessage("1", []byte("{}")), + }, + }, + } + for i := range testCases { + tc := testCases[i] + + t.Run(tc.Name, func(t *testing.T) { + commandBus := runCommandProcessorWithRequestReply(t, ts, tc.Backend) + + repliesCh, err := commandBus.SendAndWait(context.Background(), &TestCommand{}) + require.NoError(t, err) + + select { + case reply := <-repliesCh: + assert.EqualValues(t, tc.ExpectedReply.Err, reply.Err) + assert.EqualValues(t, tc.ExpectedReply.ReplyMsg.UUID, reply.ReplyMsg.UUID) + assert.EqualValues(t, tc.ExpectedReply.ReplyMsg.Payload, reply.ReplyMsg.Payload) + case <-time.After(time.Second): + t.Fatal("timeout") + } + }) + } +} + +func TestRequestReply_modify_command_message_before_publish(t *testing.T) { + ts := NewTestServices() + + replies := []cqrs.CommandReply{ + { + Err: fmt.Errorf("test error"), + ReplyMsg: message.NewMessage("1", nil), + }, + } + mockBackend := &RequestReplyBackendMock{ + Replies: replies, + CustomModifyCommandMessageBeforePublish: func(cmdMsg *message.Message, command any) error { + assert.EqualValues(t, &TestCommand{}, command) + cmdMsg.Metadata.Set("foo", "bar") + return nil + }, + CustomOnCommandProcessed: func(cmdMsg *message.Message, command any, handleErr error) error { + assert.Equal(t, "bar", cmdMsg.Metadata.Get("foo")) + return nil + }, + } + + commandBus := runCommandProcessorWithRequestReply(t, ts, mockBackend) + + repliesCh, err := commandBus.SendAndWait(context.Background(), &TestCommand{}) + require.NoError(t, err) + + select { + case reply := <-repliesCh: + assert.EqualValues(t, replies[0], reply) + + case <-time.After(time.Second): + t.Fatal("timeout on first reply") + } +} + +func TestRequestReply_on_command_processed(t *testing.T) { + ts := NewTestServices() + + onCommandProcessedCalled := false + + expectedHandlerErr := fmt.Errorf("test error") + + replies := []cqrs.CommandReply{ + { + Err: expectedHandlerErr, + ReplyMsg: message.NewMessage("1", nil), + }, + } + mockBackend := &RequestReplyBackendMock{ + Replies: replies, + CustomOnCommandProcessed: func(cmdMsg *message.Message, command any, handleErr error) error { + onCommandProcessedCalled = true + require.NotEmpty(t, cmdMsg) + assert.Equal(t, "cqrs_test.TestCommand", cmdMsg.Metadata.Get("name")) + assert.EqualValues(t, &TestCommand{}, command) + assert.Equal(t, expectedHandlerErr, handleErr) + return nil + }, + } + + commandBus := runCommandProcessorWithRequestReplyWithHandler( + t, + ts, + mockBackend, + func(ctx context.Context, cmd *TestCommand) error { + return expectedHandlerErr + }, + ) + + repliesCh, err := commandBus.SendAndWait(context.Background(), &TestCommand{}) + require.NoError(t, err) + + select { + case reply := <-repliesCh: + assert.EqualValues(t, replies[0], reply) + + case <-time.After(time.Second): + t.Fatal("timeout on first reply") + } + + assert.True(t, onCommandProcessedCalled) +} + +func TestRequestReply_reply_with_error(t *testing.T) { + ts := NewTestServices() + + replies := []cqrs.CommandReply{ + { + Err: fmt.Errorf("test error"), + ReplyMsg: message.NewMessage("1", nil), + }, + } + mockBackend := &RequestReplyBackendMock{ + Replies: replies, + } + + commandBus := runCommandProcessorWithRequestReply(t, ts, mockBackend) + + repliesCh, err := commandBus.SendAndWait(context.Background(), &TestCommand{}) + require.NoError(t, err) + + select { + case reply := <-repliesCh: + assert.EqualValues(t, replies[0], reply) + case <-time.After(time.Second): + t.Fatal("timeout on first reply") + } +} + +func TestRequestReply_multiple_replies(t *testing.T) { + ts := NewTestServices() + + replies := []cqrs.CommandReply{ + { + Err: nil, + ReplyMsg: message.NewMessage("1", nil), + }, + { + Err: nil, + ReplyMsg: message.NewMessage("2", nil), + }, + } + mockBackend := &RequestReplyBackendMock{ + Replies: replies, + } + + commandBus := runCommandProcessorWithRequestReply(t, ts, mockBackend) + + repliesCh, err := commandBus.SendAndWait(context.Background(), &TestCommand{}) + require.NoError(t, err) + + select { + case reply := <-repliesCh: + assert.EqualValues(t, replies[0], reply) + case <-time.After(time.Second): + t.Fatal("timeout on first reply") + } + + select { + case reply := <-repliesCh: + assert.EqualValues(t, replies[1], reply) + case <-time.After(time.Second): + t.Fatal("timeout on second reply") + } +} + +func runCommandProcessorWithRequestReply(t *testing.T, ts TestServices, mockBackend cqrs.RequestReplyBackend) *cqrs.CommandBus { + t.Helper() + + return runCommandProcessorWithRequestReplyWithHandler( + t, + ts, + mockBackend, + func(ctx context.Context, cmd *TestCommand) error { + return nil + }, + ) +} + +func runCommandProcessorWithRequestReplyWithHandler( + t *testing.T, + ts TestServices, + mockBackend cqrs.RequestReplyBackend, + handler func(ctx context.Context, cmd *TestCommand) error, +) *cqrs.CommandBus { + t.Helper() + + router, err := message.NewRouter(message.RouterConfig{}, ts.Logger) + require.NoError(t, err) + + commandConfig := cqrs.CommandConfig{ + GenerateTopic: func(params cqrs.GenerateCommandTopicParams) (string, error) { + return "commands", nil + }, + SubscriberConstructor: func(params cqrs.CommandsSubscriberConstructorParams) (message.Subscriber, error) { + return ts.CommandsPubSub, nil + }, + Marshaler: ts.Marshaler, + Logger: ts.Logger, + RequestReplyEnabled: true, + RequestReplyBackend: mockBackend, + AckCommandHandlingErrors: true, + } + + commandProcessor, err := cqrs.NewCommandProcessorWithConfig(commandConfig) + require.NoError(t, err) + + commandProcessor.AddHandler(cqrs.NewCommandHandler( + "command_handler", + handler, + )) + + err = commandProcessor.AddHandlersToRouter(router) + require.NoError(t, err) + + go func() { + err = router.Run(context.Background()) + assert.NoError(t, err) + }() + + <-router.Running() + + commandBus, err := cqrs.NewCommandBusWithConfig(ts.CommandsPubSub, commandConfig) + require.NoError(t, err) + + return commandBus +} + +type RequestReplyBackendMock struct { + Replies []cqrs.CommandReply + + CustomModifyCommandMessageBeforePublish func(cmdMsg *message.Message, command any) error + CustomOnCommandProcessed func(cmdMsg *message.Message, command any, handleErr error) error +} + +func (r RequestReplyBackendMock) ModifyCommandMessageBeforePublish(cmdMsg *message.Message, command any) error { + if r.CustomModifyCommandMessageBeforePublish != nil { + return r.CustomModifyCommandMessageBeforePublish(cmdMsg, command) + } + + return nil +} + +func (r RequestReplyBackendMock) ListenForReply(ctx context.Context, cmdMsg *message.Message, cmd any) (<-chan cqrs.CommandReply, error) { + out := make(chan cqrs.CommandReply, len(r.Replies)) + + for _, reply := range r.Replies { + out <- reply + } + + close(out) + + return out, nil +} + +func (r RequestReplyBackendMock) OnCommandProcessed(cmdMsg *message.Message, cmd any, handleErr error) error { + if r.CustomOnCommandProcessed != nil { + return r.CustomOnCommandProcessed(cmdMsg, cmd, handleErr) + } + + return nil +}