From ed952626e6b9e92f87076e02096dacc6b23bae56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Snorre=20Magnus=20Dav=C3=B8en?= Date: Wed, 8 Nov 2023 21:24:02 +0100 Subject: [PATCH] fix: Better context handling and graceful shutdown (#9) Use go contexts to handle shutting down go routines and closing channels. This largely makes the receiver pattern with the subscriber and firehose structs unecessary. Instead we pass the necessary arguments to the functions together with the context. The functions listen for the context to signal the process should close, so no separate shutdown function is necessary. Fiber is handled as before. The main.go main function now sets up a cancelable context and passes this to the urfave cli run command. A go routine listens for interrupts and if so gracefully signals to shutdown by calling cancel on the context. --- cmd/serve.go | 46 ++++++++-------------------- cmd/subscribe.go | 45 +++++++++------------------ db/writer.go | 38 +++++++++-------------- firehose/firehose.go | 72 ++++++++++++++------------------------------ main.go | 17 ++++++++++- 5 files changed, 81 insertions(+), 137 deletions(-) diff --git a/cmd/serve.go b/cmd/serve.go index 883e1c5..a0a0bff 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -10,10 +10,6 @@ import ( "norsky/firehose" "norsky/models" "norsky/server" - "os" - "os/signal" - "sync" - "time" log "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" @@ -85,7 +81,6 @@ func serveCmd() *cli.Command { broadcaster := server.NewBroadcaster() // SSE broadcaster dbReader := db.NewReader(database) - seq, err := dbReader.GetSequence() if err != nil { @@ -99,26 +94,6 @@ func serveCmd() *cli.Command { Broadcaster: broadcaster, }) - fh := firehose.New(postChan, ctx.Context, seq) - - dbwriter := db.NewWriter(database, dbPostChan) - - // Graceful shutdown via wait group - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt) - var wg sync.WaitGroup - - // Graceful shutdown logic - go func() { - <-c - fmt.Println("Gracefully shutting down...") - app.ShutdownWithTimeout(5 * time.Second) // Wait 5 seconds for all connections to close - fh.Shutdown() - broadcaster.Shutdown() - defer wg.Add(-4) // Decrement the waitgroup counter by 4 after shutdown of all processes - - }() - // Some glue code to pass posts from the firehose to the database and/or broadcaster // Ideally one might want to do this in a more elegant way // TODO: Move broadcaster into server package, i.e. make server a receiver and handle broadcaster and fiber together @@ -136,7 +111,7 @@ func serveCmd() *cli.Command { go func() { fmt.Println("Subscribing to firehose...") - fh.Subscribe() + firehose.Subscribe(ctx.Context, postChan, seq) }() go func() { @@ -144,22 +119,25 @@ func serveCmd() *cli.Command { if err := app.Listen(fmt.Sprintf("%s:%d", host, port)); err != nil { log.Error(err) - c <- os.Interrupt } }() go func() { fmt.Println("Starting database writer...") - dbwriter.Subscribe() + db.Subscribe(ctx.Context, database, dbPostChan) }() - // Wait for both the server and firehose to shutdown - wg.Add(4) - wg.Wait() - - log.Info("Norsky feed generator stopped") + // Wait for SIGINT (Ctrl+C) or SIGTERM (docker stop) to stop the server - return nil + select { + case <-ctx.Context.Done(): + log.Info("Stopping server") + if err := app.ShutdownWithContext(ctx.Context); err != nil { + log.Error(err) + } + log.Info("Norsky feed generator stopped") + return nil + } }, } } diff --git a/cmd/subscribe.go b/cmd/subscribe.go index e592f45..9f88d07 100644 --- a/cmd/subscribe.go +++ b/cmd/subscribe.go @@ -4,14 +4,11 @@ Copyright © 2023 NAME HERE package cmd import ( - "context" "encoding/json" "fmt" "norsky/firehose" "norsky/models" "os" - "os/signal" - "sync" log "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" @@ -34,7 +31,6 @@ the output. Prints all other log messages to stderr.`, Action: func(ctx *cli.Context) error { // Get the context for this process to pass to firehose - context := context.Background() // Disable logging to stdout log.SetOutput(os.Stderr) @@ -42,43 +38,32 @@ Prints all other log messages to stderr.`, // Channel for subscribing to bluesky posts postChan := make(chan interface{}) - // Setup the server and firehose - fh := firehose.New(postChan, context, -1) - - // Graceful shutdown - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt) - var wg sync.WaitGroup - - go func() { - <-c - defer wg.Add(-1) // Decrement the waitgroup counter by 2 after shutdown of server and firehose - fh.Shutdown() - }() - go func() { fmt.Println("Subscribing to firehose...") - fh.Subscribe() + firehose.Subscribe(ctx.Context, postChan, -1) }() go func() { // Subscribe to the post channel and log the posts + // Stop if the context is cancelled for message := range postChan { - switch message := message.(type) { - case models.CreatePostEvent: - printStdout(&message.Post) - case models.UpdatePostEvent: - printStdout(&message.Post) - case models.DeletePostEvent: - printStdout(&message.Post) + select { + case <-ctx.Context.Done(): + fmt.Println("Stopping subscription") + return + default: + switch message := message.(type) { + case models.CreatePostEvent: + printStdout(&message.Post) + case models.UpdatePostEvent: + printStdout(&message.Post) + case models.DeletePostEvent: + printStdout(&message.Post) + } } } }() - // Wait for both the server and firehose to shutdown - wg.Add(1) - wg.Wait() - return nil }, } diff --git a/db/writer.go b/db/writer.go index 46f0146..f1224e6 100644 --- a/db/writer.go +++ b/db/writer.go @@ -1,6 +1,7 @@ package db import ( + "context" "norsky/models" "time" @@ -10,47 +11,39 @@ import ( log "github.com/sirupsen/logrus" ) -type Writer struct { - db *sql.DB - postChan chan interface{} - tidyChan *time.Ticker -} +func Subscribe(ctx context.Context, database string, postChan chan interface{}) { + tidyChan := time.NewTicker(5 * time.Minute) -func NewWriter(database string, postChan chan interface{}) *Writer { db, err := connection(database) if err != nil { - panic("failed to connect database") - } - return &Writer{ - db: db, - postChan: postChan, - // Create new tidy channel that is pinged every 5 minutes - tidyChan: time.NewTicker(5 * time.Minute), + log.Error("Error connecting to database", err) + ctx.Done() } -} -func (writer *Writer) Subscribe() { // Tidy database immediately - if err := tidy(writer.db); err != nil { + if err := tidy(db); err != nil { log.Error("Error tidying database", err) } for { select { - case <-writer.tidyChan.C: + case <-ctx.Done(): + log.Info("Stopping database writer") + return + case <-tidyChan.C: log.Info("Tidying database") - if err := tidy(writer.db); err != nil { + if err := tidy(db); err != nil { log.Error("Error tidying database", err) } - case post := <-writer.postChan: + case post := <-postChan: switch event := post.(type) { case models.ProcessSeqEvent: - processSeq(writer.db, event) + processSeq(db, event) case models.CreatePostEvent: - createPost(writer.db, event.Post) + createPost(db, event.Post) case models.DeletePostEvent: - deletePost(writer.db, event.Post) + deletePost(db, event.Post) default: log.Info("Unknown post type") } @@ -60,7 +53,6 @@ func (writer *Writer) Subscribe() { } func processSeq(db *sql.DB, evt models.ProcessSeqEvent) error { - log.Info("Processing sequence") // Update sequence row with new seq number updateSeq := sqlbuilder.NewUpdateBuilder() sql, args := updateSeq.Update("sequence").Set(updateSeq.Assign("seq", evt.Seq)).Where(updateSeq.Equal("id", 0)).Build() diff --git a/firehose/firehose.go b/firehose/firehose.go index 3a18122..6ebe4b2 100644 --- a/firehose/firehose.go +++ b/firehose/firehose.go @@ -22,70 +22,44 @@ import ( log "github.com/sirupsen/logrus" ) -// Add a firehose model to use as a receiver pattern for the firehose - -type Firehose struct { - address string // The address of the firehose - dialer *websocket.Dialer // The websocket dialer to use for the firehose - conn *websocket.Conn // The websocket connection to the firehose - scheduler *sequential.Scheduler // The scheduler to use for the firehose - // A channel to write feed posts to - postChan chan interface{} - // The context for this process - context context.Context -} - -func New(postChan chan interface{}, context context.Context, seq int64) *Firehose { +// Subscribe to the firehose using the Firehose struct as a receiver +func Subscribe(ctx context.Context, postChan chan interface{}, seq int64) { address := "wss://bsky.network/xrpc/com.atproto.sync.subscribeRepos" if seq >= 0 { log.Info("Starting from sequence: ", seq) address = fmt.Sprintf("%s?cursor=%d", address, seq) } dialer := websocket.DefaultDialer - firehose := &Firehose{ - address: address, - dialer: dialer, - postChan: postChan, - context: context, - } - - return firehose -} - -// Subscribe to the firehose using the Firehose struct as a receiver -func (firehose *Firehose) Subscribe() { - backoff := backoff.NewExponentialBackOff() + // Check if context is cancelled, if so exit the connection loop for { - conn, _, err := firehose.dialer.Dial(firehose.address, nil) - if err != nil { - log.Errorf("Error connecting to firehose: %s", err) - time.Sleep(backoff.NextBackOff()) - // Increase backoff by factor of 1.3, rounded to nearest whole number - continue - } + select { + case <-ctx.Done(): + log.Info("Stopping firehose connect loop") + return + default: + conn, _, err := dialer.Dial(address, nil) + if err != nil { + log.Errorf("Error connecting to firehose: %s", err) + time.Sleep(backoff.NextBackOff()) + // Increase backoff by factor of 1.3, rounded to nearest whole number + continue + } - firehose.conn = conn - firehose.scheduler = sequential.NewScheduler(conn.RemoteAddr().String(), eventProcessor(firehose.postChan, firehose.context).EventHandler) - err = events.HandleRepoStream(context.Background(), conn, firehose.scheduler) + scheduler := sequential.NewScheduler(conn.RemoteAddr().String(), eventProcessor(postChan, ctx).EventHandler) + err = events.HandleRepoStream(ctx, conn, scheduler) - // If error sleep - if err != nil { - log.Errorf("Error handling repo stream: %s", err) - time.Sleep(backoff.NextBackOff()) - continue + // If error sleep + if err != nil { + log.Errorf("Error handling repo stream: %s", err) + time.Sleep(backoff.NextBackOff()) + continue + } } } } -func (firehose *Firehose) Shutdown() { - // TODO: Graceful shutdown here as "Error handling repo stream: read tcp use of closed network connection " - firehose.scheduler.Shutdown() - firehose.conn.Close() - log.Info("Firehose shutdown") -} - func eventProcessor(postChan chan interface{}, context context.Context) *events.RepoStreamCallbacks { streamCallbacks := &events.RepoStreamCallbacks{ RepoCommit: func(evt *atproto.SyncSubscribeRepos_Commit) error { diff --git a/main.go b/main.go index 5d07070..7607ee4 100644 --- a/main.go +++ b/main.go @@ -1,16 +1,31 @@ package main import ( + "context" "fmt" "norsky/cmd" "os" + "os/signal" + "syscall" _ "golang.org/x/crypto/x509roots/fallback" ) func main() { + // Check if a signal interrupts the process and if so call Done on the context + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Listen for interrupt signals + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + cancel() + }() + app := cmd.RootApp() - if err := app.Run(os.Args); err != nil { + if err := app.RunContext(ctx, os.Args); err != nil { fmt.Println(err) os.Exit(1) }