From 4b05b8ca2318fb5e6a4a6d831295acfadd94f8b9 Mon Sep 17 00:00:00 2001 From: Bence Csati Date: Thu, 9 Nov 2023 08:40:45 +0100 Subject: [PATCH 01/10] feat:add-selection-logic-and-file-begin Signed-off-by: Bence Csati --- main.go | 69 +++++++++++++++++++++++++++++++++++++++---- provider/file/file.go | 35 ++++++++++++++++++++++ provider/provider.go | 2 +- 3 files changed, 100 insertions(+), 6 deletions(-) create mode 100644 provider/file/file.go diff --git a/main.go b/main.go index 216ebab..2721bcb 100644 --- a/main.go +++ b/main.go @@ -24,6 +24,7 @@ import ( "os/exec" "os/signal" "slices" + "strings" "syscall" "time" @@ -32,8 +33,29 @@ import ( "github.com/spf13/cast" "github.com/bank-vaults/secret-init/provider" + "github.com/bank-vaults/secret-init/provider/file" ) +type sanitizedEnviron struct { + env []string +} + +var sanitizeEnv = []string{ + "VAULT_JSON_LOG", + "VAULT_LOG_LEVEL", + "VAULT_ENV_DAEMON", + "VAULT_ENV_DELAY", + "VAULT_ENV_PASSTHROUGH", +} + +// func (e *sanitizedEnviron) append(name string, value string) { +// for _, env := range sanitizeEnv { +// if name == env { +// e.env = append(e.env, fmt.Sprintf("%s=%s", name, value)) +// } +// } +// } + func main() { var logger *slog.Logger { @@ -94,8 +116,17 @@ func main() { slog.SetDefault(logger) } - // TODO: enable providers - var provider provider.Provider + providers := map[string]provider.Provider{ + "file": file.NewFileProvider(os.Getenv("SECRETS_FILE_PATH")), + } + + providerName := os.Getenv("PROVIDER") + provider, found := providers[providerName] + if !found { + logger.Error("invalid provider specified.", slog.String("provider name", providerName)) + + os.Exit(1) + } if len(os.Args) == 1 { logger.Error("no command is given, vault-env can't determine the entrypoint (command), please specify it explicitly or let the webhook query it (see documentation)") @@ -115,14 +146,42 @@ func main() { os.Exit(1) } + passthroughEnvVars := strings.Split(os.Getenv("VAULT_ENV_PASSTHROUGH"), ",") + + // do not sanitize env vars specified in VAULT_ENV_PASSTHROUGH + for _, envVar := range passthroughEnvVars { + if trimmed := strings.TrimSpace(envVar); trimmed != "" { + for i, sanEnv := range sanitizeEnv { + if trimmed == sanEnv { + sanitizeEnv[i] = sanitizeEnv[len(sanitizeEnv)-1] + sanitizeEnv[len(sanitizeEnv)-1] = "" + sanitizeEnv = sanitizeEnv[:len(sanitizeEnv)-1] + } + } + } + } + + environ := make(map[string]string, len(os.Environ())) + sanitized := sanitizedEnviron{} + + for _, env := range os.Environ() { + split := strings.SplitN(env, "=", 2) + name := split[0] + value := split[1] + environ[name] = value + } + ctx := context.Background() - envs, err := provider.LoadSecrets(ctx, os.Environ()) + envs, err := provider.LoadSecrets(ctx, &environ) if err != nil { logger.Error("could not retrieve secrets from the provider.", err) os.Exit(1) } + // passthroughEnvs + loaded secrets + sanitized.env = append(sanitized.env, envs...) + sigs := make(chan os.Signal, 1) if delayExec > 0 { @@ -135,7 +194,7 @@ func main() { if daemonMode { logger.Info("in daemon mode...") cmd := exec.Command(binary, entrypointCmd[1:]...) - cmd.Env = append(os.Environ(), envs...) + cmd.Env = append(os.Environ(), sanitized.env...) cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr cmd.Stdout = os.Stdout @@ -184,7 +243,7 @@ func main() { os.Exit(cmd.ProcessState.ExitCode()) } - err = syscall.Exec(binary, entrypointCmd, envs) + err = syscall.Exec(binary, entrypointCmd, sanitized.env) if err != nil { logger.Error(fmt.Errorf("failed to exec process: %w", err).Error(), slog.String("entrypoint", fmt.Sprint(entrypointCmd))) diff --git a/provider/file/file.go b/provider/file/file.go new file mode 100644 index 0000000..5a3a777 --- /dev/null +++ b/provider/file/file.go @@ -0,0 +1,35 @@ +// Copyright © 2023 Bank-Vaults Maintainers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package file + +import ( + "context" + + "github.com/bank-vaults/secret-init/provider" +) + +type Provider struct { + SecretsFilePath string +} + +func NewFileProvider(secretsFilePath string) provider.Provider { + + return &Provider{SecretsFilePath: secretsFilePath} +} + +func (provider *Provider) LoadSecrets(_ context.Context, _ *map[string]string) ([]string, error) { + + return make([]string, 0), nil +} diff --git a/provider/provider.go b/provider/provider.go index c690a0c..9531321 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -18,5 +18,5 @@ import "context" // Provider is an interface for securely loading secrets based on environment variables. type Provider interface { - LoadSecrets(ctx context.Context, paths []string) ([]string, error) + LoadSecrets(ctx context.Context, paths *map[string]string) ([]string, error) } From 94f3872dbb39693136d0b3446989345a9d0a3816 Mon Sep 17 00:00:00 2001 From: Bence Csati Date: Fri, 10 Nov 2023 17:53:22 +0100 Subject: [PATCH 02/10] feat: add file provider Signed-off-by: Bence Csati --- main.go | 62 +++++++++---------------------------------- provider/file/file.go | 41 +++++++++++++++++++++++++--- provider/provider.go | 2 +- 3 files changed, 51 insertions(+), 54 deletions(-) diff --git a/main.go b/main.go index 2721bcb..2d07fa7 100644 --- a/main.go +++ b/main.go @@ -36,26 +36,6 @@ import ( "github.com/bank-vaults/secret-init/provider/file" ) -type sanitizedEnviron struct { - env []string -} - -var sanitizeEnv = []string{ - "VAULT_JSON_LOG", - "VAULT_LOG_LEVEL", - "VAULT_ENV_DAEMON", - "VAULT_ENV_DELAY", - "VAULT_ENV_PASSTHROUGH", -} - -// func (e *sanitizedEnviron) append(name string, value string) { -// for _, env := range sanitizeEnv { -// if name == env { -// e.env = append(e.env, fmt.Sprintf("%s=%s", name, value)) -// } -// } -// } - func main() { var logger *slog.Logger { @@ -116,13 +96,16 @@ func main() { slog.SetDefault(logger) } - providers := map[string]provider.Provider{ - "file": file.NewFileProvider(os.Getenv("SECRETS_FILE_PATH")), - } - + var provider provider.Provider providerName := os.Getenv("PROVIDER") - provider, found := providers[providerName] - if !found { + switch providerName { + case "file": + newProvider, err := file.NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) + if err != nil { + logger.Error(fmt.Errorf("failed to create provider: %w", err).Error()) + } + provider = newProvider + default: logger.Error("invalid provider specified.", slog.String("provider name", providerName)) os.Exit(1) @@ -146,23 +129,7 @@ func main() { os.Exit(1) } - passthroughEnvVars := strings.Split(os.Getenv("VAULT_ENV_PASSTHROUGH"), ",") - - // do not sanitize env vars specified in VAULT_ENV_PASSTHROUGH - for _, envVar := range passthroughEnvVars { - if trimmed := strings.TrimSpace(envVar); trimmed != "" { - for i, sanEnv := range sanitizeEnv { - if trimmed == sanEnv { - sanitizeEnv[i] = sanitizeEnv[len(sanitizeEnv)-1] - sanitizeEnv[len(sanitizeEnv)-1] = "" - sanitizeEnv = sanitizeEnv[:len(sanitizeEnv)-1] - } - } - } - } - environ := make(map[string]string, len(os.Environ())) - sanitized := sanitizedEnviron{} for _, env := range os.Environ() { split := strings.SplitN(env, "=", 2) @@ -172,16 +139,13 @@ func main() { } ctx := context.Background() - envs, err := provider.LoadSecrets(ctx, &environ) + envs, err := provider.LoadSecrets(ctx, environ) if err != nil { - logger.Error("could not retrieve secrets from the provider.", err) + logger.Error(fmt.Errorf("failed to load secrets from provider %w", err).Error()) os.Exit(1) } - // passthroughEnvs + loaded secrets - sanitized.env = append(sanitized.env, envs...) - sigs := make(chan os.Signal, 1) if delayExec > 0 { @@ -194,7 +158,7 @@ func main() { if daemonMode { logger.Info("in daemon mode...") cmd := exec.Command(binary, entrypointCmd[1:]...) - cmd.Env = append(os.Environ(), sanitized.env...) + cmd.Env = append(os.Environ(), envs...) cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr cmd.Stdout = os.Stdout @@ -243,7 +207,7 @@ func main() { os.Exit(cmd.ProcessState.ExitCode()) } - err = syscall.Exec(binary, entrypointCmd, sanitized.env) + err = syscall.Exec(binary, entrypointCmd, envs) if err != nil { logger.Error(fmt.Errorf("failed to exec process: %w", err).Error(), slog.String("entrypoint", fmt.Sprint(entrypointCmd))) diff --git a/provider/file/file.go b/provider/file/file.go index 5a3a777..c103344 100644 --- a/provider/file/file.go +++ b/provider/file/file.go @@ -16,20 +16,53 @@ package file import ( "context" + "fmt" + "os" + "strings" "github.com/bank-vaults/secret-init/provider" ) type Provider struct { SecretsFilePath string + SecretData []byte } -func NewFileProvider(secretsFilePath string) provider.Provider { +func NewFileProvider(secretsFilePath string) (provider.Provider, error) { + data, err := os.ReadFile(secretsFilePath) + if err != nil { + return nil, fmt.Errorf("failed to read from file: %w", err) + } - return &Provider{SecretsFilePath: secretsFilePath} + return &Provider{SecretsFilePath: secretsFilePath, SecretData: data}, nil } -func (provider *Provider) LoadSecrets(_ context.Context, _ *map[string]string) ([]string, error) { +func (provider *Provider) LoadSecrets(_ context.Context, envs map[string]string) ([]string, error) { - return make([]string, 0), nil + //envs that has a value with "file:" prefix needs to be loaded + var secrets []string + for key, value := range envs { + if strings.HasPrefix(value, "file:") { + secret, err := provider.getSecretFromFile(key) + if err != nil { + return nil, fmt.Errorf("failed to load secret: %w", err) + } + secrets = append(secrets, secret) + } + } + + return secrets, nil +} + +func (provider *Provider) getSecretFromFile(key string) (string, error) { + lines := strings.Split(string(provider.SecretData), "\n") + for _, line := range lines { + split := strings.SplitN(line, "=", 2) + + if split[0] == key { + return split[1], nil + } + } + + return "", fmt.Errorf("key: '%s' not found in file", key) } diff --git a/provider/provider.go b/provider/provider.go index 9531321..09954bd 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -18,5 +18,5 @@ import "context" // Provider is an interface for securely loading secrets based on environment variables. type Provider interface { - LoadSecrets(ctx context.Context, paths *map[string]string) ([]string, error) + LoadSecrets(ctx context.Context, envs map[string]string) ([]string, error) } From b0e43942a35cfcbaa10b3cbd5d9cde27a985daeb Mon Sep 17 00:00:00 2001 From: Bence Csati Date: Fri, 10 Nov 2023 18:05:31 +0100 Subject: [PATCH 03/10] feat: format load secret output Signed-off-by: Bence Csati --- provider/file/file.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/provider/file/file.go b/provider/file/file.go index c103344..52818be 100644 --- a/provider/file/file.go +++ b/provider/file/file.go @@ -47,7 +47,7 @@ func (provider *Provider) LoadSecrets(_ context.Context, envs map[string]string) if err != nil { return nil, fmt.Errorf("failed to load secret: %w", err) } - secrets = append(secrets, secret) + secrets = append(secrets, fmt.Sprintf("%s=%s", key, secret)) } } From 6a3f369d97ce342ace9d6f06ba012155d6d3f795 Mon Sep 17 00:00:00 2001 From: Bence Csati Date: Fri, 10 Nov 2023 21:59:19 +0100 Subject: [PATCH 04/10] feat(file_test): add file tests Signed-off-by: Bence Csati --- main.go | 2 + provider/file/file_test.go | 161 +++++++++++++++++++++++++++++++++++++ 2 files changed, 163 insertions(+) create mode 100644 provider/file/file_test.go diff --git a/main.go b/main.go index 2d07fa7..f698398 100644 --- a/main.go +++ b/main.go @@ -103,6 +103,8 @@ func main() { newProvider, err := file.NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) if err != nil { logger.Error(fmt.Errorf("failed to create provider: %w", err).Error()) + + os.Exit(1) } provider = newProvider default: diff --git a/provider/file/file_test.go b/provider/file/file_test.go new file mode 100644 index 0000000..1e340cb --- /dev/null +++ b/provider/file/file_test.go @@ -0,0 +1,161 @@ +// Copyright © 2023 Bank-Vaults Maintainers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package file + +import ( + "context" + "os" + "strings" + "testing" + + "github.com/bank-vaults/secret-init/provider" +) + +func TestNewFileProvider(t *testing.T) { + //create a new secret-file and write secrets into it + tmpfile := createTempFileWithContent(t) + + defer os.Remove(tmpfile.Name()) + + //create new environment variables + //for file-path and secrets to get + setupEnvs(t, tmpfile) + + providerName := os.Getenv("PROVIDER") + if providerName == "file" { + _, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) + if err != nil { + t.Fatal(err) + } + } else { + t.Fatalf("invalid provider specified: %s", providerName) + } +} + +func TestFileLoadSecrets(t *testing.T) { + //create a new secret-file and write secrets into it + tmpfile := createTempFileWithContent(t) + + defer os.Remove(tmpfile.Name()) + + //create new environment variables + //for file-path and secrets to get + setupEnvs(t, tmpfile) + + var provider provider.Provider + providerName := os.Getenv("PROVIDER") + if providerName == "file" { + newProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) + if err != nil { + t.Fatal(err) + } + provider = newProvider + } else { + t.Fatalf("invalid provider specified: %s", providerName) + } + + environ := make(map[string]string, len(os.Environ())) + + for _, env := range os.Environ() { + split := strings.SplitN(env, "=", 2) + name := split[0] + value := split[1] + environ[name] = value + } + + ctx := context.Background() + envs, err := provider.LoadSecrets(ctx, environ) + if err != nil { + t.Fatal(err) + } + + test := []string{ + "MYSQL_PASSWORD=3xtr3ms3cr3t", + "AWS_SECRET_ACCESS_KEY=s3cr3t", + "AWS_ACCESS_KEY_ID=secretId", + } + //check if secrets have been correctly loaded + areEqual(t, envs, test) +} + +func areEqual(t *testing.T, actual, expected []string) { + actualMap := make(map[string]string, len(expected)) + expectedMap := make(map[string]string, len(expected)) + + for _, env := range actual { + split := strings.SplitN(env, "=", 2) + key := split[0] + value := split[1] + actualMap[key] = value + } + + for _, env := range expected { + split := strings.SplitN(env, "=", 2) + key := split[0] + value := split[1] + expectedMap[key] = value + } + + for key, actualValue := range actualMap { + expectedValue, ok := expectedMap[key] + if !ok || actualValue != expectedValue { + t.Fatalf("Mismatch for key %s: actual: %s, expected: %s", key, actualValue, expectedValue) + } + } +} + +func createTempFileWithContent(t *testing.T) *os.File { + content := []byte("MYSQL_PASSWORD=3xtr3ms3cr3t\nAWS_SECRET_ACCESS_KEY=s3cr3t\nAWS_ACCESS_KEY_ID=secretId\n") + tmpfile, err := os.CreateTemp("", "secrets-*.txt") + if err != nil { + t.Fatal(err) + } + + _, err = tmpfile.Write(content) + if err != nil { + t.Fatal(err) + } + + err = tmpfile.Close() + if err != nil { + t.Fatal(err) + } + + return tmpfile +} + +func setupEnvs(t *testing.T, tmpfile *os.File) { + err := os.Setenv("PROVIDER", "file") + if err != nil { + t.Fatal(err) + } + err = os.Setenv("SECRETS_FILE_PATH", tmpfile.Name()) + if err != nil { + t.Fatal(err) + } + + err = os.Setenv("MYSQL_PASSWORD", "file:secret") + if err != nil { + t.Fatal(err) + } + err = os.Setenv("AWS_SECRET_ACCESS_KEY", "file:secret") + if err != nil { + t.Fatal(err) + } + err = os.Setenv("AWS_ACCESS_KEY_ID", "file:secret") + if err != nil { + t.Fatal(err) + } +} From 370dae2e84aabdf4b6a4c9f14c9783eefe94a841 Mon Sep 17 00:00:00 2001 From: Bence Csati Date: Sat, 11 Nov 2023 11:33:47 +0100 Subject: [PATCH 05/10] feat(file_test): adjust tests Signed-off-by: Bence Csati --- main.go | 1 - provider/file/file.go | 4 +-- provider/file/file_test.go | 72 +++++++++++++++++++++++++++----------- 3 files changed, 53 insertions(+), 24 deletions(-) diff --git a/main.go b/main.go index f698398..d9ce1d3 100644 --- a/main.go +++ b/main.go @@ -132,7 +132,6 @@ func main() { } environ := make(map[string]string, len(os.Environ())) - for _, env := range os.Environ() { split := strings.SplitN(env, "=", 2) name := split[0] diff --git a/provider/file/file.go b/provider/file/file.go index 52818be..636285f 100644 --- a/provider/file/file.go +++ b/provider/file/file.go @@ -38,8 +38,7 @@ func NewFileProvider(secretsFilePath string) (provider.Provider, error) { } func (provider *Provider) LoadSecrets(_ context.Context, envs map[string]string) ([]string, error) { - - //envs that has a value with "file:" prefix needs to be loaded + // envs that has a "file:" prefix needs to be loaded var secrets []string for key, value := range envs { if strings.HasPrefix(value, "file:") { @@ -58,7 +57,6 @@ func (provider *Provider) getSecretFromFile(key string) (string, error) { lines := strings.Split(string(provider.SecretData), "\n") for _, line := range lines { split := strings.SplitN(line, "=", 2) - if split[0] == key { return split[1], nil } diff --git a/provider/file/file_test.go b/provider/file/file_test.go index 1e340cb..18d60ef 100644 --- a/provider/file/file_test.go +++ b/provider/file/file_test.go @@ -15,6 +15,7 @@ package file import ( + "bytes" "context" "os" "strings" @@ -23,51 +24,73 @@ import ( "github.com/bank-vaults/secret-init/provider" ) +func TestMain(m *testing.M) { + exitCode := m.Run() + + // teardown environment variables after tests are done + teardownEnvs() + + os.Exit(exitCode) +} + func TestNewFileProvider(t *testing.T) { - //create a new secret-file and write secrets into it + // create a new secret file and write secrets into it tmpfile := createTempFileWithContent(t) - defer os.Remove(tmpfile.Name()) - //create new environment variables - //for file-path and secrets to get + // create new environment variables + // for file-path and secrets to get setupEnvs(t, tmpfile) + var fileProvider provider.Provider providerName := os.Getenv("PROVIDER") - if providerName == "file" { - _, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) + switch providerName { + case "file": + newFileProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) + fileProvider = newFileProvider if err != nil { t.Fatal(err) } - } else { + default: t.Fatalf("invalid provider specified: %s", providerName) } + + expectedSecretData, err := os.ReadFile(tmpfile.Name()) + if err != nil { + t.Fatal(err) + } + + // access the provider implementation to get secretdata + // then check if file provider is correctly created + // and file-path is read + if fileProvider, ok := fileProvider.(*Provider); !ok || !bytes.Equal(expectedSecretData, fileProvider.SecretData) { + t.Fatal("failed to create file provider") + } } func TestFileLoadSecrets(t *testing.T) { - //create a new secret-file and write secrets into it + // create a new secret-file and write secrets into it tmpfile := createTempFileWithContent(t) - defer os.Remove(tmpfile.Name()) - //create new environment variables - //for file-path and secrets to get + // create new environment variables + // for file-path and secrets to get setupEnvs(t, tmpfile) - var provider provider.Provider + var fileProvider provider.Provider providerName := os.Getenv("PROVIDER") - if providerName == "file" { - newProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) + switch providerName { + case "file": + newFileProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) + fileProvider = newFileProvider if err != nil { t.Fatal(err) } - provider = newProvider - } else { + default: t.Fatalf("invalid provider specified: %s", providerName) } environ := make(map[string]string, len(os.Environ())) - for _, env := range os.Environ() { split := strings.SplitN(env, "=", 2) name := split[0] @@ -76,7 +99,7 @@ func TestFileLoadSecrets(t *testing.T) { } ctx := context.Background() - envs, err := provider.LoadSecrets(ctx, environ) + envs, err := fileProvider.LoadSecrets(ctx, environ) if err != nil { t.Fatal(err) } @@ -86,7 +109,8 @@ func TestFileLoadSecrets(t *testing.T) { "AWS_SECRET_ACCESS_KEY=s3cr3t", "AWS_ACCESS_KEY_ID=secretId", } - //check if secrets have been correctly loaded + + // check if secrets have been correctly loaded areEqual(t, envs, test) } @@ -111,7 +135,7 @@ func areEqual(t *testing.T, actual, expected []string) { for key, actualValue := range actualMap { expectedValue, ok := expectedMap[key] if !ok || actualValue != expectedValue { - t.Fatalf("Mismatch for key %s: actual: %s, expected: %s", key, actualValue, expectedValue) + t.Fatalf("mismatch for key %s: actual: %s, expected: %s", key, actualValue, expectedValue) } } } @@ -159,3 +183,11 @@ func setupEnvs(t *testing.T, tmpfile *os.File) { t.Fatal(err) } } + +func teardownEnvs() { + os.Unsetenv("PROVIDER") + os.Unsetenv("SECRETS_FILE_PATH") + os.Unsetenv("MYSQL_PASSWORD") + os.Unsetenv("AWS_SECRET_ACCESS_KEY") + os.Unsetenv("AWS_ACCESS_KEY_ID") +} From 5d2418ad1ecab8f43ddd45e81810bf6e96d6524e Mon Sep 17 00:00:00 2001 From: Bence Csati Date: Mon, 20 Nov 2023 17:10:52 +0100 Subject: [PATCH 06/10] fix: reading from yaml instead of txt, minor fixes and refactor Signed-off-by: Bence Csati --- go.mod | 1 + go.sum | 4 +++ main.go | 34 +++++++++++++++---------- provider/file/file.go | 51 +++++++++++++++++++++++-------------- provider/file/file_test.go | 52 ++++++++++---------------------------- 5 files changed, 70 insertions(+), 72 deletions(-) diff --git a/go.mod b/go.mod index f6224cb..4d451b7 100644 --- a/go.mod +++ b/go.mod @@ -11,4 +11,5 @@ require ( require ( github.com/samber/lo v1.38.1 // indirect golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect + gopkg.in/yaml.v2 v2.4.0 ) diff --git a/go.sum b/go.sum index 38c3e2a..a6c4a01 100644 --- a/go.sum +++ b/go.sum @@ -18,3 +18,7 @@ github.com/spf13/cast v1.5.1 h1:R+kOtfhWQE6TVQzY+4D7wJLBgkdVasCEFxSUBYBYIlA= github.com/spf13/cast v1.5.1/go.mod h1:b9PdjNptOpzXr7Rq1q9gJML/2cdGQAo69NKzQ10KN48= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/main.go b/main.go index d9ce1d3..7d91973 100644 --- a/main.go +++ b/main.go @@ -36,6 +36,22 @@ import ( "github.com/bank-vaults/secret-init/provider/file" ) +func NewProvider(providerName string) (provider.Provider, error) { + switch providerName { + case file.ProviderName: + provider, err := file.NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) + if err != nil { + + return nil, err + } + + return provider, nil + default: + + return nil, errors.New("invalid provider specified") + } +} + func main() { var logger *slog.Logger { @@ -96,19 +112,9 @@ func main() { slog.SetDefault(logger) } - var provider provider.Provider - providerName := os.Getenv("PROVIDER") - switch providerName { - case "file": - newProvider, err := file.NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) - if err != nil { - logger.Error(fmt.Errorf("failed to create provider: %w", err).Error()) - - os.Exit(1) - } - provider = newProvider - default: - logger.Error("invalid provider specified.", slog.String("provider name", providerName)) + provider, err := NewProvider(os.Getenv("PROVIDER")) + if err != nil { + logger.Error(fmt.Errorf("failed to create provider: %w", err).Error()) os.Exit(1) } @@ -142,7 +148,7 @@ func main() { ctx := context.Background() envs, err := provider.LoadSecrets(ctx, environ) if err != nil { - logger.Error(fmt.Errorf("failed to load secrets from provider %w", err).Error()) + logger.Error(fmt.Errorf("failed to load secrets from provider: %w", err).Error()) os.Exit(1) } diff --git a/provider/file/file.go b/provider/file/file.go index 636285f..0a24cfa 100644 --- a/provider/file/file.go +++ b/provider/file/file.go @@ -20,31 +20,39 @@ import ( "os" "strings" + "gopkg.in/yaml.v2" + "github.com/bank-vaults/secret-init/provider" ) +const ProviderName = "file" + type Provider struct { - SecretsFilePath string - SecretData []byte + secretsFilePath string } func NewFileProvider(secretsFilePath string) (provider.Provider, error) { - data, err := os.ReadFile(secretsFilePath) - if err != nil { - return nil, fmt.Errorf("failed to read from file: %w", err) - } - return &Provider{SecretsFilePath: secretsFilePath, SecretData: data}, nil + return &Provider{secretsFilePath: secretsFilePath}, nil } func (provider *Provider) LoadSecrets(_ context.Context, envs map[string]string) ([]string, error) { - // envs that has a "file:" prefix needs to be loaded + // extract secrets from the file to a map + secretsMap, err := provider.getSecretsFromFile() + if err != nil { + + return nil, fmt.Errorf("failed to load secrets: %w", err) + } + var secrets []string for key, value := range envs { if strings.HasPrefix(value, "file:") { - secret, err := provider.getSecretFromFile(key) - if err != nil { - return nil, fmt.Errorf("failed to load secret: %w", err) + // Check if the requested secret is in the loaded secret map + value = strings.TrimPrefix(value, "file:") + secret, ok := secretsMap[value] + if !ok { + + return nil, fmt.Errorf("secret %s not found", key) } secrets = append(secrets, fmt.Sprintf("%s=%s", key, secret)) } @@ -53,14 +61,19 @@ func (provider *Provider) LoadSecrets(_ context.Context, envs map[string]string) return secrets, nil } -func (provider *Provider) getSecretFromFile(key string) (string, error) { - lines := strings.Split(string(provider.SecretData), "\n") - for _, line := range lines { - split := strings.SplitN(line, "=", 2) - if split[0] == key { - return split[1], nil - } +func (provider *Provider) getSecretsFromFile() (map[string]string, error) { + data, err := os.ReadFile(provider.secretsFilePath) + if err != nil { + + return nil, fmt.Errorf("failed to read secrets file: %w", err) + } + + secretsMap := make(map[string]string) + err = yaml.Unmarshal(data, &secretsMap) + if err != nil { + + return nil, fmt.Errorf("failed to unmarshal YAML: %w", err) } - return "", fmt.Errorf("key: '%s' not found in file", key) + return secretsMap, nil } diff --git a/provider/file/file_test.go b/provider/file/file_test.go index 18d60ef..9893537 100644 --- a/provider/file/file_test.go +++ b/provider/file/file_test.go @@ -15,13 +15,10 @@ package file import ( - "bytes" "context" "os" "strings" "testing" - - "github.com/bank-vaults/secret-init/provider" ) func TestMain(m *testing.M) { @@ -39,32 +36,17 @@ func TestNewFileProvider(t *testing.T) { defer os.Remove(tmpfile.Name()) // create new environment variables - // for file-path and secrets to get setupEnvs(t, tmpfile) - var fileProvider provider.Provider - providerName := os.Getenv("PROVIDER") - switch providerName { - case "file": - newFileProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) - fileProvider = newFileProvider - if err != nil { - t.Fatal(err) - } - default: - t.Fatalf("invalid provider specified: %s", providerName) - } - - expectedSecretData, err := os.ReadFile(tmpfile.Name()) + fileProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) if err != nil { t.Fatal(err) } - // access the provider implementation to get secretdata - // then check if file provider is correctly created - // and file-path is read - if fileProvider, ok := fileProvider.(*Provider); !ok || !bytes.Equal(expectedSecretData, fileProvider.SecretData) { - t.Fatal("failed to create file provider") + // check if file provider is correctly created + _, ok := fileProvider.(*Provider) + if !ok { + t.Fatal("provider is not of type file") } } @@ -77,17 +59,9 @@ func TestFileLoadSecrets(t *testing.T) { // for file-path and secrets to get setupEnvs(t, tmpfile) - var fileProvider provider.Provider - providerName := os.Getenv("PROVIDER") - switch providerName { - case "file": - newFileProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) - fileProvider = newFileProvider - if err != nil { - t.Fatal(err) - } - default: - t.Fatalf("invalid provider specified: %s", providerName) + fileProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) + if err != nil { + t.Fatal(err) } environ := make(map[string]string, len(os.Environ())) @@ -141,8 +115,8 @@ func areEqual(t *testing.T, actual, expected []string) { } func createTempFileWithContent(t *testing.T) *os.File { - content := []byte("MYSQL_PASSWORD=3xtr3ms3cr3t\nAWS_SECRET_ACCESS_KEY=s3cr3t\nAWS_ACCESS_KEY_ID=secretId\n") - tmpfile, err := os.CreateTemp("", "secrets-*.txt") + content := []byte("sqlPassword: 3xtr3ms3cr3t\nawsSecretAccessKey: s3cr3t\nawsAccessKeyId: secretId\n") + tmpfile, err := os.CreateTemp("", "secrets-*.yaml") if err != nil { t.Fatal(err) } @@ -170,15 +144,15 @@ func setupEnvs(t *testing.T, tmpfile *os.File) { t.Fatal(err) } - err = os.Setenv("MYSQL_PASSWORD", "file:secret") + err = os.Setenv("MYSQL_PASSWORD", "file:sqlPassword") if err != nil { t.Fatal(err) } - err = os.Setenv("AWS_SECRET_ACCESS_KEY", "file:secret") + err = os.Setenv("AWS_SECRET_ACCESS_KEY", "file:awsSecretAccessKey") if err != nil { t.Fatal(err) } - err = os.Setenv("AWS_ACCESS_KEY_ID", "file:secret") + err = os.Setenv("AWS_ACCESS_KEY_ID", "file:awsAccessKeyId") if err != nil { t.Fatal(err) } From 65d24598680240ee57e55d7296b854f6e52a7010 Mon Sep 17 00:00:00 2001 From: Bence Csati Date: Tue, 21 Nov 2023 14:26:10 +0100 Subject: [PATCH 07/10] fix: moved providers interface to model package, provider selection logic to provider package, and minor fixes Signed-off-by: Bence Csati --- go.mod | 2 +- go.sum | 4 ++-- main.go | 19 +------------------ model/provider.go | 22 ++++++++++++++++++++++ provider/file/file.go | 20 ++++++++++---------- provider/file/file_test.go | 4 ++-- provider/provider.go | 23 +++++++++++++++++++---- 7 files changed, 57 insertions(+), 37 deletions(-) create mode 100644 model/provider.go diff --git a/go.mod b/go.mod index 4d451b7..f8f20ed 100644 --- a/go.mod +++ b/go.mod @@ -11,5 +11,5 @@ require ( require ( github.com/samber/lo v1.38.1 // indirect golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect - gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index a6c4a01..093651b 100644 --- a/go.sum +++ b/go.sum @@ -20,5 +20,5 @@ golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERs golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index 7d91973..fe81646 100644 --- a/main.go +++ b/main.go @@ -33,25 +33,8 @@ import ( "github.com/spf13/cast" "github.com/bank-vaults/secret-init/provider" - "github.com/bank-vaults/secret-init/provider/file" ) -func NewProvider(providerName string) (provider.Provider, error) { - switch providerName { - case file.ProviderName: - provider, err := file.NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) - if err != nil { - - return nil, err - } - - return provider, nil - default: - - return nil, errors.New("invalid provider specified") - } -} - func main() { var logger *slog.Logger { @@ -112,7 +95,7 @@ func main() { slog.SetDefault(logger) } - provider, err := NewProvider(os.Getenv("PROVIDER")) + provider, err := provider.New(os.Getenv("PROVIDER")) if err != nil { logger.Error(fmt.Errorf("failed to create provider: %w", err).Error()) diff --git a/model/provider.go b/model/provider.go new file mode 100644 index 0000000..59ea5d4 --- /dev/null +++ b/model/provider.go @@ -0,0 +1,22 @@ +// Copyright © 2023 Bank-Vaults Maintainers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import "context" + +// Provider is an interface for securely loading secrets based on environment variables. +type Provider interface { + LoadSecrets(ctx context.Context, envs map[string]string) ([]string, error) +} diff --git a/provider/file/file.go b/provider/file/file.go index 0a24cfa..887c22f 100644 --- a/provider/file/file.go +++ b/provider/file/file.go @@ -20,9 +20,9 @@ import ( "os" "strings" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" - "github.com/bank-vaults/secret-init/provider" + "github.com/bank-vaults/secret-init/model" ) const ProviderName = "file" @@ -31,7 +31,7 @@ type Provider struct { secretsFilePath string } -func NewFileProvider(secretsFilePath string) (provider.Provider, error) { +func NewProvider(secretsFilePath string) (model.Provider, error) { return &Provider{secretsFilePath: secretsFilePath}, nil } @@ -41,20 +41,20 @@ func (provider *Provider) LoadSecrets(_ context.Context, envs map[string]string) secretsMap, err := provider.getSecretsFromFile() if err != nil { - return nil, fmt.Errorf("failed to load secrets: %w", err) + return nil, fmt.Errorf("failed to get secrets from file: %w", err) } var secrets []string - for key, value := range envs { - if strings.HasPrefix(value, "file:") { + for envKey, envValue := range envs { + if strings.HasPrefix(envValue, "file:") { // Check if the requested secret is in the loaded secret map - value = strings.TrimPrefix(value, "file:") - secret, ok := secretsMap[value] + envValue = strings.TrimPrefix(envValue, "file:") + secret, ok := secretsMap[envValue] if !ok { - return nil, fmt.Errorf("secret %s not found", key) + return nil, fmt.Errorf("secret %s not found", envKey) } - secrets = append(secrets, fmt.Sprintf("%s=%s", key, secret)) + secrets = append(secrets, fmt.Sprintf("%s=%s", envKey, secret)) } } diff --git a/provider/file/file_test.go b/provider/file/file_test.go index 9893537..3b87d7d 100644 --- a/provider/file/file_test.go +++ b/provider/file/file_test.go @@ -38,7 +38,7 @@ func TestNewFileProvider(t *testing.T) { // create new environment variables setupEnvs(t, tmpfile) - fileProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) + fileProvider, err := NewProvider(os.Getenv("SECRETS_FILE_PATH")) if err != nil { t.Fatal(err) } @@ -59,7 +59,7 @@ func TestFileLoadSecrets(t *testing.T) { // for file-path and secrets to get setupEnvs(t, tmpfile) - fileProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) + fileProvider, err := NewProvider(os.Getenv("SECRETS_FILE_PATH")) if err != nil { t.Fatal(err) } diff --git a/provider/provider.go b/provider/provider.go index 09954bd..ce619ba 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -14,9 +14,24 @@ package provider -import "context" +import ( + "errors" + "os" -// Provider is an interface for securely loading secrets based on environment variables. -type Provider interface { - LoadSecrets(ctx context.Context, envs map[string]string) ([]string, error) + "github.com/bank-vaults/secret-init/model" + "github.com/bank-vaults/secret-init/provider/file" +) + +func New(providerName string) (model.Provider, error) { + switch providerName { + case file.ProviderName: + provider, err := file.NewProvider(os.Getenv("SECRETS_FILE_PATH")) + if err != nil { + return nil, err + } + return provider, nil + + default: + return nil, errors.New("invalid provider specified") + } } From fc8e9493325b57826d677d9e9c8bda8dea7ac968 Mon Sep 17 00:00:00 2001 From: Bence Csati Date: Thu, 23 Nov 2023 10:57:27 +0100 Subject: [PATCH 08/10] fix: move back NewProvider constructor to main, providers interface to provider package Signed-off-by: Bence Csati --- main.go | 18 +++++++++++++++++- model/provider.go | 22 ---------------------- provider/file/file.go | 5 ++--- provider/provider.go | 23 ++++------------------- 4 files changed, 23 insertions(+), 45 deletions(-) delete mode 100644 model/provider.go diff --git a/main.go b/main.go index fe81646..93202cd 100644 --- a/main.go +++ b/main.go @@ -33,8 +33,24 @@ import ( "github.com/spf13/cast" "github.com/bank-vaults/secret-init/provider" + "github.com/bank-vaults/secret-init/provider/file" ) +func NewProvider(providerName string) (provider.Provider, error) { + switch providerName { + case file.ProviderName: + provider, err := file.NewProvider(os.Getenv("SECRETS_FILE_PATH")) + if err != nil { + return nil, err + } + + return provider, nil + + default: + return nil, errors.New("invalid provider specified") + } +} + func main() { var logger *slog.Logger { @@ -95,7 +111,7 @@ func main() { slog.SetDefault(logger) } - provider, err := provider.New(os.Getenv("PROVIDER")) + provider, err := NewProvider(os.Getenv("PROVIDER")) if err != nil { logger.Error(fmt.Errorf("failed to create provider: %w", err).Error()) diff --git a/model/provider.go b/model/provider.go deleted file mode 100644 index 59ea5d4..0000000 --- a/model/provider.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright © 2023 Bank-Vaults Maintainers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package model - -import "context" - -// Provider is an interface for securely loading secrets based on environment variables. -type Provider interface { - LoadSecrets(ctx context.Context, envs map[string]string) ([]string, error) -} diff --git a/provider/file/file.go b/provider/file/file.go index 887c22f..7b388a7 100644 --- a/provider/file/file.go +++ b/provider/file/file.go @@ -22,7 +22,7 @@ import ( "gopkg.in/yaml.v3" - "github.com/bank-vaults/secret-init/model" + "github.com/bank-vaults/secret-init/provider" ) const ProviderName = "file" @@ -31,8 +31,7 @@ type Provider struct { secretsFilePath string } -func NewProvider(secretsFilePath string) (model.Provider, error) { - +func NewProvider(secretsFilePath string) (provider.Provider, error) { return &Provider{secretsFilePath: secretsFilePath}, nil } diff --git a/provider/provider.go b/provider/provider.go index ce619ba..09954bd 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -14,24 +14,9 @@ package provider -import ( - "errors" - "os" +import "context" - "github.com/bank-vaults/secret-init/model" - "github.com/bank-vaults/secret-init/provider/file" -) - -func New(providerName string) (model.Provider, error) { - switch providerName { - case file.ProviderName: - provider, err := file.NewProvider(os.Getenv("SECRETS_FILE_PATH")) - if err != nil { - return nil, err - } - return provider, nil - - default: - return nil, errors.New("invalid provider specified") - } +// Provider is an interface for securely loading secrets based on environment variables. +type Provider interface { + LoadSecrets(ctx context.Context, envs map[string]string) ([]string, error) } From 29cad097fd4f5d051aaedaa7bca785ebc7b34702 Mon Sep 17 00:00:00 2001 From: Bence Csati Date: Fri, 24 Nov 2023 19:37:08 +0100 Subject: [PATCH 09/10] fix: Reworked file provider logic, made the code more cleaner, improved tests Signed-off-by: Bence Csati --- go.mod | 8 +- go.sum | 6 + main.go | 77 +++++++++++-- provider/file/file.go | 78 ++++++++----- provider/file/file_test.go | 231 +++++++++++++++---------------------- provider/provider.go | 2 +- 6 files changed, 218 insertions(+), 184 deletions(-) diff --git a/go.mod b/go.mod index f8f20ed..f302877 100644 --- a/go.mod +++ b/go.mod @@ -8,8 +8,14 @@ require ( github.com/spf13/cast v1.5.1 ) +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + require ( github.com/samber/lo v1.38.1 // indirect + github.com/stretchr/testify v1.8.4 golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect - gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 093651b..8e6c1e7 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY= github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= @@ -6,6 +8,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= @@ -16,6 +20,8 @@ github.com/samber/slog-syslog v1.0.0 h1:4tf8sNv9+qTQ6Fj8+N6U1ZEtUbqbAIzd+q26/Neg github.com/samber/slog-syslog v1.0.0/go.mod h1:jjupk+yHPVSuXuGhKleoClYc/HEaC+Ro5X4YYeBrt6g= github.com/spf13/cast v1.5.1 h1:R+kOtfhWQE6TVQzY+4D7wJLBgkdVasCEFxSUBYBYIlA= github.com/spf13/cast v1.5.1/go.mod h1:b9PdjNptOpzXr7Rq1q9gJML/2cdGQAo69NKzQ10KN48= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/main.go b/main.go index 93202cd..2a8bcd3 100644 --- a/main.go +++ b/main.go @@ -39,7 +39,7 @@ import ( func NewProvider(providerName string) (provider.Provider, error) { switch providerName { case file.ProviderName: - provider, err := file.NewProvider(os.Getenv("SECRETS_FILE_PATH")) + provider, err := file.NewProvider(os.DirFS("/secrets")) if err != nil { return nil, err } @@ -51,6 +51,59 @@ func NewProvider(providerName string) (provider.Provider, error) { } } +func CreateMapOfEnvs() map[string]string { + environ := make(map[string]string, len(os.Environ())) + for _, env := range os.Environ() { + split := strings.SplitN(env, "=", 2) + name := split[0] + value := split[1] + environ[name] = value + } + + return environ +} + +func ExtractPathsFromEnvs(envs map[string]string) []string { + var secretPaths []string + + for _, path := range envs { + if strings.HasPrefix(path, "file:") { + path = strings.TrimPrefix(path, "file://") + secretPaths = append(secretPaths, path) + } + } + + return secretPaths +} + +func CreateEnvsFromLoadedSecrets(envs map[string]string, secrets []string) ([]string, error) { + // Reverse the map so we can match + // the environment variable key to the secret + // by using the secret path + reversedEnvs := make(map[string]string) + for envKey, path := range envs { + if strings.HasPrefix(path, "file:") { + path = strings.TrimPrefix(path, "file://") + reversedEnvs[path] = envKey + } + } + + var secretsEnv []string + for _, secret := range secrets { + split := strings.SplitN(secret, "|", 2) + secretPath := split[0] + + secretValue := split[1] + secretKey, ok := reversedEnvs[secretPath] + if !ok { + return nil, fmt.Errorf("failed to find environment variable key for secret path: %s", secretPath) + } + secretsEnv = append(secretsEnv, fmt.Sprintf("%s=%s", secretKey, secretValue)) + } + + return secretsEnv, nil +} + func main() { var logger *slog.Logger { @@ -136,22 +189,24 @@ func main() { os.Exit(1) } - environ := make(map[string]string, len(os.Environ())) - for _, env := range os.Environ() { - split := strings.SplitN(env, "=", 2) - name := split[0] - value := split[1] - environ[name] = value - } + environ := CreateMapOfEnvs() + paths := ExtractPathsFromEnvs(environ) ctx := context.Background() - envs, err := provider.LoadSecrets(ctx, environ) + secrets, err := provider.LoadSecrets(ctx, paths) if err != nil { logger.Error(fmt.Errorf("failed to load secrets from provider: %w", err).Error()) os.Exit(1) } + secretsEnv, err := CreateEnvsFromLoadedSecrets(environ, secrets) + if err != nil { + logger.Error(fmt.Errorf("failed to create environment variables from loaded secrets: %w", err).Error()) + + os.Exit(1) + } + sigs := make(chan os.Signal, 1) if delayExec > 0 { @@ -164,7 +219,7 @@ func main() { if daemonMode { logger.Info("in daemon mode...") cmd := exec.Command(binary, entrypointCmd[1:]...) - cmd.Env = append(os.Environ(), envs...) + cmd.Env = append(os.Environ(), secretsEnv...) cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr cmd.Stdout = os.Stdout @@ -213,7 +268,7 @@ func main() { os.Exit(cmd.ProcessState.ExitCode()) } - err = syscall.Exec(binary, entrypointCmd, envs) + err = syscall.Exec(binary, entrypointCmd, secretsEnv) if err != nil { logger.Error(fmt.Errorf("failed to exec process: %w", err).Error(), slog.String("entrypoint", fmt.Sprint(entrypointCmd))) diff --git a/provider/file/file.go b/provider/file/file.go index 7b388a7..f594368 100644 --- a/provider/file/file.go +++ b/provider/file/file.go @@ -17,10 +17,7 @@ package file import ( "context" "fmt" - "os" - "strings" - - "gopkg.in/yaml.v3" + "io/fs" "github.com/bank-vaults/secret-init/provider" ) @@ -28,51 +25,72 @@ import ( const ProviderName = "file" type Provider struct { - secretsFilePath string + fs fs.FS } -func NewProvider(secretsFilePath string) (provider.Provider, error) { - return &Provider{secretsFilePath: secretsFilePath}, nil -} +func NewProvider(fs fs.FS) (provider.Provider, error) { + if fs == nil { + return nil, fmt.Errorf("file system is nil") + } -func (provider *Provider) LoadSecrets(_ context.Context, envs map[string]string) ([]string, error) { - // extract secrets from the file to a map - secretsMap, err := provider.getSecretsFromFile() + isEmpty, err := isFileSystemEmpty(fs) if err != nil { - - return nil, fmt.Errorf("failed to get secrets from file: %w", err) + return nil, fmt.Errorf("failed to check if file system is empty: %w", err) } + if isEmpty { + return nil, fmt.Errorf("file system is empty") + } + + return &Provider{fs: fs}, nil +} +func (provider *Provider) LoadSecrets(_ context.Context, paths []string) ([]string, error) { var secrets []string - for envKey, envValue := range envs { - if strings.HasPrefix(envValue, "file:") { - // Check if the requested secret is in the loaded secret map - envValue = strings.TrimPrefix(envValue, "file:") - secret, ok := secretsMap[envValue] - if !ok { - - return nil, fmt.Errorf("secret %s not found", envKey) - } - secrets = append(secrets, fmt.Sprintf("%s=%s", envKey, secret)) + + for i, path := range paths { + secret, err := provider.getSecretFromFile(path) + if err != nil { + return nil, fmt.Errorf("failed to get secret from file: %w", err) } + // Add the secret path with a "|" separator character + // to the secrets slice along with the secret + // so later we can match it to the environment key + secrets = append(secrets, paths[i]+"|"+secret) } return secrets, nil } -func (provider *Provider) getSecretsFromFile() (map[string]string, error) { - data, err := os.ReadFile(provider.secretsFilePath) +func isFileSystemEmpty(fsys fs.FS) (bool, error) { + dir, err := fs.ReadDir(fsys, ".") + fmt.Println(dir, err) if err != nil { + return false, err + } - return nil, fmt.Errorf("failed to read secrets file: %w", err) + for _, entry := range dir { + if entry.IsDir() || entry.Type().IsRegular() { + return false, nil + } } - secretsMap := make(map[string]string) - err = yaml.Unmarshal(data, &secretsMap) + return true, nil +} + +func (provider *Provider) getSecretFromFile(path string) (string, error) { + content, err := provider.readFile(path) if err != nil { + return "", err + } + + return string(content), nil +} - return nil, fmt.Errorf("failed to unmarshal YAML: %w", err) +func (provider *Provider) readFile(path string) ([]byte, error) { + content, err := fs.ReadFile(provider.fs, path) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) } - return secretsMap, nil + return content, nil } diff --git a/provider/file/file_test.go b/provider/file/file_test.go index 3b87d7d..619b7aa 100644 --- a/provider/file/file_test.go +++ b/provider/file/file_test.go @@ -16,152 +16,101 @@ package file import ( "context" - "os" - "strings" + "io/fs" "testing" -) - -func TestMain(m *testing.M) { - exitCode := m.Run() - - // teardown environment variables after tests are done - teardownEnvs() - - os.Exit(exitCode) -} - -func TestNewFileProvider(t *testing.T) { - // create a new secret file and write secrets into it - tmpfile := createTempFileWithContent(t) - defer os.Remove(tmpfile.Name()) - - // create new environment variables - setupEnvs(t, tmpfile) - - fileProvider, err := NewProvider(os.Getenv("SECRETS_FILE_PATH")) - if err != nil { - t.Fatal(err) - } - - // check if file provider is correctly created - _, ok := fileProvider.(*Provider) - if !ok { - t.Fatal("provider is not of type file") - } -} - -func TestFileLoadSecrets(t *testing.T) { - // create a new secret-file and write secrets into it - tmpfile := createTempFileWithContent(t) - defer os.Remove(tmpfile.Name()) - - // create new environment variables - // for file-path and secrets to get - setupEnvs(t, tmpfile) - - fileProvider, err := NewProvider(os.Getenv("SECRETS_FILE_PATH")) - if err != nil { - t.Fatal(err) - } - - environ := make(map[string]string, len(os.Environ())) - for _, env := range os.Environ() { - split := strings.SplitN(env, "=", 2) - name := split[0] - value := split[1] - environ[name] = value - } - - ctx := context.Background() - envs, err := fileProvider.LoadSecrets(ctx, environ) - if err != nil { - t.Fatal(err) - } - - test := []string{ - "MYSQL_PASSWORD=3xtr3ms3cr3t", - "AWS_SECRET_ACCESS_KEY=s3cr3t", - "AWS_ACCESS_KEY_ID=secretId", - } + "testing/fstest" - // check if secrets have been correctly loaded - areEqual(t, envs, test) -} - -func areEqual(t *testing.T, actual, expected []string) { - actualMap := make(map[string]string, len(expected)) - expectedMap := make(map[string]string, len(expected)) - - for _, env := range actual { - split := strings.SplitN(env, "=", 2) - key := split[0] - value := split[1] - actualMap[key] = value - } - - for _, env := range expected { - split := strings.SplitN(env, "=", 2) - key := split[0] - value := split[1] - expectedMap[key] = value - } - - for key, actualValue := range actualMap { - expectedValue, ok := expectedMap[key] - if !ok || actualValue != expectedValue { - t.Fatalf("mismatch for key %s: actual: %s, expected: %s", key, actualValue, expectedValue) - } - } -} - -func createTempFileWithContent(t *testing.T) *os.File { - content := []byte("sqlPassword: 3xtr3ms3cr3t\nawsSecretAccessKey: s3cr3t\nawsAccessKeyId: secretId\n") - tmpfile, err := os.CreateTemp("", "secrets-*.yaml") - if err != nil { - t.Fatal(err) - } - - _, err = tmpfile.Write(content) - if err != nil { - t.Fatal(err) - } + "github.com/stretchr/testify/assert" +) - err = tmpfile.Close() - if err != nil { - t.Fatal(err) +func TestNewProvider(t *testing.T) { + tests := []struct { + name string + fs fs.FS + wantErr bool + wantType bool + }{ + { + name: "Valid file system", + fs: fstest.MapFS{ + "test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")}, + "test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")}, + }, + wantErr: false, + wantType: true, + }, + { + name: "Nil file system", + fs: nil, + wantErr: true, + wantType: false, + }, + { + name: "Empty file system", + fs: fstest.MapFS{}, + wantErr: true, + wantType: false, + }, + } + + for _, tt := range tests { + ttp := tt + t.Run(ttp.name, func(t *testing.T) { + + prov, err := NewProvider(ttp.fs) + if (err != nil) != ttp.wantErr { + t.Fatalf("NewProvider() error = %v, wantErr %v", err, ttp.wantErr) + return + } + // Use type assertion to check if the provider is of the correct type + _, ok := prov.(*Provider) + if ok != ttp.wantType { + t.Fatalf("NewProvider() = %v, wantType %v", ok, ttp.wantType) + } + }) } - - return tmpfile } -func setupEnvs(t *testing.T, tmpfile *os.File) { - err := os.Setenv("PROVIDER", "file") - if err != nil { - t.Fatal(err) - } - err = os.Setenv("SECRETS_FILE_PATH", tmpfile.Name()) - if err != nil { - t.Fatal(err) +func TestLoadSecrets(t *testing.T) { + tests := []struct { + name string + fs fs.FS + paths []string + wantErr bool + wantData []string + }{ + { + name: "Load secrets successfully", + fs: fstest.MapFS{ + "test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")}, + "test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")}, + }, + paths: []string{"test/secrets/sqlpass.txt", "test/secrets/awsaccess.txt"}, + wantErr: false, + wantData: []string{"test/secrets/sqlpass.txt|3xtr3ms3cr3t", "test/secrets/awsaccess.txt|s3cr3t"}, + }, + { + name: "Fail to load secrets due to invalid path", + fs: fstest.MapFS{ + "test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")}, + "test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")}, + }, + paths: []string{"test/secrets/mistake/sqlpass.txt", "test/secrets/mistake/awsaccess.txt"}, + wantErr: true, + wantData: nil, + }, + } + + for _, tt := range tests { + ttp := tt + t.Run(ttp.name, func(t *testing.T) { + provider, err := NewProvider(ttp.fs) + if assert.NoError(t, err, "Unexpected error") { + secrets, err := provider.LoadSecrets(context.Background(), ttp.paths) + assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status") + + assert.Equal(t, ttp.wantData, secrets, "Unexpected secrets loaded") + } + }) } - - err = os.Setenv("MYSQL_PASSWORD", "file:sqlPassword") - if err != nil { - t.Fatal(err) - } - err = os.Setenv("AWS_SECRET_ACCESS_KEY", "file:awsSecretAccessKey") - if err != nil { - t.Fatal(err) - } - err = os.Setenv("AWS_ACCESS_KEY_ID", "file:awsAccessKeyId") - if err != nil { - t.Fatal(err) - } -} - -func teardownEnvs() { - os.Unsetenv("PROVIDER") - os.Unsetenv("SECRETS_FILE_PATH") - os.Unsetenv("MYSQL_PASSWORD") - os.Unsetenv("AWS_SECRET_ACCESS_KEY") - os.Unsetenv("AWS_ACCESS_KEY_ID") } diff --git a/provider/provider.go b/provider/provider.go index 09954bd..c690a0c 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -18,5 +18,5 @@ import "context" // Provider is an interface for securely loading secrets based on environment variables. type Provider interface { - LoadSecrets(ctx context.Context, envs map[string]string) ([]string, error) + LoadSecrets(ctx context.Context, paths []string) ([]string, error) } From 725664877abc85c64e09ef73fc3281cca6edb052 Mon Sep 17 00:00:00 2001 From: Bence Csati Date: Tue, 28 Nov 2023 21:46:36 +0100 Subject: [PATCH 10/10] fix: Factored out main logic, added secret struct, fixed affected parts Signed-off-by: Bence Csati --- env.go | 83 ++++++++++++++++++++++++++++++++++++++ main.go | 61 ++-------------------------- provider/file/file.go | 58 +++++++------------------- provider/file/file_test.go | 37 ++++++++++------- provider/provider.go | 8 +++- 5 files changed, 130 insertions(+), 117 deletions(-) create mode 100644 env.go diff --git a/env.go b/env.go new file mode 100644 index 0000000..606fa77 --- /dev/null +++ b/env.go @@ -0,0 +1,83 @@ +// Copyright © 2023 Bank-Vaults Maintainers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" + "os" + "strings" + + "github.com/bank-vaults/secret-init/provider" + "github.com/bank-vaults/secret-init/provider/file" +) + +func GetEnvironMap() map[string]string { + environ := make(map[string]string, len(os.Environ())) + for _, env := range os.Environ() { + split := strings.SplitN(env, "=", 2) + name := split[0] + value := split[1] + environ[name] = value + } + + return environ +} + +func ExtractPathsFromEnvs(envs map[string]string) []string { + var secretPaths []string + + for _, path := range envs { + if p, path := getProviderPath(path); p != nil { + secretPaths = append(secretPaths, path) + } + } + + return secretPaths +} + +func CreateSecretEnvsFrom(envs map[string]string, secrets []provider.Secret) ([]string, error) { + // Reverse the map so we can match + // the environment variable key to the secret + // by using the secret path + reversedEnvs := make(map[string]string) + for envKey, path := range envs { + if p, path := getProviderPath(path); p != nil { + reversedEnvs[path] = envKey + } + } + + var secretsEnv []string + for _, secret := range secrets { + path := secret.Path + value := secret.Value + key, ok := reversedEnvs[path] + if !ok { + return nil, fmt.Errorf("failed to find environment variable key for secret path: %s", path) + } + secretsEnv = append(secretsEnv, fmt.Sprintf("%s=%s", key, value)) + } + + return secretsEnv, nil +} + +// Returns the detected provider name and path with removed prefix +func getProviderPath(path string) (*string, string) { + if strings.HasPrefix(path, "file:") { + var fileProviderName = file.ProviderName + return &fileProviderName, strings.TrimPrefix(path, "file:") + } + + return nil, path +} diff --git a/main.go b/main.go index 2a8bcd3..5c42d0d 100644 --- a/main.go +++ b/main.go @@ -24,7 +24,6 @@ import ( "os/exec" "os/signal" "slices" - "strings" "syscall" "time" @@ -39,7 +38,7 @@ import ( func NewProvider(providerName string) (provider.Provider, error) { switch providerName { case file.ProviderName: - provider, err := file.NewProvider(os.DirFS("/secrets")) + provider, err := file.NewProvider(os.DirFS("/")) if err != nil { return nil, err } @@ -51,59 +50,6 @@ func NewProvider(providerName string) (provider.Provider, error) { } } -func CreateMapOfEnvs() map[string]string { - environ := make(map[string]string, len(os.Environ())) - for _, env := range os.Environ() { - split := strings.SplitN(env, "=", 2) - name := split[0] - value := split[1] - environ[name] = value - } - - return environ -} - -func ExtractPathsFromEnvs(envs map[string]string) []string { - var secretPaths []string - - for _, path := range envs { - if strings.HasPrefix(path, "file:") { - path = strings.TrimPrefix(path, "file://") - secretPaths = append(secretPaths, path) - } - } - - return secretPaths -} - -func CreateEnvsFromLoadedSecrets(envs map[string]string, secrets []string) ([]string, error) { - // Reverse the map so we can match - // the environment variable key to the secret - // by using the secret path - reversedEnvs := make(map[string]string) - for envKey, path := range envs { - if strings.HasPrefix(path, "file:") { - path = strings.TrimPrefix(path, "file://") - reversedEnvs[path] = envKey - } - } - - var secretsEnv []string - for _, secret := range secrets { - split := strings.SplitN(secret, "|", 2) - secretPath := split[0] - - secretValue := split[1] - secretKey, ok := reversedEnvs[secretPath] - if !ok { - return nil, fmt.Errorf("failed to find environment variable key for secret path: %s", secretPath) - } - secretsEnv = append(secretsEnv, fmt.Sprintf("%s=%s", secretKey, secretValue)) - } - - return secretsEnv, nil -} - func main() { var logger *slog.Logger { @@ -189,7 +135,7 @@ func main() { os.Exit(1) } - environ := CreateMapOfEnvs() + environ := GetEnvironMap() paths := ExtractPathsFromEnvs(environ) ctx := context.Background() @@ -199,8 +145,7 @@ func main() { os.Exit(1) } - - secretsEnv, err := CreateEnvsFromLoadedSecrets(environ, secrets) + secretsEnv, err := CreateSecretEnvsFrom(environ, secrets) if err != nil { logger.Error(fmt.Errorf("failed to create environment variables from loaded secrets: %w", err).Error()) diff --git a/provider/file/file.go b/provider/file/file.go index f594368..ba6a080 100644 --- a/provider/file/file.go +++ b/provider/file/file.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "io/fs" + "strings" "github.com/bank-vaults/secret-init/provider" ) @@ -33,64 +34,33 @@ func NewProvider(fs fs.FS) (provider.Provider, error) { return nil, fmt.Errorf("file system is nil") } - isEmpty, err := isFileSystemEmpty(fs) - if err != nil { - return nil, fmt.Errorf("failed to check if file system is empty: %w", err) - } - if isEmpty { - return nil, fmt.Errorf("file system is empty") - } - return &Provider{fs: fs}, nil } -func (provider *Provider) LoadSecrets(_ context.Context, paths []string) ([]string, error) { - var secrets []string +func (p *Provider) LoadSecrets(_ context.Context, paths []string) ([]provider.Secret, error) { + var secrets []provider.Secret - for i, path := range paths { - secret, err := provider.getSecretFromFile(path) + for _, path := range paths { + secret, err := p.getSecretFromFile(path) if err != nil { return nil, fmt.Errorf("failed to get secret from file: %w", err) } - // Add the secret path with a "|" separator character - // to the secrets slice along with the secret - // so later we can match it to the environment key - secrets = append(secrets, paths[i]+"|"+secret) - } - - return secrets, nil -} - -func isFileSystemEmpty(fsys fs.FS) (bool, error) { - dir, err := fs.ReadDir(fsys, ".") - fmt.Println(dir, err) - if err != nil { - return false, err - } - for _, entry := range dir { - if entry.IsDir() || entry.Type().IsRegular() { - return false, nil - } + secrets = append(secrets, provider.Secret{ + Path: path, + Value: secret, + }) } - return true, nil + return secrets, nil } -func (provider *Provider) getSecretFromFile(path string) (string, error) { - content, err := provider.readFile(path) +func (p *Provider) getSecretFromFile(filepath string) (string, error) { + filepath = strings.TrimLeft(filepath, "/") + content, err := fs.ReadFile(p.fs, filepath) if err != nil { - return "", err + return "", fmt.Errorf("failed to read file: %w", err) } return string(content), nil } - -func (provider *Provider) readFile(path string) ([]byte, error) { - content, err := fs.ReadFile(provider.fs, path) - if err != nil { - return nil, fmt.Errorf("failed to read file: %w", err) - } - - return content, nil -} diff --git a/provider/file/file_test.go b/provider/file/file_test.go index 619b7aa..50ddddb 100644 --- a/provider/file/file_test.go +++ b/provider/file/file_test.go @@ -21,6 +21,8 @@ import ( "testing/fstest" "github.com/stretchr/testify/assert" + + "github.com/bank-vaults/secret-init/provider" ) func TestNewProvider(t *testing.T) { @@ -35,6 +37,7 @@ func TestNewProvider(t *testing.T) { fs: fstest.MapFS{ "test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")}, "test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")}, + "test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")}, }, wantErr: false, wantType: true, @@ -45,18 +48,11 @@ func TestNewProvider(t *testing.T) { wantErr: true, wantType: false, }, - { - name: "Empty file system", - fs: fstest.MapFS{}, - wantErr: true, - wantType: false, - }, } for _, tt := range tests { ttp := tt t.Run(ttp.name, func(t *testing.T) { - prov, err := NewProvider(ttp.fs) if (err != nil) != ttp.wantErr { t.Fatalf("NewProvider() error = %v, wantErr %v", err, ttp.wantErr) @@ -77,25 +73,39 @@ func TestLoadSecrets(t *testing.T) { fs fs.FS paths []string wantErr bool - wantData []string + wantData []provider.Secret }{ { name: "Load secrets successfully", fs: fstest.MapFS{ "test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")}, "test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")}, + "test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")}, + }, + paths: []string{ + "test/secrets/sqlpass.txt", + "test/secrets/awsaccess.txt", + "test/secrets/awsid.txt", + }, + wantErr: false, + wantData: []provider.Secret{ + {Path: "test/secrets/sqlpass.txt", Value: "3xtr3ms3cr3t"}, + {Path: "test/secrets/awsaccess.txt", Value: "s3cr3t"}, + {Path: "test/secrets/awsid.txt", Value: "secretId"}, }, - paths: []string{"test/secrets/sqlpass.txt", "test/secrets/awsaccess.txt"}, - wantErr: false, - wantData: []string{"test/secrets/sqlpass.txt|3xtr3ms3cr3t", "test/secrets/awsaccess.txt|s3cr3t"}, }, { name: "Fail to load secrets due to invalid path", fs: fstest.MapFS{ "test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")}, "test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")}, + "test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")}, + }, + paths: []string{ + "test/secrets/mistake/sqlpass.txt", + "test/secrets/mistake/awsaccess.txt", + "test/secrets/mistake/awsid.txt", }, - paths: []string{"test/secrets/mistake/sqlpass.txt", "test/secrets/mistake/awsaccess.txt"}, wantErr: true, wantData: nil, }, @@ -108,8 +118,7 @@ func TestLoadSecrets(t *testing.T) { if assert.NoError(t, err, "Unexpected error") { secrets, err := provider.LoadSecrets(context.Background(), ttp.paths) assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status") - - assert.Equal(t, ttp.wantData, secrets, "Unexpected secrets loaded") + assert.ElementsMatch(t, ttp.wantData, secrets, "Unexpected secrets loaded") } }) } diff --git a/provider/provider.go b/provider/provider.go index c690a0c..14e07fc 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -18,5 +18,11 @@ import "context" // Provider is an interface for securely loading secrets based on environment variables. type Provider interface { - LoadSecrets(ctx context.Context, paths []string) ([]string, error) + LoadSecrets(ctx context.Context, paths []string) ([]Secret, error) +} + +// Secret holds Provider-specific secret data. +type Secret struct { + Path string + Value string }