Skip to content

Commit

Permalink
feat(file_test): adjust tests
Browse files Browse the repository at this point in the history
Signed-off-by: Bence Csati <[email protected]>
  • Loading branch information
csatib02 committed Nov 11, 2023
1 parent 6a3f369 commit 370dae2
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 24 deletions.
1 change: 0 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ func main() {
}

environ := make(map[string]string, len(os.Environ()))

for _, env := range os.Environ() {
split := strings.SplitN(env, "=", 2)
name := split[0]
Expand Down
4 changes: 1 addition & 3 deletions provider/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ func NewFileProvider(secretsFilePath string) (provider.Provider, error) {
}

func (provider *Provider) LoadSecrets(_ context.Context, envs map[string]string) ([]string, error) {

//envs that has a value with "file:" prefix needs to be loaded
// envs that has a "file:" prefix needs to be loaded
var secrets []string
for key, value := range envs {
if strings.HasPrefix(value, "file:") {
Expand All @@ -58,7 +57,6 @@ func (provider *Provider) getSecretFromFile(key string) (string, error) {
lines := strings.Split(string(provider.SecretData), "\n")
for _, line := range lines {
split := strings.SplitN(line, "=", 2)

if split[0] == key {
return split[1], nil
}
Expand Down
72 changes: 52 additions & 20 deletions provider/file/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package file

import (
"bytes"
"context"
"os"
"strings"
Expand All @@ -23,51 +24,73 @@ import (
"github.com/bank-vaults/secret-init/provider"
)

func TestMain(m *testing.M) {
exitCode := m.Run()

// teardown environment variables after tests are done
teardownEnvs()

os.Exit(exitCode)
}

func TestNewFileProvider(t *testing.T) {
//create a new secret-file and write secrets into it
// create a new secret file and write secrets into it
tmpfile := createTempFileWithContent(t)

defer os.Remove(tmpfile.Name())

//create new environment variables
//for file-path and secrets to get
// create new environment variables
// for file-path and secrets to get
setupEnvs(t, tmpfile)

var fileProvider provider.Provider
providerName := os.Getenv("PROVIDER")
if providerName == "file" {
_, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH"))
switch providerName {
case "file":
newFileProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH"))
fileProvider = newFileProvider
if err != nil {
t.Fatal(err)
}
} else {
default:
t.Fatalf("invalid provider specified: %s", providerName)
}

expectedSecretData, err := os.ReadFile(tmpfile.Name())
if err != nil {
t.Fatal(err)
}

// access the provider implementation to get secretdata
// then check if file provider is correctly created
// and file-path is read
if fileProvider, ok := fileProvider.(*Provider); !ok || !bytes.Equal(expectedSecretData, fileProvider.SecretData) {
t.Fatal("failed to create file provider")
}
}

func TestFileLoadSecrets(t *testing.T) {
//create a new secret-file and write secrets into it
// create a new secret-file and write secrets into it
tmpfile := createTempFileWithContent(t)

defer os.Remove(tmpfile.Name())

//create new environment variables
//for file-path and secrets to get
// create new environment variables
// for file-path and secrets to get
setupEnvs(t, tmpfile)

var provider provider.Provider
var fileProvider provider.Provider
providerName := os.Getenv("PROVIDER")
if providerName == "file" {
newProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH"))
switch providerName {
case "file":
newFileProvider, err := NewFileProvider(os.Getenv("SECRETS_FILE_PATH"))
fileProvider = newFileProvider
if err != nil {
t.Fatal(err)
}
provider = newProvider
} else {
default:
t.Fatalf("invalid provider specified: %s", providerName)
}

environ := make(map[string]string, len(os.Environ()))

for _, env := range os.Environ() {
split := strings.SplitN(env, "=", 2)
name := split[0]
Expand All @@ -76,7 +99,7 @@ func TestFileLoadSecrets(t *testing.T) {
}

ctx := context.Background()
envs, err := provider.LoadSecrets(ctx, environ)
envs, err := fileProvider.LoadSecrets(ctx, environ)
if err != nil {
t.Fatal(err)
}
Expand All @@ -86,7 +109,8 @@ func TestFileLoadSecrets(t *testing.T) {
"AWS_SECRET_ACCESS_KEY=s3cr3t",
"AWS_ACCESS_KEY_ID=secretId",
}
//check if secrets have been correctly loaded

// check if secrets have been correctly loaded
areEqual(t, envs, test)
}

Expand All @@ -111,7 +135,7 @@ func areEqual(t *testing.T, actual, expected []string) {
for key, actualValue := range actualMap {
expectedValue, ok := expectedMap[key]
if !ok || actualValue != expectedValue {
t.Fatalf("Mismatch for key %s: actual: %s, expected: %s", key, actualValue, expectedValue)
t.Fatalf("mismatch for key %s: actual: %s, expected: %s", key, actualValue, expectedValue)
}
}
}
Expand Down Expand Up @@ -159,3 +183,11 @@ func setupEnvs(t *testing.T, tmpfile *os.File) {
t.Fatal(err)
}
}

func teardownEnvs() {
os.Unsetenv("PROVIDER")
os.Unsetenv("SECRETS_FILE_PATH")
os.Unsetenv("MYSQL_PASSWORD")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
os.Unsetenv("AWS_ACCESS_KEY_ID")
}

0 comments on commit 370dae2

Please sign in to comment.