diff --git a/main.go b/main.go index 1138da8..076bb83 100644 --- a/main.go +++ b/main.go @@ -39,7 +39,7 @@ import ( func NewProvider(providerName string, logger *slog.Logger, sigs chan os.Signal) (provider.Provider, error) { switch providerName { case file.ProviderName: - config := file.NewConfig(logger) + config := file.NewConfig() provider, err := file.NewProvider(config) if err != nil { return nil, fmt.Errorf("failed to create file provider: %w", err) @@ -47,12 +47,12 @@ func NewProvider(providerName string, logger *slog.Logger, sigs chan os.Signal) return provider, nil case vault.ProviderName: - config, err := vault.NewConfig(logger, sigs) + config, err := vault.NewConfig() if err != nil { return nil, fmt.Errorf("failed to create vault config: %w", err) } - provider, err := vault.NewProvider(config) + provider, err := vault.NewProvider(config, logger, sigs) if err != nil { return nil, fmt.Errorf("failed to create vault provider: %w", err) } diff --git a/provider/file/config.go b/provider/file/config.go index b0db5ea..e2e4e32 100644 --- a/provider/file/config.go +++ b/provider/file/config.go @@ -15,7 +15,7 @@ package file import ( - "log/slog" + "fmt" "os" ) @@ -28,10 +28,11 @@ type Config struct { MountPath string `json:"mountPath"` } -func NewConfig(logger *slog.Logger) *Config { +func NewConfig() *Config { mountPath, ok := os.LookupEnv(EnvPrefix + "MOUNT_PATH") if !ok { - logger.Warn("Mount path not provided. Using default.", slog.String("Default Mount Path", DefaultMountPath)) + fmt.Printf("Mount path not provided. Using default. Default Mount Path: %s\n", DefaultMountPath) + mountPath = DefaultMountPath } diff --git a/provider/file/config_test.go b/provider/file/config_test.go index 02a4f80..737ccdb 100644 --- a/provider/file/config_test.go +++ b/provider/file/config_test.go @@ -15,7 +15,6 @@ package file import ( - "log/slog" "os" "testing" @@ -48,7 +47,7 @@ func TestConfig(t *testing.T) { for envKey, envVal := range ttp.env { os.Setenv(envKey, envVal) } - config := NewConfig(slog.Default()) + config := NewConfig() assert.Equal(t, ttp.wantMountPath, config.MountPath, "Unexpected mount path") }) diff --git a/provider/file/file_test.go b/provider/file/file_test.go index 35446d4..f61b388 100644 --- a/provider/file/file_test.go +++ b/provider/file/file_test.go @@ -54,10 +54,10 @@ func TestNewProvider(t *testing.T) { func TestLoadSecrets(t *testing.T) { tests := []struct { - name string - paths []string - wantErr bool - wantData []provider.Secret + name string + paths []string + wantErr bool + wantSecrets []provider.Secret }{ { name: "Load secrets successfully", @@ -67,7 +67,7 @@ func TestLoadSecrets(t *testing.T) { "test/secrets/awsid.txt", }, wantErr: false, - wantData: []provider.Secret{ + wantSecrets: []provider.Secret{ {Path: "test/secrets/sqlpass.txt", Value: "3xtr3ms3cr3t"}, {Path: "test/secrets/awsaccess.txt", Value: "s3cr3t"}, {Path: "test/secrets/awsid.txt", Value: "secretId"}, @@ -80,8 +80,8 @@ func TestLoadSecrets(t *testing.T) { "test/secrets/mistake/awsaccess.txt", "test/secrets/mistake/awsid.txt", }, - wantErr: true, - wantData: nil, + wantErr: true, + wantSecrets: nil, }, } @@ -97,7 +97,7 @@ func TestLoadSecrets(t *testing.T) { secrets, err := provider.LoadSecrets(context.Background(), ttp.paths) assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status") - assert.Equal(t, ttp.wantData, secrets, "Unexpected secrets") + assert.Equal(t, ttp.wantSecrets, secrets, "Unexpected secrets") }) } } diff --git a/provider/vault/config.go b/provider/vault/config.go index a9ade06..cb2dafa 100644 --- a/provider/vault/config.go +++ b/provider/vault/config.go @@ -15,8 +15,7 @@ package vault import ( - "errors" - "log/slog" + "fmt" "os" "strings" @@ -24,7 +23,8 @@ import ( ) const ( - EnvPrefix = "VAULT_" + EnvPrefix = "VAULT_" + SecretInitDaemonEnv = "SECRET_INIT_DAEMON" // The special value for SECRET_INIT which marks that the login token needs to be passed through to the application // which was acquired during the vault client initialization. vaultLogin = "vault:login" @@ -44,8 +44,6 @@ type Config struct { IgnoreMissingSecrets bool `json:"ignoreMissingSecrets"` FromPath string `json:"fromPath"` RevokeToken bool `json:"revokeToken"` - Logger *slog.Logger - Sigs chan os.Signal } type envType struct { @@ -85,7 +83,7 @@ var sanitizeEnvmap = map[string]envType{ "SECRET_INIT_DAEMON": {login: false}, } -func NewConfig(logger *slog.Logger, sigs chan os.Signal) (*Config, error) { +func NewConfig() (*Config, error) { var ( role, authPath, authMethod string hasRole, hasPath, hasAuthMethod bool @@ -102,9 +100,7 @@ func NewConfig(logger *slog.Logger, sigs chan os.Signal) (*Config, error) { if b, err := os.ReadFile(tokenFile); err == nil { vaultToken = string(b) } else { - logger.Error("could not read vault token file", slog.String("file", tokenFile)) - - return nil, err + return nil, fmt.Errorf("failed to read token file: %w", err) } } else { if isLogin { @@ -115,9 +111,8 @@ func NewConfig(logger *slog.Logger, sigs chan os.Signal) (*Config, error) { authPath, hasPath = os.LookupEnv(EnvPrefix + "PATH") authMethod, hasAuthMethod = os.LookupEnv(EnvPrefix + "AUTH_METHOD") if !hasRole || !hasPath || !hasAuthMethod { - logger.Error("Incomplete authentication configuration. Make sure VAULT_ROLE, VAULT_PATH, and VAULT_AUTH_METHOD are set.") - - return nil, errors.New("incomplete authentication configuration") + return nil, fmt.Errorf("incomplete authentication configuration %s, %s, and %s", + "VAULT_ROLE", "VAULT_PATH", "VAULT_AUTH_METHOD") } } @@ -138,7 +133,7 @@ func NewConfig(logger *slog.Logger, sigs chan os.Signal) (*Config, error) { transitKeyID := os.Getenv(EnvPrefix + "TRANSIT_KEY_ID") transitPath := os.Getenv(EnvPrefix + "TRANSIT_PATH") transitBatchSize := cast.ToInt(os.Getenv(EnvPrefix + "TRANSIT_BATCH_SIZE")) - daemonMode := cast.ToBool(os.Getenv("SECRET_INIT_DAEMON_MODE")) + daemonMode := cast.ToBool(os.Getenv(SecretInitDaemonEnv)) // Used both for reading secrets and transit encryption ignoreMissingSecrets := cast.ToBool(os.Getenv(EnvPrefix + "IGNORE_MISSING_SECRETS")) @@ -159,7 +154,5 @@ func NewConfig(logger *slog.Logger, sigs chan os.Signal) (*Config, error) { IgnoreMissingSecrets: ignoreMissingSecrets, FromPath: fromPath, RevokeToken: revokeToken, - Logger: logger, - Sigs: sigs, }, nil } diff --git a/provider/vault/config_test.go b/provider/vault/config_test.go index e17763c..c1595e8 100644 --- a/provider/vault/config_test.go +++ b/provider/vault/config_test.go @@ -1,7 +1,6 @@ package vault import ( - "log/slog" "os" "path/filepath" "testing" @@ -10,9 +9,6 @@ import ( ) func TestConfig(t *testing.T) { - // mock logger, sigs, and tokenfile - logger := slog.Default() - sigs := make(chan os.Signal, 1) tokenFile := newTokenFile(t) defer os.Remove(tokenFile) @@ -25,15 +21,28 @@ func TestConfig(t *testing.T) { { name: "Valid login configuration with Token", env: map[string]string{ - "VAULT_TOKEN": vaultLogin, - "VAULT_TOKEN_FILE": tokenFile, + "VAULT_TOKEN": vaultLogin, + "VAULT_TOKEN_FILE": tokenFile, + "VAULT_PASSTHROUGH": "VAULT_AGENT_ADDR, VAULT_CLI_NO_COLOR", + "VAULT_TRANSIT_KEY_ID": "test-key", + "VAULT_TRANSIT_PATH": "transit", + "VAULT_TRANSIT_BATCH_SIZE": "10", + "SECRET_INIT_DAEMON": "true", + "VAULT_IGNORE_MISSING_SECRETS": "true", + "VAULT_REVOKE_TOKEN": "true", + "VAULT_FROM_PATH": "secret/data/test", }, wantConfig: &Config{ - Islogin: true, - Token: "root", - TokenFile: tokenFile, - Logger: logger, - Sigs: sigs, + Islogin: true, + Token: "root", + TokenFile: tokenFile, + TransitKeyID: "test-key", + TransitPath: "transit", + TransitBatchSize: 10, + DaemonMode: true, + IgnoreMissingSecrets: true, + FromPath: "secret/data/test", + RevokeToken: true, }, wantErr: false, }, @@ -51,8 +60,6 @@ func TestConfig(t *testing.T) { Role: "test-app-role", AuthPath: "auth/approle/test/login", AuthMethod: "test-approle", - Logger: logger, - Sigs: sigs, }, wantErr: false, }, @@ -65,7 +72,7 @@ func TestConfig(t *testing.T) { wantErr: true, }, { - name: "Invalid login configuration missing role/path credentials", + name: "Invalid login configuration missing role - path credentials", env: map[string]string{ "VAULT_PATH": "auth/approle/test/login", "VAULT_AUTH_METHOD": "test-approle", @@ -82,7 +89,7 @@ func TestConfig(t *testing.T) { os.Setenv(envKey, envVal) } - config, err := NewConfig(logger, sigs) + config, err := NewConfig() assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status") assert.Equal(t, ttp.wantConfig, config, "Unexpected config") @@ -106,5 +113,6 @@ func newTokenFile(t *testing.T) string { if err != nil { t.Fatalf("Failed to write to a temporary token file: %v", err) } + return tokenFile.Name() } diff --git a/provider/vault/vault.go b/provider/vault/vault.go index d8f6ca1..f42f781 100644 --- a/provider/vault/vault.go +++ b/provider/vault/vault.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "log/slog" + "os" "strings" "github.com/bank-vaults/internal/injector" @@ -57,8 +58,8 @@ func (s *sanitized) append(key string, value string) { } } -func NewProvider(config *Config) (provider.Provider, error) { - clientOptions := []vault.ClientOption{vault.ClientLogger(clientLogger{config.Logger})} +func NewProvider(config *Config, logger *slog.Logger, sigs chan os.Signal) (provider.Provider, error) { + clientOptions := []vault.ClientOption{vault.ClientLogger(clientLogger{logger})} if config.TokenFile != "" { clientOptions = append(clientOptions, vault.ClientToken(config.Token)) } else { @@ -72,7 +73,7 @@ func NewProvider(config *Config) (provider.Provider, error) { client, err := vault.NewClientWithOptions(clientOptions...) if err != nil { - config.Logger.Error(fmt.Errorf("failed to create vault client: %w", err).Error()) + logger.Error(fmt.Errorf("failed to create vault client: %w", err).Error()) return nil, err } @@ -88,8 +89,8 @@ func NewProvider(config *Config) (provider.Provider, error) { var secretRenewer injector.SecretRenewer if config.DaemonMode { - secretRenewer = daemonSecretRenewer{client: client, sigs: config.Sigs, logger: config.Logger} - config.Logger.Info("Daemon mode enabled. Will renew secrets in the background.") + secretRenewer = daemonSecretRenewer{client: client, sigs: sigs, logger: logger} + logger.Info("Daemon mode enabled. Will renew secrets in the background.") } return &Provider{ @@ -99,7 +100,7 @@ func NewProvider(config *Config) (provider.Provider, error) { secretRenewer: secretRenewer, fromPath: config.FromPath, revokeToken: config.RevokeToken, - logger: config.Logger, + logger: logger, }, nil } diff --git a/provider/vault/vault_test.go b/provider/vault/vault_test.go new file mode 100644 index 0000000..f82cde7 --- /dev/null +++ b/provider/vault/vault_test.go @@ -0,0 +1,64 @@ +package vault + +import ( + "bytes" + "log/slog" + "os" + "testing" + + "github.com/bank-vaults/internal/injector" + "github.com/stretchr/testify/assert" +) + +func TestNewProvider(t *testing.T) { + tests := []struct { + name string + config *Config + wantInjectorConfig injector.Config + wantErr bool + wantType bool + }{ + { + name: "Valid Provider with Token", + config: &Config{ + Islogin: true, + TokenFile: "root", + Token: "root", + TransitKeyID: "test-key", + TransitPath: "transit", + TransitBatchSize: 10, + DaemonMode: true, + IgnoreMissingSecrets: true, + FromPath: "secret/data/test", + RevokeToken: true, + }, + wantErr: false, + wantType: true, + }, + { + name: "Fail to create vault client", + config: &Config{}, + wantErr: true, + wantType: false, + }, + } + + for _, tt := range tests { + ttp := tt + + // Redirect logs to avoid polluting the test output + var buf bytes.Buffer + handler := slog.NewTextHandler(&buf, nil) + logger := slog.New(handler) + + t.Run(ttp.name, func(t *testing.T) { + provider, err := NewProvider(ttp.config, logger, make(chan os.Signal)) + + assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status") + assert.Equal(t, ttp.wantType, provider != nil, "Unexpected provider type") + }) + + buf.Truncate(0) + } + +}