diff --git a/crowdsec/caddyfile.go b/crowdsec/caddyfile.go index 472079b8..53d34aaf 100644 --- a/crowdsec/caddyfile.go +++ b/crowdsec/caddyfile.go @@ -1,27 +1,33 @@ package crowdsec import ( + "fmt" "net/url" + "strings" "time" + "github.com/caddyserver/caddy/v2/caddyconfig" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" + "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" ) -func parseCaddyfileGlobalOption(d *caddyfile.Dispenser, existingVal interface{}) (interface{}, error) { - - // TODO: make this work similar to the handler? Or doesn't that work for this - // app level module, because of shared config etc. - - cfg = &config{ - TickerInterval: defaultTickerInterval, - EnableStreaming: defaultStreamingEnabled, - EnableHardFails: defaultHardFailsEnabled, +func parseCrowdSec(d *caddyfile.Dispenser, existingVal any) (any, error) { + tv := true + fv := false + cs := &CrowdSec{ + TickerInterval: "60s", + EnableStreaming: &tv, + EnableHardFails: &fv, } if !d.Next() { return nil, d.Err("expected tokens") } + if d.Val() != "crowdsec" { + return nil, d.Err(fmt.Sprintf(`expected "crowdsec"; got %q`, d.Val())) + } + for d.NextBlock(0) { switch d.Val() { case "api_url": @@ -32,12 +38,19 @@ func parseCaddyfileGlobalOption(d *caddyfile.Dispenser, existingVal interface{}) if err != nil { return nil, d.Errf("invalid URL %s: %v", d.Val(), err) } - cfg.APIUrl = u.String() + if u.Scheme == "" { + return nil, d.Errf("URL %q does not have a scheme (i.e https)", u.String()) + } + s := u.String() + if !strings.HasSuffix(s, "/") { + s = s + "/" + } + cs.APIUrl = s case "api_key": if !d.NextArg() { return nil, d.ArgErr() } - cfg.APIKey = d.Val() + cs.APIKey = d.Val() case "ticker_interval": if !d.NextArg() { return nil, d.ArgErr() @@ -46,21 +59,24 @@ func parseCaddyfileGlobalOption(d *caddyfile.Dispenser, existingVal interface{}) if err != nil { return nil, d.Errf("invalid duration %s: %v", d.Val(), err) } - cfg.TickerInterval = interval.String() + cs.TickerInterval = interval.String() case "disable_streaming": if d.NextArg() { return nil, d.ArgErr() } - cfg.EnableStreaming = false + cs.EnableStreaming = &fv case "enable_hard_fails": if d.NextArg() { return nil, d.ArgErr() } - cfg.EnableHardFails = true + cs.EnableHardFails = &tv default: return nil, d.Errf("invalid configuration token provided: %s", d.Val()) } } - return nil, nil + return httpcaddyfile.App{ + Name: "crowdsec", + Value: caddyconfig.JSON(cs, nil), + }, nil } diff --git a/crowdsec/caddyfile_test.go b/crowdsec/caddyfile_test.go index d6ef49f3..482ffdd8 100644 --- a/crowdsec/caddyfile_test.go +++ b/crowdsec/caddyfile_test.go @@ -1,47 +1,144 @@ package crowdsec import ( + "encoding/json" "testing" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" + "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUnmarshalCaddyfile(t *testing.T) { - trueValue := true - falseValue := false - type args struct { - d *caddyfile.Dispenser - } + tv := true + fv := false tests := []struct { - name string - expected *CrowdSec - args args - wantParseErr bool - wantConfigureErr bool + name string + input string + env map[string]string + expected *CrowdSec + wantParseErr bool }{ { - name: "fail/no-args", + name: "fail/missing tokens", + expected: &CrowdSec{}, + input: ``, + wantParseErr: true, + }, + { + name: "fail/not-crowdsec", + expected: &CrowdSec{}, + input: `not-crowdsec`, + wantParseErr: true, + }, + { + name: "fail/invalid-duration", expected: &CrowdSec{}, - args: args{ - d: caddyfile.NewTestDispenser(`crowdsec`), - }, - wantParseErr: false, - wantConfigureErr: true, + input: `crowdsec { + api_url http://127.0.0.1:8080 + api_key some_random_key + ticker_interval 30x + }`, + wantParseErr: true, + }, + { + name: "fail/no-api-url", + expected: &CrowdSec{}, + input: ` + crowdsec { + api_url + api_key some_random_key + ticker_interval 30x + }`, + wantParseErr: true, + }, + { + name: "fail/invalid-api-url", + expected: &CrowdSec{}, + input: `crowdsec { + api_url http://\x00/ + api_key some_random_key + ticker_interval 30x + }`, + wantParseErr: true, + }, + { + name: "fail/invalid-api-url-no-scheme", + expected: &CrowdSec{}, + input: `crowdsec { + api_url example.com + api_key some_random_key + ticker_interval 30x + }`, + wantParseErr: true, + }, + { + name: "fail/missing-api-key", + expected: &CrowdSec{}, + input: `crowdsec { + api_url http://127.0.0.1:8080 + api_key + }`, + wantParseErr: true, + }, + { + name: "fail/missing-ticker-interval", + expected: &CrowdSec{}, + input: `crowdsec { + api_url http://127.0.0.1:8080 + api_key test-key + ticker_interval + }`, + wantParseErr: true, + }, + { + name: "fail/invalid-streaming", + expected: &CrowdSec{}, + input: `crowdsec { + api_url http://127.0.0.1:8080 + api_key test-key + ticker_interval 30s + disable_streaming absolutely + }`, + wantParseErr: true, + }, + { + name: "fail/invalid-streaming", + expected: &CrowdSec{}, + input: `crowdsec { + api_url http://127.0.0.1:8080 + api_key test-key + ticker_interval 30s + disable_streaming + enable_hard_fails yo + }`, + wantParseErr: true, + }, + { + name: "fail/unknown-token", + expected: &CrowdSec{}, + input: `crowdsec { + api_url http://127.0.0.1:8080 + api_key some_random_key + unknown_token 42 + }`, + wantParseErr: true, }, { name: "ok/basic", expected: &CrowdSec{ - APIUrl: "http://127.0.0.1:8080/", - APIKey: "some_random_key", + APIUrl: "http://127.0.0.1:8080/", + APIKey: "some_random_key", + TickerInterval: "60s", + EnableStreaming: &tv, + EnableHardFails: &fv, }, - args: args{ - d: caddyfile.NewTestDispenser(`crowdsec { + input: `crowdsec { api_url http://127.0.0.1:8080 api_key some_random_key - }`), - }, - wantParseErr: false, - wantConfigureErr: false, + }`, + wantParseErr: false, }, { name: "ok/full", @@ -49,85 +146,66 @@ func TestUnmarshalCaddyfile(t *testing.T) { APIUrl: "http://127.0.0.1:8080/", APIKey: "some_random_key", TickerInterval: "33s", - EnableStreaming: &falseValue, - EnableHardFails: &trueValue, + EnableStreaming: &fv, + EnableHardFails: &tv, }, - args: args{ - d: caddyfile.NewTestDispenser(`crowdsec { + input: `crowdsec { api_url http://127.0.0.1:8080 api_key some_random_key ticker_interval 33s disable_streaming enable_hard_fails - }`), - }, - wantParseErr: false, - wantConfigureErr: false, + }`, + wantParseErr: false, }, { - name: "fail/invalid-duration", - expected: &CrowdSec{}, - args: args{ - d: caddyfile.NewTestDispenser(`crowdsec { - api_url http://127.0.0.1:8080 - api_key some_random_key - ticker_interval 30x - }`), + name: "ok/env-vars", + expected: &CrowdSec{ + APIUrl: "http://127.0.0.2:8080/", + APIKey: "env-test-key", + TickerInterval: "25s", + EnableStreaming: &tv, + EnableHardFails: &fv, }, - wantParseErr: true, - wantConfigureErr: false, - }, - { - name: "fail/unknown-token", - expected: &CrowdSec{}, - args: args{ - d: caddyfile.NewTestDispenser(`crowdsec { - api_url http://127.0.0.1:8080 - api_key some_random_key - unknown_token 42 - }`), + env: map[string]string{ + "CROWDSEC_TEST_API_URL": "http://127.0.0.2:8080/", + "CROWDSEC_TEST_API_KEY": "env-test-key", + "CROWDSEC_TEST_TICKER_INTERVAL": "25s", }, - wantParseErr: true, - wantConfigureErr: false, + input: `crowdsec { + api_url {$CROWDSEC_TEST_API_URL} + api_key {$CROWDSEC_TEST_API_KEY} + ticker_interval {$CROWDSEC_TEST_TICKER_INTERVAL} + }`, + wantParseErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &CrowdSec{} - if _, err := parseCaddyfileGlobalOption(tt.args.d, nil); (err != nil) != tt.wantParseErr { - t.Errorf("CrowdSec.parseCaddyfileGlobalOption() error = %v, wantParseErr %v", err, tt.wantParseErr) - return + for k, v := range tt.env { + t.Setenv(k, v) } - if err := c.configure(); (err != nil) != tt.wantConfigureErr { - t.Errorf("CrowdSec.configure) error = %v, wantConfigureErr %v", err, tt.wantConfigureErr) + dispenser := caddyfile.NewTestDispenser(tt.input) + jsonApp, err := parseCrowdSec(dispenser, nil) + if tt.wantParseErr { + assert.Error(t, err) return } - // TODO: properly use go-cmp and get unexported fields to work - if tt.expected.APIUrl != "" { - if tt.expected.APIUrl != c.APIUrl { - t.Errorf("got: %s, want: %s", c.APIUrl, tt.expected.APIUrl) - } - } - if tt.expected.APIKey != "" { - if tt.expected.APIKey != c.APIKey { - t.Errorf("got: %s, want: %s", c.APIKey, tt.expected.APIKey) - } - } - if tt.expected.TickerInterval != "" { - if tt.expected.TickerInterval != c.TickerInterval { - t.Errorf("got: %s, want: %s", c.TickerInterval, tt.expected.TickerInterval) - } - } - if tt.expected.EnableStreaming != nil { - if *tt.expected.EnableStreaming != *c.EnableStreaming { - t.Errorf("got: %t, want: %t", *c.EnableStreaming, *tt.expected.EnableStreaming) - } - } - if tt.expected.EnableHardFails != nil { - if *tt.expected.EnableHardFails != *c.EnableHardFails { - t.Errorf("got: %t, want: %t", *c.EnableHardFails, *tt.expected.EnableHardFails) - } - } + assert.NoError(t, err) + + app, ok := jsonApp.(httpcaddyfile.App) + require.True(t, ok) + assert.Equal(t, "crowdsec", app.Name) + + var c CrowdSec + err = json.Unmarshal(app.Value, &c) + require.NoError(t, err) + + assert.Equal(t, tt.expected.APIUrl, c.APIUrl) + assert.Equal(t, tt.expected.APIKey, c.APIKey) + assert.Equal(t, tt.expected.TickerInterval, c.TickerInterval) + assert.Equal(t, tt.expected.isStreamingEnabled(), c.isStreamingEnabled()) + assert.Equal(t, tt.expected.shouldFailHard(), c.shouldFailHard()) }) } } diff --git a/crowdsec/crowdsec.go b/crowdsec/crowdsec.go index d845706c..4df2cc61 100644 --- a/crowdsec/crowdsec.go +++ b/crowdsec/crowdsec.go @@ -18,8 +18,6 @@ import ( "errors" "fmt" "net" - "net/url" - "strings" "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" @@ -29,19 +27,9 @@ import ( "github.com/hslatman/caddy-crowdsec-bouncer/internal/bouncer" ) -var ( - cfg *config -) - -const ( - defaultTickerInterval string = "60s" - defaultStreamingEnabled bool = true - defaultHardFailsEnabled bool = false -) - func init() { caddy.RegisterModule(CrowdSec{}) - httpcaddyfile.RegisterGlobalOption("crowdsec", parseCaddyfileGlobalOption) + httpcaddyfile.RegisterGlobalOption("crowdsec", parseCrowdSec) } // CaddyModule returns the Caddy module information. @@ -52,14 +40,6 @@ func (CrowdSec) CaddyModule() caddy.ModuleInfo { } } -type config struct { - APIUrl string - APIKey string - TickerInterval string - EnableStreaming bool - EnableHardFails bool -} - // CrowdSec is a Caddy App that functions as a CrowdSec bouncer. It acts // as a CrowdSec API client as well as a local cache for CrowdSec decisions, // which can be used by the HTTP handler and Layer4 matcher to decide if @@ -70,7 +50,7 @@ type CrowdSec struct { // APIKey for the CrowdSec Local API APIKey string `json:"api_key"` // TickerInterval is the interval the StreamBouncer uses for querying - // the CrowdSec Local API. Defaults to "10s". + // the CrowdSec Local API. Defaults to "60s". TickerInterval string `json:"ticker_interval,omitempty"` // EnableStreaming indicates whether the StreamBouncer should be used. // If it's false, the LiveBouncer is used. The StreamBouncer keeps @@ -91,14 +71,20 @@ type CrowdSec struct { // Provision sets up the CrowdSec app. func (c *CrowdSec) Provision(ctx caddy.Context) error { - c.ctx = ctx c.logger = ctx.Logger(c) defer c.logger.Sync() // nolint - err := c.configure() - if err != nil { - return err + repl := caddy.NewReplacer() // create replacer with the default, global replacement functions, including ".env" env var reading + c.APIUrl = repl.ReplaceKnown(c.APIUrl, "") + c.APIKey = repl.ReplaceKnown(c.APIKey, "") + c.TickerInterval = repl.ReplaceKnown(c.TickerInterval, "") + + if c.APIUrl == "" { + c.APIUrl = "http://127.0.0.1:8080/" + } + if c.TickerInterval == "" { + c.TickerInterval = "60s" } bouncer, err := bouncer.New(c.APIKey, c.APIUrl, c.TickerInterval, c.logger) @@ -114,68 +100,16 @@ func (c *CrowdSec) Provision(ctx caddy.Context) error { bouncer.EnableHardFails() } - if err := bouncer.Init(); err != nil { - return err - } - c.bouncer = bouncer return nil } -func (c *CrowdSec) configure() error { - if cfg != nil { - // A global config is provided through the Caddyfile; always use it - // TODO: combine this with the Unmarshaler approach? - c.APIUrl = cfg.APIUrl - c.APIKey = cfg.APIKey - c.TickerInterval = cfg.TickerInterval - c.EnableStreaming = &cfg.EnableStreaming - c.EnableHardFails = &cfg.EnableHardFails - } - - repl := caddy.NewReplacer() // create replacer with the default, global replacement functions, including ".env" env var reading - c.APIUrl = repl.ReplaceKnown(c.APIUrl, "") - c.APIKey = repl.ReplaceKnown(c.APIKey, "") - - s := c.APIUrl - u, err := url.Parse(s) - if err != nil { - return fmt.Errorf("invalid CrowdSec API URL: %e", err) - } - if u.Scheme == "" { - return fmt.Errorf("URL %s does not have a scheme (i.e https)", u.String()) - } - if !strings.HasSuffix(s, "/") { - s = s + "/" - } - c.APIUrl = s - if c.APIKey == "" { - return errors.New("crowdsec API Key is missing") - } - if c.TickerInterval == "" { - c.TickerInterval = defaultTickerInterval - } - if c.EnableStreaming == nil { - value := defaultStreamingEnabled - c.EnableStreaming = &value - } - if c.EnableHardFails == nil { - value := defaultHardFailsEnabled - c.EnableHardFails = &value - } - return nil -} - // Validate ensures the app's configuration is valid. func (c *CrowdSec) Validate() error { - - // TODO: fail hard after provisioning is not correct? Or do it in provisioning already? - if c.APIKey == "" { - return errors.New("crowdsec API Key must not be empty") + return errors.New("crowdsec API key must not be empty") } - if c.bouncer == nil { return errors.New("bouncer instance not available due to (potential) misconfiguration") } @@ -183,9 +117,22 @@ func (c *CrowdSec) Validate() error { return nil } +func (c *CrowdSec) Cleanup() error { + if err := c.bouncer.Shutdown(); err != nil { + return fmt.Errorf("failed cleaning up: %w", err) + } + + return nil +} + // Start starts the CrowdSec Caddy app func (c *CrowdSec) Start() error { + if err := c.bouncer.Init(); err != nil { + return err + } + c.bouncer.Run() + return nil } @@ -202,18 +149,18 @@ func (c *CrowdSec) IsAllowed(ip net.IP) (bool, *models.Decision, error) { } func (c *CrowdSec) isStreamingEnabled() bool { - return *c.EnableStreaming + return c.EnableStreaming == nil || *c.EnableStreaming } func (c *CrowdSec) shouldFailHard() bool { - return *c.EnableHardFails + return c.EnableHardFails != nil && *c.EnableHardFails } // Interface guards var ( - _ caddy.Module = (*CrowdSec)(nil) - _ caddy.App = (*CrowdSec)(nil) - _ caddy.Provisioner = (*CrowdSec)(nil) - _ caddy.Validator = (*CrowdSec)(nil) - //_ caddyfile.Unmarshaler = (*CrowdSec)(nil) + _ caddy.Module = (*CrowdSec)(nil) + _ caddy.App = (*CrowdSec)(nil) + _ caddy.Provisioner = (*CrowdSec)(nil) + _ caddy.Validator = (*CrowdSec)(nil) + _ caddy.CleanerUpper = (*CrowdSec)(nil) ) diff --git a/crowdsec/crowdsec_test.go b/crowdsec/crowdsec_test.go new file mode 100644 index 00000000..67b72395 --- /dev/null +++ b/crowdsec/crowdsec_test.go @@ -0,0 +1,269 @@ +// Copyright 2020 Herman Slatman +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crowdsec + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/caddyserver/caddy/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" +) + +func TestCrowdSec_Provision(t *testing.T) { + tests := []struct { + name string + config string + env map[string]string + assertion func(tt assert.TestingT, c *CrowdSec) + wantErr bool + }{ + { + name: "ok", + config: `{ + "api_url": "http://localhost:8080", + "api_key": "test-key", + "ticker_interval": "10s", + "enable_streaming": false, + "enable_hard_fails": true + }`, + assertion: func(tt assert.TestingT, c *CrowdSec) { + assert.Equal(tt, "http://localhost:8080", c.APIUrl) + assert.Equal(tt, "test-key", c.APIKey) + assert.Equal(tt, "10s", c.TickerInterval) + assert.False(tt, c.isStreamingEnabled()) + assert.True(tt, c.shouldFailHard()) + }, + wantErr: false, + }, + { + name: "defaults", + config: `{}`, + assertion: func(tt assert.TestingT, c *CrowdSec) { + assert.Equal(tt, "http://127.0.0.1:8080/", c.APIUrl) + assert.Equal(tt, "", c.APIKey) + assert.Equal(tt, "60s", c.TickerInterval) + assert.True(tt, c.isStreamingEnabled()) + assert.False(tt, c.shouldFailHard()) + }, + wantErr: false, + }, + { + name: "json-env-vars", + config: `{ + "api_url": "{env.CROWDSEC_TEST_API_URL}", + "api_key": "{env.CROWDSEC_TEST_API_KEY}", + "ticker_interval": "{env.CROWDSEC_TEST_TICKER_INTERVAL}" + }`, + env: map[string]string{ + "CROWDSEC_TEST_API_URL": "http://127.0.0.2:8080/", + "CROWDSEC_TEST_API_KEY": "env-test-key", + "CROWDSEC_TEST_TICKER_INTERVAL": "25s", + }, + assertion: func(tt assert.TestingT, c *CrowdSec) { + assert.Equal(tt, "http://127.0.0.2:8080/", c.APIUrl) + assert.Equal(tt, "env-test-key", c.APIKey) + assert.Equal(tt, "25s", c.TickerInterval) + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var c CrowdSec + err := json.Unmarshal([]byte(tt.config), &c) + require.NoError(t, err) + + for k, v := range tt.env { + t.Setenv(k, v) + } + + ctx, _ := caddy.NewContext(caddy.Context{Context: context.Background()}) + err = c.Provision(ctx) + require.NoError(t, err) + + if tt.assertion != nil { + tt.assertion(t, &c) + } + }) + } +} + +func TestCrowdSec_Validate(t *testing.T) { + tests := []struct { + name string + config string + wantErr bool + }{ + { + name: "ok", + config: `{ + "api_url": "http://localhost:8080", + "api_key": "test-key", + "ticker_interval": "10s", + "enable_streaming": false, + "enable_hard_fails": true + }`, + wantErr: false, + }, + { + name: "fail/missing-api-key", + config: `{ + "api_url": "http://localhost:8080", + "api_key": "" + }`, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var c CrowdSec + err := json.Unmarshal([]byte(tt.config), &c) + require.NoError(t, err) + + ctx, _ := caddy.NewContext(caddy.Context{Context: context.Background()}) + err = c.Provision(ctx) + require.NoError(t, err) + + err = c.Validate() + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + }) + } +} + +func TestCrowdSec_streamingBouncerRuntime(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) // ignore current ones; they're deep in the Caddy stack + requestCount := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount += 1 + w.WriteHeader(200) // just accept any request + w.Write(nil) // nolint + })) + defer srv.Close() + + config := fmt.Sprintf(`{ + "api_url": %q, + "api_key": "test-key" + }`, srv.URL) // set test server URL as API URL + + var c CrowdSec + err := json.Unmarshal([]byte(config), &c) + require.NoError(t, err) + + caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) + defer cancel() + + err = c.Provision(caddyCtx) + require.NoError(t, err) + require.True(t, c.isStreamingEnabled()) + + err = c.Validate() + require.NoError(t, err) + + err = c.Start() + require.NoError(t, err) + + wg := &sync.WaitGroup{} + wg.Add(1) + go func() { + // simulate request coming in and stopping the server from another goroutine + defer wg.Done() + + // wait a little bit of time to let the go-cs-bouncer do _some_ work, + // before it properly returns; seems to hang otherwise on b.wg.Wait(). + time.Sleep(100 * time.Millisecond) + + // simulate a lookup + allowed, decision, err := c.IsAllowed(net.ParseIP("127.0.0.1")) + assert.NoError(t, err) + assert.Nil(t, decision) + assert.True(t, allowed) + + err = c.Stop() + require.NoError(t, err) + + err = c.Cleanup() + require.NoError(t, err) + }() + + // wait for the stop and cleanup process + wg.Wait() + + // expect a single request to have been performed + assert.Equal(t, 1, requestCount) +} + +func TestCrowdSec_liveBouncerRuntime(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) // ignore current ones; they're deep in the Caddy stack + requestCount := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount += 1 + w.WriteHeader(200) // just accept any request + w.Write(nil) // nolint + })) + defer srv.Close() + + config := fmt.Sprintf(`{ + "api_url": %q, + "api_key": "test-key", + "enable_streaming": false + }`, srv.URL) // set test server URL as API URL + + var c CrowdSec + err := json.Unmarshal([]byte(config), &c) + require.NoError(t, err) + + caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) + defer cancel() + + err = c.Provision(caddyCtx) + require.NoError(t, err) + require.False(t, c.isStreamingEnabled()) + + err = c.Validate() + require.NoError(t, err) + + err = c.Start() + require.NoError(t, err) + + // simulate a lookup + allowed, decision, err := c.IsAllowed(net.ParseIP("127.0.0.1")) + assert.NoError(t, err) + assert.Nil(t, decision) + assert.True(t, allowed) + + err = c.Stop() + require.NoError(t, err) + + err = c.Cleanup() + require.NoError(t, err) + + // expect a single request to have been performed + assert.Equal(t, 1, requestCount) +} diff --git a/go.mod b/go.mod index 79bfb0fa..9f209614 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,8 @@ require ( github.com/jarcoal/httpmock v1.3.1 github.com/mholt/caddy-l4 v0.0.0-20231016112149-a362a1fbf652 github.com/sirupsen/logrus v1.9.3 + github.com/stretchr/testify v1.8.4 + go.uber.org/goleak v1.2.1 go.uber.org/zap v1.26.0 ) @@ -35,6 +37,7 @@ require ( github.com/chzyer/readline v1.5.1 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect github.com/crowdsecurity/go-cs-lib v0.0.5 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgraph-io/badger v1.6.2 // indirect github.com/dgraph-io/badger/v2 v2.2007.4 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect @@ -97,6 +100,7 @@ require ( github.com/oklog/ulid v1.3.1 // indirect github.com/onsi/ginkgo/v2 v2.13.0 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.17.0 // indirect github.com/prometheus/client_model v0.5.0 // indirect github.com/prometheus/common v0.45.0 // indirect @@ -124,7 +128,6 @@ require ( go.step.sm/cli-utils v0.8.0 // indirect go.step.sm/crypto v0.36.1 // indirect go.step.sm/linkedca v0.20.1 // indirect - go.uber.org/goleak v1.2.1 // indirect go.uber.org/mock v0.3.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.14.0 // indirect diff --git a/http/http.go b/http/http.go index 39c8f431..11681f04 100644 --- a/http/http.go +++ b/http/http.go @@ -50,7 +50,6 @@ func (Handler) CaddyModule() caddy.ModuleInfo { // Provision sets up the CrowdSec handler. func (h *Handler) Provision(ctx caddy.Context) error { - crowdsecAppIface, err := ctx.App("crowdsec") if err != nil { return fmt.Errorf("getting crowdsec app: %v", err) diff --git a/internal/bouncer/bouncer.go b/internal/bouncer/bouncer.go index 635ac3ba..c866c9e4 100644 --- a/internal/bouncer/bouncer.go +++ b/internal/bouncer/bouncer.go @@ -16,9 +16,12 @@ package bouncer import ( "context" + "encoding/hex" "fmt" + "math/rand" "net" "sync" + "time" "github.com/crowdsecurity/crowdsec/pkg/models" csbouncer "github.com/crowdsecurity/go-cs-bouncer" @@ -27,10 +30,13 @@ import ( "go.uber.org/zap/zapcore" ) -const version = "v0.5.4" +const version = "v0.6.0" const maxNumberOfDecisionsToLog = 10 -// Bouncer is a custom CrowdSec bouncer backed by an immutable radix tree +// Bouncer is a wrapper for a CrowdSec bouncer. It supports both the the +// streaming and live bouncer implementations. The streaming bouncer is +// backed by an immutable radix tree storing known bad IPs and IP ranges. +// The live bouncer will reach out to the CrowdSec agent on every check. type Bouncer struct { streamingBouncer *csbouncer.StreamBouncer liveBouncer *csbouncer.LiveBouncer @@ -38,10 +44,15 @@ type Bouncer struct { logger *zap.Logger useStreamingBouncer bool shouldFailHard bool - - ctx context.Context - cancel context.CancelFunc - wg *sync.WaitGroup + instantiatedAt time.Time + instanceID string + + ctx context.Context + started bool + stopped bool + startMu sync.Mutex + cancel context.CancelFunc + wg *sync.WaitGroup } // New creates a new (streaming) Bouncer with a storage based on immutable radix tree @@ -49,6 +60,12 @@ type Bouncer struct { func New(apiKey, apiURL, tickerInterval string, logger *zap.Logger) (*Bouncer, error) { userAgent := fmt.Sprintf("caddy-cs-bouncer/%s", version) insecureSkipVerify := false + instantiatedAt := time.Now() + instanceID, err := generateInstanceID(instantiatedAt) + if err != nil { + return nil, fmt.Errorf("failed generating instance ID: %w", err) + } + return &Bouncer{ streamingBouncer: &csbouncer.StreamBouncer{ APIKey: apiKey, @@ -64,11 +81,22 @@ func New(apiKey, apiURL, tickerInterval string, logger *zap.Logger) (*Bouncer, e InsecureSkipVerify: &insecureSkipVerify, UserAgent: userAgent, }, - store: newStore(), - logger: logger, + store: newStore(), + logger: logger, + instantiatedAt: instantiatedAt, + instanceID: instanceID, }, nil } +func generateInstanceID(t time.Time) (string, error) { + r := rand.New(rand.NewSource(t.Unix())) + b := [4]byte{} + if _, err := r.Read(b[:]); err != nil { + return "", err + } + return hex.EncodeToString(b[:]), nil +} + // EnableStreaming enables usage of the StreamBouncer (instead of the LiveBouncer). func (b *Bouncer) EnableStreaming() { b.useStreamingBouncer = true @@ -81,22 +109,36 @@ func (b *Bouncer) EnableHardFails() { b.streamingBouncer.RetryInitialConnect = false } +func (b *Bouncer) zapField() zapcore.Field { + return zap.String("instance_id", b.instanceID) +} + // Init initializes the Bouncer func (b *Bouncer) Init() error { // override CrowdSec's default logrus logging b.overrideLogrusLogger() - // initialize the CrowdSec streaming bouncer - if b.useStreamingBouncer { - return b.streamingBouncer.Init() + // initialize the CrowdSec live bouncer + if !b.useStreamingBouncer { + b.logger.Info("initializing live bouncer", b.zapField()) + return b.liveBouncer.Init() } - // initialize the CrowdSec live bouncer - return b.liveBouncer.Init() + // initialize the CrowdSec streaming bouncer + b.logger.Info("initializing streaming bouncer", b.zapField()) + return b.streamingBouncer.Init() } // Run starts the Bouncer processes func (b *Bouncer) Run() { + b.startMu.Lock() + defer b.startMu.Unlock() + if b.started { + return + } + b.started = true + b.logger.Info("started", b.zapField()) + // the LiveBouncer has nothing to run in the background; return early if !b.useStreamingBouncer { return @@ -124,7 +166,7 @@ func (b *Bouncer) Run() { for { select { case <-b.ctx.Done(): - b.logger.Info("processing new and deleted decisions stopped") + b.logger.Info("processing new and deleted decisions stopped", b.zapField()) return case decisions := <-b.streamingBouncer.Stream: if decisions == nil { @@ -133,38 +175,38 @@ func (b *Bouncer) Run() { // TODO: deletions seem to include all old decisions that had already expired; CrowdSec bug or intended behavior? // TODO: process in separate goroutines/waitgroup? if numberOfDeletedDecisions := len(decisions.Deleted); numberOfDeletedDecisions > 0 { - b.logger.Debug(fmt.Sprintf("processing %d deleted decisions", numberOfDeletedDecisions)) + b.logger.Debug(fmt.Sprintf("processing %d deleted decisions", numberOfDeletedDecisions), b.zapField()) for _, decision := range decisions.Deleted { if err := b.delete(decision); err != nil { - b.logger.Error(fmt.Sprintf("unable to delete decision for %q: %s", *decision.Value, err)) + b.logger.Error(fmt.Sprintf("unable to delete decision for %q: %s", *decision.Value, err), b.zapField()) } else { if numberOfDeletedDecisions <= maxNumberOfDecisionsToLog { - b.logger.Debug(fmt.Sprintf("deleted %q (scope: %s)", *decision.Value, *decision.Scope)) + b.logger.Debug(fmt.Sprintf("deleted %q (scope: %s)", *decision.Value, *decision.Scope), b.zapField()) } } } if numberOfDeletedDecisions > maxNumberOfDecisionsToLog { - b.logger.Debug(fmt.Sprintf("skipped logging for %d deleted decisions", numberOfDeletedDecisions)) + b.logger.Debug(fmt.Sprintf("skipped logging for %d deleted decisions", numberOfDeletedDecisions), b.zapField()) } - b.logger.Debug(fmt.Sprintf("finished processing %d deleted decisions", numberOfDeletedDecisions)) + b.logger.Debug(fmt.Sprintf("finished processing %d deleted decisions", numberOfDeletedDecisions), b.zapField()) } // TODO: process in separate goroutines/waitgroup? if numberOfNewDecisions := len(decisions.New); numberOfNewDecisions > 0 { - b.logger.Debug(fmt.Sprintf("processing %d new decisions", numberOfNewDecisions)) + b.logger.Debug(fmt.Sprintf("processing %d new decisions", numberOfNewDecisions), b.zapField()) for _, decision := range decisions.New { if err := b.add(decision); err != nil { - b.logger.Error(fmt.Sprintf("unable to insert decision for %q: %s", *decision.Value, err)) + b.logger.Error(fmt.Sprintf("unable to insert decision for %q: %s", *decision.Value, err), b.zapField()) } else { if numberOfNewDecisions <= maxNumberOfDecisionsToLog { - b.logger.Debug(fmt.Sprintf("adding %q (scope: %s) for %q", *decision.Value, *decision.Scope, *decision.Duration)) + b.logger.Debug(fmt.Sprintf("adding %q (scope: %s) for %q", *decision.Value, *decision.Scope, *decision.Duration), b.zapField()) } } } if numberOfNewDecisions > maxNumberOfDecisionsToLog { - b.logger.Debug(fmt.Sprintf("skipped logging for %d new decisions", numberOfNewDecisions)) + b.logger.Debug(fmt.Sprintf("skipped logging for %d new decisions", numberOfNewDecisions), b.zapField()) } - b.logger.Debug(fmt.Sprintf("finished processing %d new decisions", numberOfNewDecisions)) + b.logger.Debug(fmt.Sprintf("finished processing %d new decisions", numberOfNewDecisions), b.zapField()) } } } @@ -173,6 +215,18 @@ func (b *Bouncer) Run() { // Shutdown stops the Bouncer func (b *Bouncer) Shutdown() error { + b.startMu.Lock() + defer b.startMu.Unlock() + if !b.started || b.stopped { + return nil + } + b.logger.Info("stopping", b.zapField()) + defer func() { + b.stopped = true + b.logger.Info("finished", b.zapField()) + b.logger.Sync() // nolint + }() + // the LiveBouncer has nothing to do on shutdown if !b.useStreamingBouncer { return nil @@ -233,6 +287,7 @@ func (b *Bouncer) retrieveDecision(ip net.IP) (*models.Decision, error) { decision, err := b.liveBouncer.Get(ip.String()) if err != nil { fields := []zapcore.Field{ + b.zapField(), zap.String("address", b.liveBouncer.APIUrl), zap.Error(err), } diff --git a/internal/bouncer/bouncer_test.go b/internal/bouncer/bouncer_test.go index 39190a15..015de3ba 100644 --- a/internal/bouncer/bouncer_test.go +++ b/internal/bouncer/bouncer_test.go @@ -10,10 +10,11 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/google/go-cmp/cmp" "github.com/jarcoal/httpmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" - - "github.com/google/go-cmp/cmp" ) func new(t *testing.T) (*Bouncer, error) { @@ -229,3 +230,9 @@ func TestStreamingBouncer(t *testing.T) { } } } + +func Test_generateInstanceID(t *testing.T) { + id, err := generateInstanceID(time.Now()) + require.NoError(t, err) + assert.Len(t, id, 8) +} diff --git a/internal/bouncer/logging.go b/internal/bouncer/logging.go index d2edbd9d..a28e9255 100644 --- a/internal/bouncer/logging.go +++ b/internal/bouncer/logging.go @@ -3,28 +3,46 @@ package bouncer import ( "errors" "io" + "unicode" + "unicode/utf8" "github.com/sirupsen/logrus" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) +// overrideLogrusLogger overrides the (default) settings of the standard +// logrus logger. The logrus logger is used by the `go-cs-bouncer` package, +// whereas Caddy uses zap. The output of the standard logger is discarded, +// and a hook is used to send messages to Caddy's zap logger instead. +// +// Note that this method changes global state, but only after a new Bouncer +// is provisioned, validated and has just been started. This should thus +// generally not be a problem. func (b *Bouncer) overrideLogrusLogger() { + // the CrowdSec go-cs-bouncer uses the standard logrus logger + std := logrus.StandardLogger() + // silence the default CrowdSec logrus logging - logrus.SetOutput(io.Discard) + std.SetOutput(io.Discard) - // catch log entries and log them using the *zap.Logger instead - logrus.AddHook(&zapAdapterHook{ + // replace all hooks on the standard logrus logger + hooks := logrus.LevelHooks{} + hooks.Add(&zapAdapterHook{ logger: b.logger, shouldFailHard: b.shouldFailHard, address: b.streamingBouncer.APIUrl, + instanceID: b.instanceID, }) + + std.ReplaceHooks(hooks) } type zapAdapterHook struct { logger *zap.Logger shouldFailHard bool address string + instanceID string } func (zh *zapAdapterHook) Levels() []logrus.Level { @@ -43,28 +61,40 @@ func (zh *zapAdapterHook) Fire(entry *logrus.Entry) error { // TODO: extract details from entry.Data? But doesn't seem to be used by CrowdSec today. msg := entry.Message - fields := []zapcore.Field{zap.String("address", zh.address)} + fields := []zapcore.Field{zap.String("instance_id", zh.instanceID), zap.String("address", zh.address)} switch { case entry.Level <= logrus.ErrorLevel: // error, fatal, panic fields = append(fields, zap.Error(errors.New(msg))) if zh.shouldFailHard { // TODO: if we keep this Fatal and the "shouldFailhard" around, ensure we // shut the bouncer down nicely - zh.logger.Fatal(msg, fields...) + zh.logger.Fatal(firstToLower(msg), fields...) } else { - zh.logger.Error(msg, fields...) + zh.logger.Error(firstToLower(msg), fields...) } default: level := zapcore.DebugLevel if l, ok := levelAdapter[entry.Level]; ok { level = l } - zh.logger.Log(level, msg, fields...) + zh.logger.Log(level, firstToLower(msg), fields...) } return nil } +func firstToLower(s string) string { + r, size := utf8.DecodeRuneInString(s) + if r == utf8.RuneError && size <= 1 { + return s + } + lc := unicode.ToLower(r) + if r == lc { + return s + } + return string(lc) + s[size:] +} + var levelAdapter = map[logrus.Level]zapcore.Level{ logrus.TraceLevel: zapcore.DebugLevel, // no trace level in zap logrus.DebugLevel: zapcore.DebugLevel, diff --git a/internal/bouncer/store.go b/internal/bouncer/store.go index d4c8e612..90a45166 100644 --- a/internal/bouncer/store.go +++ b/internal/bouncer/store.go @@ -33,7 +33,6 @@ func newStore() *crowdSecStore { } func (s *crowdSecStore) add(decision *models.Decision) error { - if isInvalid(decision) { return nil } diff --git a/logs/.gitignore b/logs/.gitignore deleted file mode 100644 index 76d22faa..00000000 --- a/logs/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -*.log -!dummy.log \ No newline at end of file diff --git a/logs/dummy.log b/logs/dummy.log deleted file mode 100644 index e69de29b..00000000