diff --git a/main.go b/main.go index 5c42d0d..23c0e79 100644 --- a/main.go +++ b/main.go @@ -38,7 +38,8 @@ import ( func NewProvider(providerName string) (provider.Provider, error) { switch providerName { case file.ProviderName: - provider, err := file.NewProvider(os.DirFS("/")) + config := file.NewConfig() + provider, err := file.NewProvider(config) if err != nil { return nil, err } @@ -55,7 +56,7 @@ func main() { { var level slog.Level - err := level.UnmarshalText([]byte(os.Getenv("VAULT_LOG_LEVEL"))) + err := level.UnmarshalText([]byte(os.Getenv("SECRET_INIT_LOG_LEVEL"))) if err != nil { // Silently fall back to info level level = slog.LevelInfo } @@ -68,7 +69,7 @@ func main() { router := slogmulti.Router() - if cast.ToBool(os.Getenv("VAULT_JSON_LOG")) { + if cast.ToBool(os.Getenv("SECRET_INIT_JSON_LOG")) { // Send logs with level higher than warning to stderr router = router.Add( slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelWarn}), @@ -94,7 +95,7 @@ func main() { ) } - if logServerAddr := os.Getenv("VAULT_ENV_LOG_SERVER"); logServerAddr != "" { + if logServerAddr := os.Getenv("SECRET_INIT_LOG_SERVER"); logServerAddr != "" { writer, err := net.Dial("udp", logServerAddr) // We silently ignore syslog connection errors for the lack of a better solution @@ -123,8 +124,8 @@ func main() { os.Exit(1) } - daemonMode := cast.ToBool(os.Getenv("VAULT_ENV_DAEMON")) - delayExec := cast.ToDuration(os.Getenv("VAULT_ENV_DELAY")) + daemonMode := cast.ToBool(os.Getenv("SECRET_INIT_DAEMON")) + delayExec := cast.ToDuration(os.Getenv("SECRET_INIT_DELAY")) entrypointCmd := os.Args[1:] diff --git a/provider/file/config.go b/provider/file/config.go new file mode 100644 index 0000000..1af3919 --- /dev/null +++ b/provider/file/config.go @@ -0,0 +1,21 @@ +package file + +import "os" + +const ( + EnvPrefix = "FILE_" + DefaultMountPath = "/" +) + +type Config struct { + MountPath string `json:"mountPath"` +} + +func NewConfig() *Config { + mountPath, ok := os.LookupEnv(EnvPrefix + "MOUNT_PATH") + if !ok { + mountPath = DefaultMountPath + } + + return &Config{MountPath: mountPath} +} diff --git a/provider/file/config_test.go b/provider/file/config_test.go new file mode 100644 index 0000000..e96b1ec --- /dev/null +++ b/provider/file/config_test.go @@ -0,0 +1,41 @@ +package file + +import ( + "os" + "testing" +) + +func TestConfig(t *testing.T) { + tests := []struct { + name string + env map[string]string + wantMountPath string + }{ + { + name: "Default mount path", + env: map[string]string{}, + wantMountPath: "/", + }, + { + name: "Custom mount path", + env: map[string]string{ + "FILE_MOUNT_PATH": "test/secrets", + }, + wantMountPath: "test/secrets", + }, + } + + for _, tt := range tests { + ttp := tt + t.Run(ttp.name, func(t *testing.T) { + for envKey, envVal := range ttp.env { + os.Setenv(envKey, envVal) + } + + config := NewConfig() + if config.MountPath != ttp.wantMountPath { + t.Errorf("NewConfig() = %v, wantMountPath %v", config.MountPath, ttp.wantMountPath) + } + }) + } +} diff --git a/provider/file/file.go b/provider/file/file.go index ba6a080..8952de0 100644 --- a/provider/file/file.go +++ b/provider/file/file.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "io/fs" + "os" "strings" "github.com/bank-vaults/secret-init/provider" @@ -29,10 +30,8 @@ type Provider struct { fs fs.FS } -func NewProvider(fs fs.FS) (provider.Provider, error) { - if fs == nil { - return nil, fmt.Errorf("file system is nil") - } +func NewProvider(config *Config) (provider.Provider, error) { + fs := os.DirFS(config.MountPath) return &Provider{fs: fs}, nil } @@ -55,9 +54,10 @@ func (p *Provider) LoadSecrets(_ context.Context, paths []string) ([]provider.Se return secrets, nil } -func (p *Provider) getSecretFromFile(filepath string) (string, error) { - filepath = strings.TrimLeft(filepath, "/") - content, err := fs.ReadFile(p.fs, filepath) +func (p *Provider) getSecretFromFile(path string) (string, error) { + path = strings.TrimLeft(path, "/") + fmt.Println("file path:", path, "fs:", p.fs) + content, err := fs.ReadFile(p.fs, path) if err != nil { return "", fmt.Errorf("failed to read file: %w", err) } diff --git a/provider/file/file_test.go b/provider/file/file_test.go index 50ddddb..2f47302 100644 --- a/provider/file/file_test.go +++ b/provider/file/file_test.go @@ -16,7 +16,6 @@ package file import ( "context" - "io/fs" "testing" "testing/fstest" @@ -28,41 +27,26 @@ import ( func TestNewProvider(t *testing.T) { tests := []struct { name string - fs fs.FS + config *Config wantErr bool wantType bool }{ { - name: "Valid file system", - fs: fstest.MapFS{ - "test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")}, - "test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")}, - "test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")}, + name: "Valid config", + config: &Config{ + MountPath: "test/secrets", }, wantErr: false, wantType: true, }, - { - name: "Nil file system", - fs: nil, - wantErr: true, - wantType: false, - }, } for _, tt := range tests { ttp := tt t.Run(ttp.name, func(t *testing.T) { - prov, err := NewProvider(ttp.fs) - if (err != nil) != ttp.wantErr { - t.Fatalf("NewProvider() error = %v, wantErr %v", err, ttp.wantErr) - return - } - // Use type assertion to check if the provider is of the correct type - _, ok := prov.(*Provider) - if ok != ttp.wantType { - t.Fatalf("NewProvider() = %v, wantType %v", ok, ttp.wantType) - } + provider, err := NewProvider(ttp.config) + assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status") + assert.Equal(t, ttp.wantType, provider != nil, "Unexpected provider type") }) } } @@ -70,18 +54,12 @@ func TestNewProvider(t *testing.T) { func TestLoadSecrets(t *testing.T) { tests := []struct { name string - fs fs.FS paths []string wantErr bool wantData []provider.Secret }{ { name: "Load secrets successfully", - fs: fstest.MapFS{ - "test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")}, - "test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")}, - "test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")}, - }, paths: []string{ "test/secrets/sqlpass.txt", "test/secrets/awsaccess.txt", @@ -96,11 +74,6 @@ func TestLoadSecrets(t *testing.T) { }, { name: "Fail to load secrets due to invalid path", - fs: fstest.MapFS{ - "test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")}, - "test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")}, - "test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")}, - }, paths: []string{ "test/secrets/mistake/sqlpass.txt", "test/secrets/mistake/awsaccess.txt", @@ -114,12 +87,15 @@ func TestLoadSecrets(t *testing.T) { for _, tt := range tests { ttp := tt t.Run(ttp.name, func(t *testing.T) { - provider, err := NewProvider(ttp.fs) - if assert.NoError(t, err, "Unexpected error") { - secrets, err := provider.LoadSecrets(context.Background(), ttp.paths) - assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status") - assert.ElementsMatch(t, ttp.wantData, secrets, "Unexpected secrets loaded") + fs := fstest.MapFS{ + "test/secrets/sqlpass.txt": {Data: []byte("3xtr3ms3cr3t")}, + "test/secrets/awsaccess.txt": {Data: []byte("s3cr3t")}, + "test/secrets/awsid.txt": {Data: []byte("secretId")}, } + provider := Provider{fs: fs} + secrets, err := provider.LoadSecrets(context.Background(), ttp.paths) + assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status") + assert.Equal(t, ttp.wantData, secrets, "Unexpected secrets") }) } }