diff --git a/cmd/root.go b/cmd/root.go index 73f007a6..011d0c3a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -31,29 +31,23 @@ const ( name = "crowdsec-firewall-bouncer" ) -func termHandler(sig os.Signal, backend *backend.BackendCTX) error { - if err := backend.ShutDown(); err != nil { - return err - } - return nil -} - func backendCleanup(backend *backend.BackendCTX) { + log.Info("Shutting down backend") if err := backend.ShutDown(); err != nil { log.Errorf("unable to shutdown backend: %s", err) } } -func HandleSignals(backend *backend.BackendCTX) { +func HandleSignals(ctx context.Context) error { signalChan := make(chan os.Signal, 1) signal.Notify(signalChan, syscall.SIGTERM) - s := <-signalChan - if err := termHandler(s, backend); err != nil { - log.Fatalf("shutdown fail: %s", err) + select { + case <-signalChan: + return fmt.Errorf("received SIGTERM") + case <-ctx.Done(): + return ctx.Err() } - log.Infof("Shutting down firewall-bouncer service") - os.Exit(0) } func deleteDecisions(backend *backend.BackendCTX, decisions []*models.Decision, config *cfg.BouncerConfig) { @@ -119,7 +113,7 @@ func addDecisions(backend *backend.BackendCTX, decisions []*models.Decision, con } } -func Execute() { +func Execute() error { var err error configPath := flag.String("c", "", "path to crowdsec-firewall-bouncer.yaml") verbose := flag.Bool("v", false, "set verbose mode") @@ -136,17 +130,17 @@ func Execute() { log.Infof("crowdsec-firewall-bouncer %s", version.VersionStr()) if configPath == nil || *configPath == "" { - log.Fatalf("configuration file is required") + return fmt.Errorf("configuration file is required") } configBytes, err := cfg.MergedConfig(*configPath) if err != nil { - log.Fatalf("unable to read config file: %s", err) + return fmt.Errorf("unable to read config file: %w", err) } config, err := cfg.NewConfig(bytes.NewReader(configBytes)) if err != nil { - log.Fatalf("unable to load configuration: %s", err) + return fmt.Errorf("unable to load configuration: %w", err) } if *verbose { @@ -155,7 +149,7 @@ func Execute() { backend, err := backend.NewBackend(config) if err != nil { - log.Fatal(err) + return err } if *testConfig { @@ -164,21 +158,19 @@ func Execute() { } if err = backend.Init(); err != nil { - log.Fatal(err) + return err } - // No call to fatalf after this point + defer backendCleanup(backend) bouncer := &csbouncer.StreamBouncer{} err = bouncer.ConfigReader(bytes.NewReader(configBytes)) if err != nil { - log.Errorf("unable to configure bouncer: %s", err) - return + return fmt.Errorf("unable to configure bouncer: %w", err) } bouncer.UserAgent = fmt.Sprintf("%s/%s", name, version.VersionStr()) if err := bouncer.Init(); err != nil { - log.Error(err) - return + return err } if bouncer.InsecureSkipVerify != nil { @@ -213,7 +205,6 @@ func Execute() { for { select { case <-ctx.Done(): - log.Info("terminating bouncer process") return nil case decisions := <-bouncer.Stream: if decisions == nil { @@ -230,10 +221,14 @@ func Execute() { if !sent && err != nil { log.Errorf("Failed to notify: %v", err) } - go HandleSignals(backend) + g.Go(func() error { + return HandleSignals(ctx) + }) } if err := g.Wait(); err != nil { - log.Errorf("process return with error: %s", err) + return fmt.Errorf("process terminated with error: %w", err) } + + return nil } diff --git a/main.go b/main.go index c648ee0f..c5fdf35c 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,14 @@ package main -import "github.com/crowdsecurity/cs-firewall-bouncer/cmd" +import ( + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/cs-firewall-bouncer/cmd" +) func main() { - cmd.Execute() + err := cmd.Execute() + if err != nil { + log.Fatal(err) + } }