diff --git a/cmd/dex/config_test.go b/cmd/dex/config_test.go index c6d37cb03e..401600cc01 100644 --- a/cmd/dex/config_test.go +++ b/cmd/dex/config_test.go @@ -7,6 +7,7 @@ import ( "github.com/ghodss/yaml" "github.com/kylelemons/godebug/pretty" + "github.com/stretchr/testify/require" "github.com/dexidp/dex/connector/mock" "github.com/dexidp/dex/connector/oidc" @@ -447,3 +448,23 @@ logger: t.Errorf("got!=want: %s", diff) } } + +func TestUnmarshalConfigWithRetry(t *testing.T) { + rawConfig := []byte(` +storage: + type: postgres + config: + host: 10.0.0.1 + port: 65432 + retryAttempts: 10 + retryDelay: "1s" +`) + + var c Config + err := yaml.Unmarshal(rawConfig, &c) + require.NoError(t, err) + + require.Equal(t, "postgres", c.Storage.Type) + require.Equal(t, 10, c.Storage.RetryAttempts) + require.Equal(t, "1s", c.Storage.RetryDelay) +} diff --git a/cmd/dex/serve_test.go b/cmd/dex/serve_test.go index 9e214480d3..7e64edaa4f 100644 --- a/cmd/dex/serve_test.go +++ b/cmd/dex/serve_test.go @@ -1,10 +1,15 @@ package main import ( + "context" + "errors" "log/slog" "testing" "github.com/stretchr/testify/require" + + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/memory" ) func TestNewLogger(t *testing.T) { @@ -27,3 +32,50 @@ func TestNewLogger(t *testing.T) { require.Equal(t, (*slog.Logger)(nil), logger) }) } + +func TestStorageInitializationRetry(t *testing.T) { + _, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Create a mock storage that fails a certain number of times before succeeding + mockStorage := &mockRetryStorage{ + failuresLeft: 3, + } + + config := Config{ + Issuer: "http://127.0.0.1:5556/dex", + Storage: Storage{ + Type: "mock", + Config: mockStorage, + RetryAttempts: 5, + RetryDelay: "1s", + }, + Web: Web{ + HTTP: "127.0.0.1:5556", + }, + Logger: Logger{ + Level: slog.LevelInfo, + Format: "json", + }, + } + + logger, _ := newLogger(config.Logger.Level, config.Logger.Format) + + s, err := initializeStorageWithRetry(config.Storage, logger) + require.NoError(t, err) + require.NotNil(t, s) + + require.Equal(t, 0, mockStorage.failuresLeft) +} + +type mockRetryStorage struct { + failuresLeft int +} + +func (m *mockRetryStorage) Open(logger *slog.Logger) (storage.Storage, error) { + if m.failuresLeft > 0 { + m.failuresLeft-- + return nil, errors.New("mock storage failure") + } + return memory.New(logger), nil +}