From c6d8d523bc2bedf4c3da82190225fc8e51090ccd Mon Sep 17 00:00:00 2001 From: Aitor Perez Cedres Date: Tue, 7 May 2024 13:09:36 +0100 Subject: [PATCH] Fix data races in the example Some users rely on this example as a starting point to their applications. This commit fixes a data race that could cause issues in any code that relied on the example as base. Related to #72 Signed-off-by: Aitor Perez Cedres --- example_client_test.go | 65 +++++++++++++++++++++++++++++++----------- 1 file changed, 49 insertions(+), 16 deletions(-) diff --git a/example_client_test.go b/example_client_test.go index 34c7fc3..a93c509 100644 --- a/example_client_test.go +++ b/example_client_test.go @@ -8,9 +8,9 @@ package amqp091_test import ( "context" "errors" - "fmt" "log" "os" + "sync" "time" amqp "github.com/rabbitmq/amqp091-go" @@ -23,9 +23,10 @@ import ( // It doesn't automatically ack each message, but leaves that // to the parent process, since it is usage-dependent. // -// Try running this in one terminal, and `rabbitmq-server` in another. +// Try running this in one terminal, and rabbitmq-server in another. +// // Stop & restart RabbitMQ to see how the queue reacts. -func Example() { +func Example_publish() { queueName := "job_queue" addr := "amqp://guest:guest@localhost:5672/" queue := New(queueName, addr) @@ -39,12 +40,14 @@ loop: // Attempt to push a message every 2 seconds case <-time.After(time.Second * 2): if err := queue.Push(message); err != nil { - fmt.Printf("Push failed: %s\n", err) + log.Printf("Push failed: %s\n", err) } else { - fmt.Println("Push succeeded!") + log.Println("Push succeeded!") } case <-ctx.Done(): - queue.Close() + if err := queue.Close(); err != nil { + log.Printf("Close failed: %s\n", err) + } break loop } } @@ -55,7 +58,7 @@ func Example_consume() { addr := "amqp://guest:guest@localhost:5672/" queue := New(queueName, addr) - // Give the connection sometime to setup + // Give the connection sometime to set up <-time.After(time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) @@ -63,7 +66,7 @@ func Example_consume() { deliveries, err := queue.Consume() if err != nil { - fmt.Printf("Could not start consuming: %s\n", err) + log.Printf("Could not start consuming: %s\n", err) return } @@ -78,19 +81,22 @@ func Example_consume() { for { select { case <-ctx.Done(): - queue.Close() + err := queue.Close() + if err != nil { + log.Printf("Close failed: %s\n", err) + } return case amqErr := <-chClosedCh: // This case handles the event of closed channel e.g. abnormal shutdown - fmt.Printf("AMQP Channel closed due to: %s\n", amqErr) + log.Printf("AMQP Channel closed due to: %s\n", amqErr) deliveries, err = queue.Consume() if err != nil { // If the AMQP channel is not ready, it will continue the loop. Next // iteration will enter this case because chClosedCh is closed by the // library - fmt.Println("Error trying to consume, will try again") + log.Println("Error trying to consume, will try again") continue } @@ -101,16 +107,21 @@ func Example_consume() { case delivery := <-deliveries: // Ack a message every 2 seconds - fmt.Printf("Received message: %s\n", delivery.Body) + log.Printf("Received message: %s\n", delivery.Body) if err := delivery.Ack(false); err != nil { - fmt.Printf("Error acknowledging message: %s\n", err) + log.Printf("Error acknowledging message: %s\n", err) } <-time.After(time.Second * 2) } } } +// Client is the base struct for handling connection recovery, consumption and +// publishing. Note that this struct has an internal mutex to safeguard against +// data races. As you develop and iterate over this example, you may need to add +// further locks, or safeguards, to keep your application safe from data races type Client struct { + m *sync.Mutex queueName string logger *log.Logger connection *amqp.Connection @@ -143,6 +154,7 @@ var ( // attempts to connect to the server. func New(queueName, addr string) *Client { client := Client{ + m: &sync.Mutex{}, logger: log.New(os.Stdout, "", log.LstdFlags), queueName: queueName, done: make(chan bool), @@ -155,7 +167,10 @@ func New(queueName, addr string) *Client { // notifyConnClose, and then continuously attempt to reconnect. func (client *Client) handleReconnect(addr string) { for { + client.m.Lock() client.isReady = false + client.m.Unlock() + client.logger.Println("Attempting to connect") conn, err := client.connect(addr) @@ -194,7 +209,9 @@ func (client *Client) connect(addr string) (*amqp.Connection, error) { // and then continuously attempt to re-initialize both channels func (client *Client) handleReInit(conn *amqp.Connection) bool { for { + client.m.Lock() client.isReady = false + client.m.Unlock() err := client.init(conn) @@ -251,7 +268,9 @@ func (client *Client) init(conn *amqp.Connection) error { } client.changeChannel(ch) + client.m.Lock() client.isReady = true + client.m.Unlock() client.logger.Println("Setup!") return nil @@ -275,13 +294,16 @@ func (client *Client) changeChannel(channel *amqp.Channel) { client.channel.NotifyPublish(client.notifyConfirm) } -// Push will push data onto the queue, and wait for a confirm. -// This will block until the server sends a confirm. Errors are +// Push will push data onto the queue, and wait for a confirmation. +// This will block until the server sends a confirmation. Errors are // only returned if the push action itself fails, see UnsafePush. func (client *Client) Push(data []byte) error { + client.m.Lock() if !client.isReady { + client.m.Unlock() return errors.New("failed to push: not connected") } + client.m.Unlock() for { err := client.UnsafePush(data) if err != nil { @@ -306,9 +328,12 @@ func (client *Client) Push(data []byte) error { // No guarantees are provided for whether the server will // receive the message. func (client *Client) UnsafePush(data []byte) error { + client.m.Lock() if !client.isReady { + client.m.Unlock() return errNotConnected } + client.m.Unlock() ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -331,13 +356,16 @@ func (client *Client) UnsafePush(data []byte) error { // successfully processed, or delivery.Nack when it fails. // Ignoring this will cause data to build up on the server. func (client *Client) Consume() (<-chan amqp.Delivery, error) { + client.m.Lock() if !client.isReady { + client.m.Unlock() return nil, errNotConnected } + client.m.Unlock() if err := client.channel.Qos( 1, // prefetchCount - 0, // prefrechSize + 0, // prefetchSize false, // global ); err != nil { return nil, err @@ -356,6 +384,11 @@ func (client *Client) Consume() (<-chan amqp.Delivery, error) { // Close will cleanly shut down the channel and connection. func (client *Client) Close() error { + client.m.Lock() + // we read and write isReady in two locations, so we grab the lock and hold onto + // it until we are finished + defer client.m.Unlock() + if !client.isReady { return errAlreadyClosed }