Skip to content

Commit

Permalink
Merge pull request #41 from jxsl13/feature/batch-bytes
Browse files Browse the repository at this point in the history
Feature/batch bytes
  • Loading branch information
jxsl13 authored Jan 3, 2024
2 parents 66dacd4 + 0fc52d5 commit 9f8c0af
Show file tree
Hide file tree
Showing 10 changed files with 361 additions and 62 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.prof
*.test
*trace*
coverage.txt
8 changes: 4 additions & 4 deletions pool/connection_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,15 @@ func (cp *ConnectionPool) ReturnConnection(conn *Connection, flag bool) {

// close transient connections
if !conn.IsCached() {
conn.Close()
_ = conn.Close()
}

err := cp.connections.Put(conn)
if err != nil {
// queue was disposed of,
// indicating pool shutdown
// -> close connection upon pool shutdown
conn.Close()
_ = conn.Close()
}
}

Expand All @@ -232,16 +232,16 @@ func (cp *ConnectionPool) Close() {
for !cp.connections.Empty() {
items := cp.connections.Dispose()

wg.Add(len(items))
for _, item := range items {
conn, ok := item.(*Connection)
if !ok {
panic("item in connection queue is not a connection")
}

wg.Add(1)
go func(c *Connection) {
defer wg.Done()
c.Close()
_ = c.Close()
}(conn)
}
}
Expand Down
27 changes: 9 additions & 18 deletions pool/helpers_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,7 @@ func (sc *stateContext) Resume(ctx context.Context) error {
}

func (sc *stateContext) IsActive(ctx context.Context) (active bool, err error) {
closed := func() bool {
sc.mu.RLock()
defer sc.mu.RUnlock()
return sc.closed
}()
if closed {
if sc.isClosed() {
return false, nil
}

Expand All @@ -312,12 +307,7 @@ func (sc *stateContext) IsActive(ctx context.Context) (active bool, err error) {
}

func (sc *stateContext) AwaitResumed(ctx context.Context) (err error) {
closed := func() bool {
sc.mu.RLock()
defer sc.mu.RUnlock()
return sc.closed
}()
if closed {
if sc.isClosed() {
return ErrClosed
}

Expand All @@ -330,12 +320,7 @@ func (sc *stateContext) AwaitResumed(ctx context.Context) (err error) {
}

func (sc *stateContext) AwaitPaused(ctx context.Context) (err error) {
closed := func() bool {
sc.mu.RLock()
defer sc.mu.RUnlock()
return sc.closed
}()
if closed {
if sc.isClosed() {
return ErrClosed
}

Expand All @@ -347,6 +332,12 @@ func (sc *stateContext) AwaitPaused(ctx context.Context) (err error) {
}
}

func (sc *stateContext) isClosed() bool {
sc.mu.RLock()
defer sc.mu.RUnlock()
return sc.closed
}

// close closes all active contexts
// in order to prevent dangling goroutines
// When closing you may want to use pause first and then close for the final cleanup
Expand Down
3 changes: 3 additions & 0 deletions pool/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ type Session struct {

consumers map[string]bool // saves consumer names in order to cancel them upon session closure

// a session should not be used in a multithreaded context
// but only one session per goroutine. That is why we keep this
// as a Mutex and not a RWMutex.
mu sync.Mutex
ctx context.Context
cancel context.CancelFunc
Expand Down
11 changes: 7 additions & 4 deletions pool/session_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (sp *SessionPool) ReturnSession(session *Session, erred bool) {

// don't ass non-managed sessions back to the channel
if !session.IsCached() {
session.Close()
_ = session.Close()
return
}

Expand All @@ -168,7 +168,7 @@ func (sp *SessionPool) ReturnSession(session *Session, erred bool) {

select {
case <-sp.catchShutdown():
session.Close()
_ = session.Close()
case sp.sessions <- session:
}
}
Expand All @@ -190,11 +190,14 @@ SessionClose:
for {
select {
// flush sessions channel
case session := <-sp.sessions:
case session, ok := <-sp.sessions:
if !ok {
break SessionClose
}
wg.Add(1)
go func(*Session) {
defer wg.Done()
session.Close()
_ = session.Close()
}(session)

default:
Expand Down
90 changes: 73 additions & 17 deletions pool/subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,11 @@ func (s *Subscriber) RegisterBatchHandler(handler *BatchHandler) {

s.log.WithFields(withConsumerIfSet(handler.ConsumeOptions().ConsumerTag,
map[string]any{
"subscriber": s.pool.Name(),
"queue": opts.Queue,
"maxBatchSize": opts.MaxBatchSize, // TODO: optimize so that we don't call getters multiple times (mutex contention)
"flushTimeout": handler.FlushTimeout,
"subscriber": s.pool.Name(),
"queue": opts.Queue,
"maxBatchBytes": opts.MaxBatchBytes,
"maxBatchSize": opts.MaxBatchSize,
"flushTimeout": opts.FlushTimeout,
})).Info("registered batch message handler")
}

Expand Down Expand Up @@ -445,7 +446,14 @@ func (s *Subscriber) batchConsume(h *BatchHandler) (err error) {
// There is no way to recover form this state in case an error is returned from the Nack call.
nackErr := batch[len(batch)-1].Nack(true, true)
if nackErr != nil {
s.warnBatchHandler(opts.ConsumerTag, opts.Queue, opts.MaxBatchSize, err, "failed to nack and requeue batch upon shutdown")
s.warnBatchHandler(
opts.ConsumerTag,
opts.Queue,
opts.MaxBatchSize,
opts.MaxBatchBytes,
err,
"failed to nack and requeue batch upon shutdown",
)
}
}
}()
Expand All @@ -456,10 +464,12 @@ func (s *Subscriber) batchConsume(h *BatchHandler) (err error) {
)
defer closeTimer(timer, &drained)

var batchBytes = 0
for {
// reset batch slice
// reuse memory
batch = batch[:0]
batchBytes = 0

collectBatch:
for {
Expand All @@ -474,8 +484,14 @@ func (s *Subscriber) batchConsume(h *BatchHandler) (err error) {
if !ok {
return ErrDeliveryClosed
}

batchBytes += len(msg.Body)
batch = append(batch, msg)
if len(batch) == opts.MaxBatchSize {
if opts.MaxBatchSize > 0 && len(batch) == opts.MaxBatchSize {
break collectBatch
}

if opts.MaxBatchBytes > 0 && batchBytes >= opts.MaxBatchBytes {
break collectBatch
}

Expand All @@ -485,6 +501,13 @@ func (s *Subscriber) batchConsume(h *BatchHandler) (err error) {
if len(batch) > 0 {
// timeout reached, process batch that might not contain
// a full batch, yet.
s.infoBatchHandler(
opts.ConsumerTag,
opts.Queue,
len(batch),
batchBytes,
"flush timeout reached",
)
break collectBatch
}

Expand All @@ -498,26 +521,37 @@ func (s *Subscriber) batchConsume(h *BatchHandler) (err error) {
lastDeliveryTag = batch[len(batch)-1].DeliveryTag
)

s.infoBatchHandler(opts.ConsumerTag, opts.Queue, batchSize, "received batch")
s.infoBatchHandler(opts.ConsumerTag, opts.Queue, batchSize, batchBytes, "received batch")
err = opts.HandlerFunc(batch)
// no acks required
if opts.AutoAck {
if err != nil {
// we cannot really do anything to recover from a processing error in this case
s.errorBatchHandler(opts.ConsumerTag, opts.Queue, batchSize, fmt.Errorf("processing failed: dropping batch: %w", err))
s.errorBatchHandler(opts.ConsumerTag,
opts.Queue,
batchSize,
batchBytes,
fmt.Errorf("processing failed: dropping batch: %w", err),
)
} else {
s.infoBatchHandler(opts.ConsumerTag, opts.Queue, batchSize, "processed batch")
s.infoBatchHandler(
opts.ConsumerTag,
opts.Queue,
batchSize,
batchBytes,
"processed batch",
)
}
} else {
poolErr := s.ackBatchPostHandle(opts, lastDeliveryTag, batchSize, session, err)
poolErr := s.ackBatchPostHandle(opts, lastDeliveryTag, batchSize, batchBytes, session, err)
if poolErr != nil {
return poolErr
}
}
}
}

func (s *Subscriber) ackBatchPostHandle(opts BatchHandlerConfig, lastDeliveryTag uint64, currentBatchSize int, session *Session, handlerErr error) (err error) {
func (s *Subscriber) ackBatchPostHandle(opts BatchHandlerConfig, lastDeliveryTag uint64, currentBatchSize, currentBatchBytes int, session *Session, handlerErr error) (err error) {
var ackErr error
// processing failed
if handlerErr != nil {
Expand All @@ -531,7 +565,14 @@ func (s *Subscriber) ackBatchPostHandle(opts BatchHandlerConfig, lastDeliveryTag
// if (n)ack fails, we know that the connection died
// potentially before processing already.
if ackErr != nil {
s.warnBatchHandler(opts.ConsumerTag, opts.Queue, currentBatchSize, ackErr, "batch (n)ack failed")
s.warnBatchHandler(
opts.ConsumerTag,
opts.Queue,
currentBatchSize,
currentBatchBytes,
ackErr,
"batch (n)ack failed",
)
poolErr := session.Recover()
if poolErr != nil {
// only returns an error upon shutdown
Expand All @@ -545,9 +586,21 @@ func (s *Subscriber) ackBatchPostHandle(opts BatchHandlerConfig, lastDeliveryTag

// (n)acked successfully
if handlerErr != nil {
s.infoBatchHandler(opts.ConsumerTag, opts.Queue, currentBatchSize, "nacked batch")
s.infoBatchHandler(
opts.ConsumerTag,
opts.Queue,
currentBatchSize,
currentBatchBytes,
"nacked batch",
)
} else {
s.infoBatchHandler(opts.ConsumerTag, opts.Queue, currentBatchSize, "acked batch")
s.infoBatchHandler(
opts.ConsumerTag,
opts.Queue,
currentBatchSize,
currentBatchBytes,
"acked batch",
)
}
// successfully handled message
return nil
Expand Down Expand Up @@ -586,28 +639,31 @@ func (s *Subscriber) catchShutdown() <-chan struct{} {
return s.ctx.Done()
}

func (s *Subscriber) infoBatchHandler(consumer, queue string, batchSize int, a ...any) {
func (s *Subscriber) infoBatchHandler(consumer, queue string, batchSize, batchBytes int, a ...any) {
s.log.WithFields(withConsumerIfSet(consumer,
map[string]any{
"batchSize": batchSize,
"batchBytes": batchBytes,
"subscriber": s.pool.Name(),
"queue": queue,
})).Info(a...)
}

func (s *Subscriber) warnBatchHandler(consumer, queue string, batchSize int, err error, a ...any) {
func (s *Subscriber) warnBatchHandler(consumer, queue string, batchSize, batchBytes int, err error, a ...any) {
s.log.WithFields(withConsumerIfSet(consumer, map[string]any{
"batchSize": batchSize,
"batchBytes": batchBytes,
"subscriber": s.pool.Name(),
"queue": queue,
"error": err,
})).Warn(a...)
}

func (s *Subscriber) errorBatchHandler(consumer, queue string, batchSize int, err error, a ...any) {
func (s *Subscriber) errorBatchHandler(consumer, queue string, batchSize, batchBytes int, err error, a ...any) {
s.log.WithFields(withConsumerIfSet(consumer,
map[string]any{
"batchSize": batchSize,
"batchBytes": batchBytes,
"subscriber": s.pool.Name(),
"queue": queue,
"error": err,
Expand Down
Loading

0 comments on commit 9f8c0af

Please sign in to comment.