Skip to content

Commit

Permalink
feat(vault provider tests): Add more tests, and minor fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Bence Csati <[email protected]>
  • Loading branch information
csatib02 committed Dec 21, 2023
1 parent 46216ef commit 4f7ae51
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 52 deletions.
6 changes: 3 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,20 @@ import (
func NewProvider(providerName string, logger *slog.Logger, sigs chan os.Signal) (provider.Provider, error) {
switch providerName {
case file.ProviderName:
config := file.NewConfig(logger)
config := file.NewConfig()
provider, err := file.NewProvider(config)
if err != nil {
return nil, fmt.Errorf("failed to create file provider: %w", err)
}

return provider, nil
case vault.ProviderName:
config, err := vault.NewConfig(logger, sigs)
config, err := vault.NewConfig()
if err != nil {
return nil, fmt.Errorf("failed to create vault config: %w", err)
}

provider, err := vault.NewProvider(config)
provider, err := vault.NewProvider(config, logger, sigs)
if err != nil {
return nil, fmt.Errorf("failed to create vault provider: %w", err)
}
Expand Down
7 changes: 4 additions & 3 deletions provider/file/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
package file

import (
"log/slog"
"fmt"
"os"
)

Expand All @@ -28,10 +28,11 @@ type Config struct {
MountPath string `json:"mountPath"`
}

func NewConfig(logger *slog.Logger) *Config {
func NewConfig() *Config {
mountPath, ok := os.LookupEnv(EnvPrefix + "MOUNT_PATH")
if !ok {
logger.Warn("Mount path not provided. Using default.", slog.String("Default Mount Path", DefaultMountPath))
fmt.Printf("Mount path not provided. Using default. Default Mount Path: %s\n", DefaultMountPath)

mountPath = DefaultMountPath
}

Expand Down
3 changes: 1 addition & 2 deletions provider/file/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package file

import (
"log/slog"
"os"
"testing"

Expand Down Expand Up @@ -48,7 +47,7 @@ func TestConfig(t *testing.T) {
for envKey, envVal := range ttp.env {
os.Setenv(envKey, envVal)
}
config := NewConfig(slog.Default())
config := NewConfig()

assert.Equal(t, ttp.wantMountPath, config.MountPath, "Unexpected mount path")
})
Expand Down
16 changes: 8 additions & 8 deletions provider/file/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ func TestNewProvider(t *testing.T) {

func TestLoadSecrets(t *testing.T) {
tests := []struct {
name string
paths []string
wantErr bool
wantData []provider.Secret
name string
paths []string
wantErr bool
wantSecrets []provider.Secret
}{
{
name: "Load secrets successfully",
Expand All @@ -67,7 +67,7 @@ func TestLoadSecrets(t *testing.T) {
"test/secrets/awsid.txt",
},
wantErr: false,
wantData: []provider.Secret{
wantSecrets: []provider.Secret{
{Path: "test/secrets/sqlpass.txt", Value: "3xtr3ms3cr3t"},
{Path: "test/secrets/awsaccess.txt", Value: "s3cr3t"},
{Path: "test/secrets/awsid.txt", Value: "secretId"},
Expand All @@ -80,8 +80,8 @@ func TestLoadSecrets(t *testing.T) {
"test/secrets/mistake/awsaccess.txt",
"test/secrets/mistake/awsid.txt",
},
wantErr: true,
wantData: nil,
wantErr: true,
wantSecrets: nil,
},
}

Expand All @@ -97,7 +97,7 @@ func TestLoadSecrets(t *testing.T) {
secrets, err := provider.LoadSecrets(context.Background(), ttp.paths)

assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status")
assert.Equal(t, ttp.wantData, secrets, "Unexpected secrets")
assert.Equal(t, ttp.wantSecrets, secrets, "Unexpected secrets")
})
}
}
23 changes: 8 additions & 15 deletions provider/vault/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
package vault

import (
"errors"
"log/slog"
"fmt"
"os"
"strings"

"github.com/spf13/cast"
)

const (
EnvPrefix = "VAULT_"
EnvPrefix = "VAULT_"
SecretInitDaemonEnv = "SECRET_INIT_DAEMON"
// The special value for SECRET_INIT which marks that the login token needs to be passed through to the application
// which was acquired during the vault client initialization.
vaultLogin = "vault:login"
Expand All @@ -44,8 +44,6 @@ type Config struct {
IgnoreMissingSecrets bool `json:"ignoreMissingSecrets"`
FromPath string `json:"fromPath"`
RevokeToken bool `json:"revokeToken"`
Logger *slog.Logger
Sigs chan os.Signal
}

type envType struct {
Expand Down Expand Up @@ -85,7 +83,7 @@ var sanitizeEnvmap = map[string]envType{
"SECRET_INIT_DAEMON": {login: false},
}

func NewConfig(logger *slog.Logger, sigs chan os.Signal) (*Config, error) {
func NewConfig() (*Config, error) {
var (
role, authPath, authMethod string
hasRole, hasPath, hasAuthMethod bool
Expand All @@ -102,9 +100,7 @@ func NewConfig(logger *slog.Logger, sigs chan os.Signal) (*Config, error) {
if b, err := os.ReadFile(tokenFile); err == nil {
vaultToken = string(b)
} else {
logger.Error("could not read vault token file", slog.String("file", tokenFile))

return nil, err
return nil, fmt.Errorf("failed to read token file: %w", err)
}
} else {
if isLogin {
Expand All @@ -115,9 +111,8 @@ func NewConfig(logger *slog.Logger, sigs chan os.Signal) (*Config, error) {
authPath, hasPath = os.LookupEnv(EnvPrefix + "PATH")
authMethod, hasAuthMethod = os.LookupEnv(EnvPrefix + "AUTH_METHOD")
if !hasRole || !hasPath || !hasAuthMethod {
logger.Error("Incomplete authentication configuration. Make sure VAULT_ROLE, VAULT_PATH, and VAULT_AUTH_METHOD are set.")

return nil, errors.New("incomplete authentication configuration")
return nil, fmt.Errorf("incomplete authentication configuration %s, %s, and %s",
"VAULT_ROLE", "VAULT_PATH", "VAULT_AUTH_METHOD")
}
}

Expand All @@ -138,7 +133,7 @@ func NewConfig(logger *slog.Logger, sigs chan os.Signal) (*Config, error) {
transitKeyID := os.Getenv(EnvPrefix + "TRANSIT_KEY_ID")
transitPath := os.Getenv(EnvPrefix + "TRANSIT_PATH")
transitBatchSize := cast.ToInt(os.Getenv(EnvPrefix + "TRANSIT_BATCH_SIZE"))
daemonMode := cast.ToBool(os.Getenv("SECRET_INIT_DAEMON_MODE"))
daemonMode := cast.ToBool(os.Getenv(SecretInitDaemonEnv))
// Used both for reading secrets and transit encryption
ignoreMissingSecrets := cast.ToBool(os.Getenv(EnvPrefix + "IGNORE_MISSING_SECRETS"))

Expand All @@ -159,7 +154,5 @@ func NewConfig(logger *slog.Logger, sigs chan os.Signal) (*Config, error) {
IgnoreMissingSecrets: ignoreMissingSecrets,
FromPath: fromPath,
RevokeToken: revokeToken,
Logger: logger,
Sigs: sigs,
}, nil
}
38 changes: 23 additions & 15 deletions provider/vault/config_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package vault

import (
"log/slog"
"os"
"path/filepath"
"testing"
Expand All @@ -10,9 +9,6 @@ import (
)

func TestConfig(t *testing.T) {
// mock logger, sigs, and tokenfile
logger := slog.Default()
sigs := make(chan os.Signal, 1)
tokenFile := newTokenFile(t)
defer os.Remove(tokenFile)

Expand All @@ -25,15 +21,28 @@ func TestConfig(t *testing.T) {
{
name: "Valid login configuration with Token",
env: map[string]string{
"VAULT_TOKEN": vaultLogin,
"VAULT_TOKEN_FILE": tokenFile,
"VAULT_TOKEN": vaultLogin,
"VAULT_TOKEN_FILE": tokenFile,
"VAULT_PASSTHROUGH": "VAULT_AGENT_ADDR, VAULT_CLI_NO_COLOR",
"VAULT_TRANSIT_KEY_ID": "test-key",
"VAULT_TRANSIT_PATH": "transit",
"VAULT_TRANSIT_BATCH_SIZE": "10",
"SECRET_INIT_DAEMON": "true",
"VAULT_IGNORE_MISSING_SECRETS": "true",
"VAULT_REVOKE_TOKEN": "true",
"VAULT_FROM_PATH": "secret/data/test",
},
wantConfig: &Config{
Islogin: true,
Token: "root",
TokenFile: tokenFile,
Logger: logger,
Sigs: sigs,
Islogin: true,
Token: "root",
TokenFile: tokenFile,
TransitKeyID: "test-key",
TransitPath: "transit",
TransitBatchSize: 10,
DaemonMode: true,
IgnoreMissingSecrets: true,
FromPath: "secret/data/test",
RevokeToken: true,
},
wantErr: false,
},
Expand All @@ -51,8 +60,6 @@ func TestConfig(t *testing.T) {
Role: "test-app-role",
AuthPath: "auth/approle/test/login",
AuthMethod: "test-approle",
Logger: logger,
Sigs: sigs,
},
wantErr: false,
},
Expand All @@ -65,7 +72,7 @@ func TestConfig(t *testing.T) {
wantErr: true,
},
{
name: "Invalid login configuration missing role/path credentials",
name: "Invalid login configuration missing role - path credentials",
env: map[string]string{
"VAULT_PATH": "auth/approle/test/login",
"VAULT_AUTH_METHOD": "test-approle",
Expand All @@ -82,7 +89,7 @@ func TestConfig(t *testing.T) {
os.Setenv(envKey, envVal)
}

config, err := NewConfig(logger, sigs)
config, err := NewConfig()

assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status")
assert.Equal(t, ttp.wantConfig, config, "Unexpected config")
Expand All @@ -106,5 +113,6 @@ func newTokenFile(t *testing.T) string {
if err != nil {
t.Fatalf("Failed to write to a temporary token file: %v", err)
}

return tokenFile.Name()
}
13 changes: 7 additions & 6 deletions provider/vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"fmt"
"log/slog"
"os"
"strings"

"github.com/bank-vaults/internal/injector"
Expand Down Expand Up @@ -57,8 +58,8 @@ func (s *sanitized) append(key string, value string) {
}
}

func NewProvider(config *Config) (provider.Provider, error) {
clientOptions := []vault.ClientOption{vault.ClientLogger(clientLogger{config.Logger})}
func NewProvider(config *Config, logger *slog.Logger, sigs chan os.Signal) (provider.Provider, error) {
clientOptions := []vault.ClientOption{vault.ClientLogger(clientLogger{logger})}
if config.TokenFile != "" {
clientOptions = append(clientOptions, vault.ClientToken(config.Token))
} else {
Expand All @@ -72,7 +73,7 @@ func NewProvider(config *Config) (provider.Provider, error) {

client, err := vault.NewClientWithOptions(clientOptions...)
if err != nil {
config.Logger.Error(fmt.Errorf("failed to create vault client: %w", err).Error())
logger.Error(fmt.Errorf("failed to create vault client: %w", err).Error())

return nil, err
}
Expand All @@ -88,8 +89,8 @@ func NewProvider(config *Config) (provider.Provider, error) {
var secretRenewer injector.SecretRenewer

if config.DaemonMode {
secretRenewer = daemonSecretRenewer{client: client, sigs: config.Sigs, logger: config.Logger}
config.Logger.Info("Daemon mode enabled. Will renew secrets in the background.")
secretRenewer = daemonSecretRenewer{client: client, sigs: sigs, logger: logger}
logger.Info("Daemon mode enabled. Will renew secrets in the background.")
}

return &Provider{
Expand All @@ -99,7 +100,7 @@ func NewProvider(config *Config) (provider.Provider, error) {
secretRenewer: secretRenewer,
fromPath: config.FromPath,
revokeToken: config.RevokeToken,
logger: config.Logger,
logger: logger,
}, nil
}

Expand Down
64 changes: 64 additions & 0 deletions provider/vault/vault_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package vault

import (
"bytes"
"log/slog"
"os"
"testing"

"github.com/bank-vaults/internal/injector"
"github.com/stretchr/testify/assert"
)

func TestNewProvider(t *testing.T) {
tests := []struct {
name string
config *Config
wantInjectorConfig injector.Config
wantErr bool
wantType bool
}{
{
name: "Valid Provider with Token",
config: &Config{
Islogin: true,
TokenFile: "root",
Token: "root",
TransitKeyID: "test-key",
TransitPath: "transit",
TransitBatchSize: 10,
DaemonMode: true,
IgnoreMissingSecrets: true,
FromPath: "secret/data/test",
RevokeToken: true,
},
wantErr: false,
wantType: true,
},
{
name: "Fail to create vault client",
config: &Config{},
wantErr: true,
wantType: false,
},
}

for _, tt := range tests {
ttp := tt

// Redirect logs to avoid polluting the test output
var buf bytes.Buffer
handler := slog.NewTextHandler(&buf, nil)
logger := slog.New(handler)

t.Run(ttp.name, func(t *testing.T) {
provider, err := NewProvider(ttp.config, logger, make(chan os.Signal))

assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status")
assert.Equal(t, ttp.wantType, provider != nil, "Unexpected provider type")
})

buf.Truncate(0)
}

}

0 comments on commit 4f7ae51

Please sign in to comment.