diff --git a/cmd/serve.go b/cmd/serve.go index 883e1c5..a0a0bff 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -10,10 +10,6 @@ import ( "norsky/firehose" "norsky/models" "norsky/server" - "os" - "os/signal" - "sync" - "time" log "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" @@ -85,7 +81,6 @@ func serveCmd() *cli.Command { broadcaster := server.NewBroadcaster() // SSE broadcaster dbReader := db.NewReader(database) - seq, err := dbReader.GetSequence() if err != nil { @@ -99,26 +94,6 @@ func serveCmd() *cli.Command { Broadcaster: broadcaster, }) - fh := firehose.New(postChan, ctx.Context, seq) - - dbwriter := db.NewWriter(database, dbPostChan) - - // Graceful shutdown via wait group - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt) - var wg sync.WaitGroup - - // Graceful shutdown logic - go func() { - <-c - fmt.Println("Gracefully shutting down...") - app.ShutdownWithTimeout(5 * time.Second) // Wait 5 seconds for all connections to close - fh.Shutdown() - broadcaster.Shutdown() - defer wg.Add(-4) // Decrement the waitgroup counter by 4 after shutdown of all processes - - }() - // Some glue code to pass posts from the firehose to the database and/or broadcaster // Ideally one might want to do this in a more elegant way // TODO: Move broadcaster into server package, i.e. make server a receiver and handle broadcaster and fiber together @@ -136,7 +111,7 @@ func serveCmd() *cli.Command { go func() { fmt.Println("Subscribing to firehose...") - fh.Subscribe() + firehose.Subscribe(ctx.Context, postChan, seq) }() go func() { @@ -144,22 +119,25 @@ func serveCmd() *cli.Command { if err := app.Listen(fmt.Sprintf("%s:%d", host, port)); err != nil { log.Error(err) - c <- os.Interrupt } }() go func() { fmt.Println("Starting database writer...") - dbwriter.Subscribe() + db.Subscribe(ctx.Context, database, dbPostChan) }() - // Wait for both the server and firehose to shutdown - wg.Add(4) - wg.Wait() - - log.Info("Norsky feed generator stopped") + // Wait for SIGINT (Ctrl+C) or SIGTERM (docker stop) to stop the server - return nil + select { + case <-ctx.Context.Done(): + log.Info("Stopping server") + if err := app.ShutdownWithContext(ctx.Context); err != nil { + log.Error(err) + } + log.Info("Norsky feed generator stopped") + return nil + } }, } } diff --git a/cmd/subscribe.go b/cmd/subscribe.go index e592f45..9f88d07 100644 --- a/cmd/subscribe.go +++ b/cmd/subscribe.go @@ -4,14 +4,11 @@ Copyright © 2023 NAME HERE package cmd import ( - "context" "encoding/json" "fmt" "norsky/firehose" "norsky/models" "os" - "os/signal" - "sync" log "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" @@ -34,7 +31,6 @@ the output. Prints all other log messages to stderr.`, Action: func(ctx *cli.Context) error { // Get the context for this process to pass to firehose - context := context.Background() // Disable logging to stdout log.SetOutput(os.Stderr) @@ -42,43 +38,32 @@ Prints all other log messages to stderr.`, // Channel for subscribing to bluesky posts postChan := make(chan interface{}) - // Setup the server and firehose - fh := firehose.New(postChan, context, -1) - - // Graceful shutdown - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt) - var wg sync.WaitGroup - - go func() { - <-c - defer wg.Add(-1) // Decrement the waitgroup counter by 2 after shutdown of server and firehose - fh.Shutdown() - }() - go func() { fmt.Println("Subscribing to firehose...") - fh.Subscribe() + firehose.Subscribe(ctx.Context, postChan, -1) }() go func() { // Subscribe to the post channel and log the posts + // Stop if the context is cancelled for message := range postChan { - switch message := message.(type) { - case models.CreatePostEvent: - printStdout(&message.Post) - case models.UpdatePostEvent: - printStdout(&message.Post) - case models.DeletePostEvent: - printStdout(&message.Post) + select { + case <-ctx.Context.Done(): + fmt.Println("Stopping subscription") + return + default: + switch message := message.(type) { + case models.CreatePostEvent: + printStdout(&message.Post) + case models.UpdatePostEvent: + printStdout(&message.Post) + case models.DeletePostEvent: + printStdout(&message.Post) + } } } }() - // Wait for both the server and firehose to shutdown - wg.Add(1) - wg.Wait() - return nil }, } diff --git a/db/writer.go b/db/writer.go index 46f0146..f1224e6 100644 --- a/db/writer.go +++ b/db/writer.go @@ -1,6 +1,7 @@ package db import ( + "context" "norsky/models" "time" @@ -10,47 +11,39 @@ import ( log "github.com/sirupsen/logrus" ) -type Writer struct { - db *sql.DB - postChan chan interface{} - tidyChan *time.Ticker -} +func Subscribe(ctx context.Context, database string, postChan chan interface{}) { + tidyChan := time.NewTicker(5 * time.Minute) -func NewWriter(database string, postChan chan interface{}) *Writer { db, err := connection(database) if err != nil { - panic("failed to connect database") - } - return &Writer{ - db: db, - postChan: postChan, - // Create new tidy channel that is pinged every 5 minutes - tidyChan: time.NewTicker(5 * time.Minute), + log.Error("Error connecting to database", err) + ctx.Done() } -} -func (writer *Writer) Subscribe() { // Tidy database immediately - if err := tidy(writer.db); err != nil { + if err := tidy(db); err != nil { log.Error("Error tidying database", err) } for { select { - case <-writer.tidyChan.C: + case <-ctx.Done(): + log.Info("Stopping database writer") + return + case <-tidyChan.C: log.Info("Tidying database") - if err := tidy(writer.db); err != nil { + if err := tidy(db); err != nil { log.Error("Error tidying database", err) } - case post := <-writer.postChan: + case post := <-postChan: switch event := post.(type) { case models.ProcessSeqEvent: - processSeq(writer.db, event) + processSeq(db, event) case models.CreatePostEvent: - createPost(writer.db, event.Post) + createPost(db, event.Post) case models.DeletePostEvent: - deletePost(writer.db, event.Post) + deletePost(db, event.Post) default: log.Info("Unknown post type") } @@ -60,7 +53,6 @@ func (writer *Writer) Subscribe() { } func processSeq(db *sql.DB, evt models.ProcessSeqEvent) error { - log.Info("Processing sequence") // Update sequence row with new seq number updateSeq := sqlbuilder.NewUpdateBuilder() sql, args := updateSeq.Update("sequence").Set(updateSeq.Assign("seq", evt.Seq)).Where(updateSeq.Equal("id", 0)).Build() diff --git a/firehose/firehose.go b/firehose/firehose.go index 3a18122..6ebe4b2 100644 --- a/firehose/firehose.go +++ b/firehose/firehose.go @@ -22,70 +22,44 @@ import ( log "github.com/sirupsen/logrus" ) -// Add a firehose model to use as a receiver pattern for the firehose - -type Firehose struct { - address string // The address of the firehose - dialer *websocket.Dialer // The websocket dialer to use for the firehose - conn *websocket.Conn // The websocket connection to the firehose - scheduler *sequential.Scheduler // The scheduler to use for the firehose - // A channel to write feed posts to - postChan chan interface{} - // The context for this process - context context.Context -} - -func New(postChan chan interface{}, context context.Context, seq int64) *Firehose { +// Subscribe to the firehose using the Firehose struct as a receiver +func Subscribe(ctx context.Context, postChan chan interface{}, seq int64) { address := "wss://bsky.network/xrpc/com.atproto.sync.subscribeRepos" if seq >= 0 { log.Info("Starting from sequence: ", seq) address = fmt.Sprintf("%s?cursor=%d", address, seq) } dialer := websocket.DefaultDialer - firehose := &Firehose{ - address: address, - dialer: dialer, - postChan: postChan, - context: context, - } - - return firehose -} - -// Subscribe to the firehose using the Firehose struct as a receiver -func (firehose *Firehose) Subscribe() { - backoff := backoff.NewExponentialBackOff() + // Check if context is cancelled, if so exit the connection loop for { - conn, _, err := firehose.dialer.Dial(firehose.address, nil) - if err != nil { - log.Errorf("Error connecting to firehose: %s", err) - time.Sleep(backoff.NextBackOff()) - // Increase backoff by factor of 1.3, rounded to nearest whole number - continue - } + select { + case <-ctx.Done(): + log.Info("Stopping firehose connect loop") + return + default: + conn, _, err := dialer.Dial(address, nil) + if err != nil { + log.Errorf("Error connecting to firehose: %s", err) + time.Sleep(backoff.NextBackOff()) + // Increase backoff by factor of 1.3, rounded to nearest whole number + continue + } - firehose.conn = conn - firehose.scheduler = sequential.NewScheduler(conn.RemoteAddr().String(), eventProcessor(firehose.postChan, firehose.context).EventHandler) - err = events.HandleRepoStream(context.Background(), conn, firehose.scheduler) + scheduler := sequential.NewScheduler(conn.RemoteAddr().String(), eventProcessor(postChan, ctx).EventHandler) + err = events.HandleRepoStream(ctx, conn, scheduler) - // If error sleep - if err != nil { - log.Errorf("Error handling repo stream: %s", err) - time.Sleep(backoff.NextBackOff()) - continue + // If error sleep + if err != nil { + log.Errorf("Error handling repo stream: %s", err) + time.Sleep(backoff.NextBackOff()) + continue + } } } } -func (firehose *Firehose) Shutdown() { - // TODO: Graceful shutdown here as "Error handling repo stream: read tcp use of closed network connection " - firehose.scheduler.Shutdown() - firehose.conn.Close() - log.Info("Firehose shutdown") -} - func eventProcessor(postChan chan interface{}, context context.Context) *events.RepoStreamCallbacks { streamCallbacks := &events.RepoStreamCallbacks{ RepoCommit: func(evt *atproto.SyncSubscribeRepos_Commit) error { diff --git a/main.go b/main.go index 5d07070..7607ee4 100644 --- a/main.go +++ b/main.go @@ -1,16 +1,31 @@ package main import ( + "context" "fmt" "norsky/cmd" "os" + "os/signal" + "syscall" _ "golang.org/x/crypto/x509roots/fallback" ) func main() { + // Check if a signal interrupts the process and if so call Done on the context + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Listen for interrupt signals + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + cancel() + }() + app := cmd.RootApp() - if err := app.Run(os.Args); err != nil { + if err := app.RunContext(ctx, os.Args); err != nil { fmt.Println(err) os.Exit(1) }