diff --git a/pool/subscriber.go b/pool/subscriber.go index 50e7d98..6a357f4 100644 --- a/pool/subscriber.go +++ b/pool/subscriber.go @@ -163,10 +163,11 @@ func (s *Subscriber) RegisterBatchHandler(handler *BatchHandler) { s.log.WithFields(withConsumerIfSet(handler.ConsumeOptions().ConsumerTag, map[string]any{ - "subscriber": s.pool.Name(), - "queue": opts.Queue, - "maxBatchSize": opts.MaxBatchSize, // TODO: optimize so that we don't call getters multiple times (mutex contention) - "flushTimeout": handler.FlushTimeout, + "subscriber": s.pool.Name(), + "queue": opts.Queue, + "maxBatchBytes": opts.MaxBatchBytes, + "maxBatchSize": opts.MaxBatchSize, + "flushTimeout": opts.FlushTimeout, })).Info("registered batch message handler") } @@ -445,7 +446,14 @@ func (s *Subscriber) batchConsume(h *BatchHandler) (err error) { // There is no way to recover form this state in case an error is returned from the Nack call. nackErr := batch[len(batch)-1].Nack(true, true) if nackErr != nil { - s.warnBatchHandler(opts.ConsumerTag, opts.Queue, opts.MaxBatchSize, err, "failed to nack and requeue batch upon shutdown") + s.warnBatchHandler( + opts.ConsumerTag, + opts.Queue, + opts.MaxBatchSize, + opts.MaxBatchBytes, + err, + "failed to nack and requeue batch upon shutdown", + ) } } }() @@ -456,10 +464,12 @@ func (s *Subscriber) batchConsume(h *BatchHandler) (err error) { ) defer closeTimer(timer, &drained) + var batchBytes = 0 for { // reset batch slice // reuse memory batch = batch[:0] + batchBytes = 0 collectBatch: for { @@ -474,8 +484,14 @@ func (s *Subscriber) batchConsume(h *BatchHandler) (err error) { if !ok { return ErrDeliveryClosed } + + batchBytes += len(msg.Body) batch = append(batch, msg) - if len(batch) == opts.MaxBatchSize { + if opts.MaxBatchSize > 0 && len(batch) == opts.MaxBatchSize { + break collectBatch + } + + if opts.MaxBatchBytes > 0 && batchBytes >= opts.MaxBatchBytes { break collectBatch } @@ -485,6 +501,13 @@ func (s *Subscriber) batchConsume(h *BatchHandler) (err error) { if len(batch) > 0 { // timeout reached, process batch that might not contain // a full batch, yet. + s.infoBatchHandler( + opts.ConsumerTag, + opts.Queue, + len(batch), + batchBytes, + "flush timeout reached", + ) break collectBatch } @@ -498,18 +521,29 @@ func (s *Subscriber) batchConsume(h *BatchHandler) (err error) { lastDeliveryTag = batch[len(batch)-1].DeliveryTag ) - s.infoBatchHandler(opts.ConsumerTag, opts.Queue, batchSize, "received batch") + s.infoBatchHandler(opts.ConsumerTag, opts.Queue, batchSize, batchBytes, "received batch") err = opts.HandlerFunc(batch) // no acks required if opts.AutoAck { if err != nil { // we cannot really do anything to recover from a processing error in this case - s.errorBatchHandler(opts.ConsumerTag, opts.Queue, batchSize, fmt.Errorf("processing failed: dropping batch: %w", err)) + s.errorBatchHandler(opts.ConsumerTag, + opts.Queue, + batchSize, + batchBytes, + fmt.Errorf("processing failed: dropping batch: %w", err), + ) } else { - s.infoBatchHandler(opts.ConsumerTag, opts.Queue, batchSize, "processed batch") + s.infoBatchHandler( + opts.ConsumerTag, + opts.Queue, + batchSize, + batchBytes, + "processed batch", + ) } } else { - poolErr := s.ackBatchPostHandle(opts, lastDeliveryTag, batchSize, session, err) + poolErr := s.ackBatchPostHandle(opts, lastDeliveryTag, batchSize, batchBytes, session, err) if poolErr != nil { return poolErr } @@ -517,7 +551,7 @@ func (s *Subscriber) batchConsume(h *BatchHandler) (err error) { } } -func (s *Subscriber) ackBatchPostHandle(opts BatchHandlerConfig, lastDeliveryTag uint64, currentBatchSize int, session *Session, handlerErr error) (err error) { +func (s *Subscriber) ackBatchPostHandle(opts BatchHandlerConfig, lastDeliveryTag uint64, currentBatchSize, currentBatchBytes int, session *Session, handlerErr error) (err error) { var ackErr error // processing failed if handlerErr != nil { @@ -531,7 +565,14 @@ func (s *Subscriber) ackBatchPostHandle(opts BatchHandlerConfig, lastDeliveryTag // if (n)ack fails, we know that the connection died // potentially before processing already. if ackErr != nil { - s.warnBatchHandler(opts.ConsumerTag, opts.Queue, currentBatchSize, ackErr, "batch (n)ack failed") + s.warnBatchHandler( + opts.ConsumerTag, + opts.Queue, + currentBatchSize, + currentBatchBytes, + ackErr, + "batch (n)ack failed", + ) poolErr := session.Recover() if poolErr != nil { // only returns an error upon shutdown @@ -545,9 +586,21 @@ func (s *Subscriber) ackBatchPostHandle(opts BatchHandlerConfig, lastDeliveryTag // (n)acked successfully if handlerErr != nil { - s.infoBatchHandler(opts.ConsumerTag, opts.Queue, currentBatchSize, "nacked batch") + s.infoBatchHandler( + opts.ConsumerTag, + opts.Queue, + currentBatchSize, + currentBatchBytes, + "nacked batch", + ) } else { - s.infoBatchHandler(opts.ConsumerTag, opts.Queue, currentBatchSize, "acked batch") + s.infoBatchHandler( + opts.ConsumerTag, + opts.Queue, + currentBatchSize, + currentBatchBytes, + "acked batch", + ) } // successfully handled message return nil @@ -586,28 +639,31 @@ func (s *Subscriber) catchShutdown() <-chan struct{} { return s.ctx.Done() } -func (s *Subscriber) infoBatchHandler(consumer, queue string, batchSize int, a ...any) { +func (s *Subscriber) infoBatchHandler(consumer, queue string, batchSize, batchBytes int, a ...any) { s.log.WithFields(withConsumerIfSet(consumer, map[string]any{ "batchSize": batchSize, + "batchBytes": batchBytes, "subscriber": s.pool.Name(), "queue": queue, })).Info(a...) } -func (s *Subscriber) warnBatchHandler(consumer, queue string, batchSize int, err error, a ...any) { +func (s *Subscriber) warnBatchHandler(consumer, queue string, batchSize, batchBytes int, err error, a ...any) { s.log.WithFields(withConsumerIfSet(consumer, map[string]any{ "batchSize": batchSize, + "batchBytes": batchBytes, "subscriber": s.pool.Name(), "queue": queue, "error": err, })).Warn(a...) } -func (s *Subscriber) errorBatchHandler(consumer, queue string, batchSize int, err error, a ...any) { +func (s *Subscriber) errorBatchHandler(consumer, queue string, batchSize, batchBytes int, err error, a ...any) { s.log.WithFields(withConsumerIfSet(consumer, map[string]any{ "batchSize": batchSize, + "batchBytes": batchBytes, "subscriber": s.pool.Name(), "queue": queue, "error": err, diff --git a/pool/subscriber_batch_handler.go b/pool/subscriber_batch_handler.go index 4c09749..b8bcc58 100644 --- a/pool/subscriber_batch_handler.go +++ b/pool/subscriber_batch_handler.go @@ -21,11 +21,12 @@ func NewBatchHandler(queue string, hf BatchHandlerFunc, options ...BatchHandlerO // sane defaults h := &BatchHandler{ - sc: newStateContext(context.Background()), - queue: queue, - handlerFunc: hf, - maxBatchSize: defaultMaxBatchSize, - flushTimeout: defaultFlushTimeout, + sc: newStateContext(context.Background()), + queue: queue, + handlerFunc: hf, + maxBatchSize: defaultMaxBatchSize, + maxBatchBytes: 0, // unlimited by default + flushTimeout: defaultFlushTimeout, consumeOpts: ConsumeOptions{ ConsumerTag: "", AutoAck: false, @@ -55,6 +56,11 @@ type BatchHandler struct { // before processing is triggered maxBatchSize int + // In case that the batch size exceeds this limit, the batch is passed to the handler function. + // This indicates that a batch will contains at least one message for processing. + // If the value is set to 0, the batch size is unlimited. + maxBatchBytes int + // FlushTimeout is the duration that is waited for the next message from a queue before // the batch is closed and passed for processing. // This value should be less than 30m (which is the (n)ack timeout of RabbitMQ) @@ -69,8 +75,15 @@ type BatchHandlerConfig struct { Queue string ConsumeOptions - HandlerFunc BatchHandlerFunc + HandlerFunc BatchHandlerFunc + + // Maximum number of messages MaxBatchSize int + + // Maximum size of a batch in bytes (soft limit which triggers a batch to be processed) + // does not guarantee that the batch size is not exceeded. + MaxBatchBytes int + FlushTimeout time.Duration } @@ -110,6 +123,7 @@ func (h *BatchHandler) configUnguarded() BatchHandlerConfig { Queue: h.queue, HandlerFunc: h.handlerFunc, MaxBatchSize: h.maxBatchSize, + MaxBatchBytes: h.maxBatchBytes, FlushTimeout: h.flushTimeout, ConsumeOptions: h.consumeOpts, } @@ -186,22 +200,31 @@ func (h *BatchHandler) ConsumeOptions() ConsumeOptions { func (h *BatchHandler) SetConsumeOptions(consumeOpts ConsumeOptions) { h.mu.Lock() defer h.mu.Unlock() - h.consumeOpts = consumeOpts + WithBatchConsumeOptions(consumeOpts)(h) } func (h *BatchHandler) MaxBatchSize() int { - h.mu.Lock() - defer h.mu.Unlock() + h.mu.RLock() + defer h.mu.RUnlock() return h.maxBatchSize } func (h *BatchHandler) SetMaxBatchSize(maxBatchSize int) { h.mu.Lock() defer h.mu.Unlock() - if maxBatchSize <= 0 { - maxBatchSize = defaultMaxBatchSize - } - h.maxBatchSize = maxBatchSize + WithMaxBatchSize(maxBatchSize)(h) +} + +func (h *BatchHandler) MaxBatchBytes() int { + h.mu.RLock() + defer h.mu.RUnlock() + return h.maxBatchBytes +} + +func (h *BatchHandler) SetMaxBatchBytes(maxBatchBytes int) { + h.mu.Lock() + defer h.mu.Unlock() + WithMaxBatchBytes(maxBatchBytes)(h) } func (h *BatchHandler) FlushTimeout() time.Duration { @@ -213,8 +236,5 @@ func (h *BatchHandler) FlushTimeout() time.Duration { func (h *BatchHandler) SetFlushTimeout(flushTimeout time.Duration) { h.mu.Lock() defer h.mu.Unlock() - if flushTimeout <= 0 { - flushTimeout = defaultFlushTimeout - } - h.flushTimeout = flushTimeout + WithBatchFlushTimeout(flushTimeout)(h) } diff --git a/pool/subscriber_handler_options.go b/pool/subscriber_handler_options.go index eb6ce22..21635b8 100644 --- a/pool/subscriber_handler_options.go +++ b/pool/subscriber_handler_options.go @@ -4,16 +4,47 @@ import "time" type BatchHandlerOption func(*BatchHandler) +// WithMaxBatchSize sets the maximum size of a batch. +// If set to 0 the batch size is not limited. +// This means that the batch size is only limited by the maximum batch size in bytes. func WithMaxBatchSize(size int) BatchHandlerOption { return func(bh *BatchHandler) { - if size <= 0 { + + switch { + case and(size <= 0, bh.maxBatchBytes == 0): + // we need to set a sane default bh.maxBatchSize = defaultMaxBatchSize - } else { + case and(size <= 0, bh.maxBatchBytes > 0): + // we need to set the batch size to unlimited + // because the batch is limited by bytes + bh.maxBatchSize = 0 + case size > 0: + // we need to set the batch size to the provided value bh.maxBatchSize = size } } } +// WithMaxBatchBytes sets the maximum size of a batch in bytes. +// If the batch size exceeds this limit, the batch is passed to the handler function. +// If the value is set to 0, the batch size is not limited by bytes. +func WithMaxBatchBytes(size int) BatchHandlerOption { + return func(bh *BatchHandler) { + switch { + case and(size <= 0, bh.maxBatchSize == 0): + // do not change the current value + return + case and(size <= 0, bh.maxBatchSize > 0): + // we need to set the batch size to unlimited + // because the batch is limited by number of messages + bh.maxBatchBytes = 0 + case size > 0: + // we need to set the batch size to the provided value + bh.maxBatchBytes = size + } + } +} + func WithBatchFlushTimeout(d time.Duration) BatchHandlerOption { return func(bh *BatchHandler) { if d <= 0 { @@ -29,3 +60,12 @@ func WithBatchConsumeOptions(opts ConsumeOptions) BatchHandlerOption { bh.consumeOpts = opts } } + +func and(b ...bool) bool { + for _, v := range b { + if !v { + return false + } + } + return true +} diff --git a/pool/subscriber_handler_options_test.go b/pool/subscriber_handler_options_test.go new file mode 100644 index 0000000..5c2c79d --- /dev/null +++ b/pool/subscriber_handler_options_test.go @@ -0,0 +1,29 @@ +package pool + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWithMaxBatchSize(t *testing.T) { + dummyHandler := func([]Delivery) error { return nil } + bh := NewBatchHandler("test", dummyHandler, WithMaxBatchSize(0), WithMaxBatchBytes(0)) + + assert.Equal(t, defaultMaxBatchSize, bh.MaxBatchSize()) + assert.Equal(t, 0, bh.MaxBatchBytes()) + + bh = NewBatchHandler("test", dummyHandler, WithMaxBatchBytes(0), WithMaxBatchSize(0)) + assert.Equal(t, defaultMaxBatchSize, bh.MaxBatchSize()) + assert.Equal(t, 0, bh.MaxBatchBytes()) + + bh = NewBatchHandler("test", dummyHandler, WithMaxBatchBytes(1), WithMaxBatchSize(1)) + assert.Equal(t, 1, bh.MaxBatchSize()) + assert.Equal(t, 1, bh.MaxBatchBytes()) + + // if you want to set specific limits to infinite, you may first set all the != 0 options and then set the + // rest of the options to 0. + bh = NewBatchHandler("test", dummyHandler, WithMaxBatchBytes(50), WithMaxBatchSize(0)) + assert.Equal(t, 0, bh.MaxBatchSize()) + assert.Equal(t, 50, bh.MaxBatchBytes()) +} diff --git a/pool/subscriber_test.go b/pool/subscriber_test.go index 986ad4c..5d28607 100644 --- a/pool/subscriber_test.go +++ b/pool/subscriber_test.go @@ -241,3 +241,159 @@ func TestBatchSubscriber(t *testing.T) { wg.Wait() } + +func TestBatchSubscriberMaxBytes(t *testing.T) { + + for i := 1; i <= 2048; i = i*2 + 1 { + testBatchSubscriberMaxBytes(t, i) + } +} + +func testBatchSubscriberMaxBytes(t *testing.T, maxBatchBytes int) { + t.Helper() + + var ( + sessions = 2 // publisher sessions + consumer sessions + numMessages = 50 + batchTimeout = 5 * time.Second // keep this at a higher number for slow machines + ) + p, err := pool.New(connectURL, 1, sessions, pool.WithConfirms(true), pool.WithLogger(logging.NewTestLogger(t))) + if err != nil { + assert.NoError(t, err) + return + } + defer p.Close() + + var wg sync.WaitGroup + + channels := sessions / 2 // one sessions for consumer and one for publisher + wg.Add(channels) + for id := 0; id < channels; id++ { + go func(id int64) { + defer wg.Done() + + ts, err := p.GetTransientSession(p.Context()) + if err != nil { + assert.NoError(t, err) + return + } + defer p.ReturnSession(ts, false) + + queueName := fmt.Sprintf("TestBatchSubscriberMaxBytes-Queue-%d", id) + _, err = ts.QueueDeclare(queueName) + if err != nil { + assert.NoError(t, err) + return + } + defer func() { + i, err := ts.QueueDelete(queueName) + assert.NoError(t, err) + assert.Equal(t, 0, i) + }() + + exchangeName := fmt.Sprintf("TestBatchSubscriberMaxBytes-Exchange-%d", id) + err = ts.ExchangeDeclare(exchangeName, "topic") + if err != nil { + assert.NoError(t, err) + return + } + defer func() { + err := ts.ExchangeDelete(exchangeName) + assert.NoError(t, err) + }() + + err = ts.QueueBind(queueName, "#", exchangeName) + if err != nil { + assert.NoError(t, err) + return + } + defer func() { + err := ts.QueueUnbind(queueName, "#", exchangeName, nil) + assert.NoError(t, err) + }() + + // publish all messages + pub := pool.NewPublisher(p) + defer pub.Close() + + log := logging.NewTestLogger(t) + + maxMsgLen := 0 + for i := 0; i < numMessages; i++ { + message := fmt.Sprintf("Message-%s-%06d", queueName, i) // max 6 digits + mlen := len(message) + if mlen > maxMsgLen { + maxMsgLen = mlen + } + + pub.Publish(exchangeName, "", pool.Publishing{ + Mandatory: true, + ContentType: "application/json", + Body: []byte(message), + }) + } + log.Debugf("max message length: %d", maxMsgLen) + log.Debugf("max batch bytes: %d", maxBatchBytes) + expectedMessagesPerBatch := maxBatchBytes / maxMsgLen + if maxBatchBytes%maxMsgLen > 0 { + expectedMessagesPerBatch += 1 + } + log.Debugf("expected messages per batch: %d", expectedMessagesPerBatch) + expectedBatches := numMessages / expectedMessagesPerBatch + if numMessages%expectedMessagesPerBatch > 0 { + expectedBatches += 1 + } + log.Debugf("expected batches: %d", expectedBatches) + + ctx, cancel := context.WithCancel(p.Context()) + + sub := pool.NewSubscriber(p, pool.SubscriberWithContext(ctx)) + defer sub.Close() + + batchCount := 0 + messageCount := 0 + sub.RegisterBatchHandlerFunc(queueName, + func(msgs []pool.Delivery) error { + + for idx, msg := range msgs { + assert.Truef(t, len(msg.Body) > 0, "msg body is empty: message index: %d", idx) + log.Debugf("batch: %d message: %d: body: %q", batchCount, idx, string(msg.Body)) + } + + messageCount += len(msgs) + batchCount += 1 + + expectedMessages := expectedMessagesPerBatch + if len(msgs)%expectedMessagesPerBatch > 0 { + expectedMessages = len(msgs) % expectedMessagesPerBatch + } + assert.Equal(t, expectedMessages, len(msgs)) + + if messageCount == numMessages { + // close subscriber from within handler + cancel() + } + return nil + }, + pool.WithMaxBatchBytes(maxBatchBytes), + pool.WithMaxBatchSize(0), // disable this check + pool.WithBatchFlushTimeout(batchTimeout), + pool.WithBatchConsumeOptions(pool.ConsumeOptions{ + ConsumerTag: fmt.Sprintf("Consumer-%s", queueName), + Exclusive: true, + }), + ) + sub.Start() + + // this should be canceled upon context cancelation from within the + // subscriber handler. + sub.Wait() + + assert.Equalf(t, numMessages, messageCount, "expected messages counter to have the same number as publishes messages") + assert.Equalf(t, expectedBatches, batchCount, "required to have %d batches received", expectedBatches) + + }(int64(id)) + } + + wg.Wait() +}