From 370dae2e84aabdf4b6a4c9f14c9783eefe94a841 Mon Sep 17 00:00:00 2001 From: Bence Csati Date: Sat, 11 Nov 2023 11:33:47 +0100 Subject: [PATCH] feat(file_test): adjust tests Signed-off-by: Bence Csati --- main.go | 1 - provider/file/file.go | 4 +-- provider/file/file_test.go | 72 +++++++++++++++++++++++++++----------- 3 files changed, 53 insertions(+), 24 deletions(-) diff --git a/main.go b/main.go index f698398..d9ce1d3 100644 --- a/main.go +++ b/main.go @@ -132,7 +132,6 @@ func main() { } environ := make(map[string]string, len(os.Environ())) - for _, env := range os.Environ() { split := strings.SplitN(env, "=", 2) name := split[0] diff --git a/provider/file/file.go b/provider/file/file.go index 52818be..636285f 100644 --- a/provider/file/file.go +++ b/provider/file/file.go @@ -38,8 +38,7 @@ func NewFileProvider(secretsFilePath string) (provider.Provider, error) { } func (provider *Provider) LoadSecrets(_ context.Context, envs map[string]string) ([]string, error) { - - //envs that has a value with "file:" prefix needs to be loaded + // envs that has a "file:" prefix needs to be loaded var secrets []string for key, value := range envs { if strings.HasPrefix(value, "file:") { @@ -58,7 +57,6 @@ func (provider *Provider) getSecretFromFile(key string) (string, error) { lines := strings.Split(string(provider.SecretData), "\n") for _, line := range lines { split := strings.SplitN(line, "=", 2) - if split[0] == key { return split[1], nil } diff --git a/provider/file/file_test.go b/provider/file/file_test.go index 1e340cb..18d60ef 100644 --- a/provider/file/file_test.go +++ b/provider/file/file_test.go @@ -15,6 +15,7 @@ package file import ( + "bytes" "context" "os" "strings" @@ -23,51 +24,73 @@ import ( "github.com/bank-vaults/secret-init/provider" ) +func TestMain(m *testing.M) { + exitCode := m.Run() + + // teardown environment variables after tests are done + teardownEnvs() + + os.Exit(exitCode) +} + func TestNewFileProvider(t *testing.T) { - //create a new secret-file and write secrets into it + // create a new secret file and write secrets into it tmpfile := createTempFileWithContent(t) - defer os.Remove(tmpfile.Name()) - //create new environment variables - //for file-path and secrets to get + // create new environment variables + // for file-path and secrets to get setupEnvs(t, tmpfile) + var fileProvider provider.Provider providerName := os.Getenv("PROVIDER") - if providerName == "file" { - _, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) + switch providerName { + case "file": + newFileProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) + fileProvider = newFileProvider if err != nil { t.Fatal(err) } - } else { + default: t.Fatalf("invalid provider specified: %s", providerName) } + + expectedSecretData, err := os.ReadFile(tmpfile.Name()) + if err != nil { + t.Fatal(err) + } + + // access the provider implementation to get secretdata + // then check if file provider is correctly created + // and file-path is read + if fileProvider, ok := fileProvider.(*Provider); !ok || !bytes.Equal(expectedSecretData, fileProvider.SecretData) { + t.Fatal("failed to create file provider") + } } func TestFileLoadSecrets(t *testing.T) { - //create a new secret-file and write secrets into it + // create a new secret-file and write secrets into it tmpfile := createTempFileWithContent(t) - defer os.Remove(tmpfile.Name()) - //create new environment variables - //for file-path and secrets to get + // create new environment variables + // for file-path and secrets to get setupEnvs(t, tmpfile) - var provider provider.Provider + var fileProvider provider.Provider providerName := os.Getenv("PROVIDER") - if providerName == "file" { - newProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) + switch providerName { + case "file": + newFileProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH")) + fileProvider = newFileProvider if err != nil { t.Fatal(err) } - provider = newProvider - } else { + default: t.Fatalf("invalid provider specified: %s", providerName) } environ := make(map[string]string, len(os.Environ())) - for _, env := range os.Environ() { split := strings.SplitN(env, "=", 2) name := split[0] @@ -76,7 +99,7 @@ func TestFileLoadSecrets(t *testing.T) { } ctx := context.Background() - envs, err := provider.LoadSecrets(ctx, environ) + envs, err := fileProvider.LoadSecrets(ctx, environ) if err != nil { t.Fatal(err) } @@ -86,7 +109,8 @@ func TestFileLoadSecrets(t *testing.T) { "AWS_SECRET_ACCESS_KEY=s3cr3t", "AWS_ACCESS_KEY_ID=secretId", } - //check if secrets have been correctly loaded + + // check if secrets have been correctly loaded areEqual(t, envs, test) } @@ -111,7 +135,7 @@ func areEqual(t *testing.T, actual, expected []string) { for key, actualValue := range actualMap { expectedValue, ok := expectedMap[key] if !ok || actualValue != expectedValue { - t.Fatalf("Mismatch for key %s: actual: %s, expected: %s", key, actualValue, expectedValue) + t.Fatalf("mismatch for key %s: actual: %s, expected: %s", key, actualValue, expectedValue) } } } @@ -159,3 +183,11 @@ func setupEnvs(t *testing.T, tmpfile *os.File) { t.Fatal(err) } } + +func teardownEnvs() { + os.Unsetenv("PROVIDER") + os.Unsetenv("SECRETS_FILE_PATH") + os.Unsetenv("MYSQL_PASSWORD") + os.Unsetenv("AWS_SECRET_ACCESS_KEY") + os.Unsetenv("AWS_ACCESS_KEY_ID") +}