Skip to content

Commit

Permalink
feat(file): Add config, adjust tests
Browse files Browse the repository at this point in the history
Signed-off-by: Bence Csati <[email protected]>
  • Loading branch information
csatib02 committed Dec 15, 2023
1 parent b2bc5c6 commit 37a4c51
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 52 deletions.
13 changes: 7 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ import (
func NewProvider(providerName string) (provider.Provider, error) {
switch providerName {
case file.ProviderName:
provider, err := file.NewProvider(os.DirFS("/"))
config := file.NewConfig()
provider, err := file.NewProvider(config)
if err != nil {
return nil, err
}
Expand All @@ -55,7 +56,7 @@ func main() {
{
var level slog.Level

err := level.UnmarshalText([]byte(os.Getenv("VAULT_LOG_LEVEL")))
err := level.UnmarshalText([]byte(os.Getenv("SECRET_INIT_LOG_LEVEL")))
if err != nil { // Silently fall back to info level
level = slog.LevelInfo
}
Expand All @@ -68,7 +69,7 @@ func main() {

router := slogmulti.Router()

if cast.ToBool(os.Getenv("VAULT_JSON_LOG")) {
if cast.ToBool(os.Getenv("SECRET_INIT_JSON_LOG")) {
// Send logs with level higher than warning to stderr
router = router.Add(
slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelWarn}),
Expand All @@ -94,7 +95,7 @@ func main() {
)
}

if logServerAddr := os.Getenv("VAULT_ENV_LOG_SERVER"); logServerAddr != "" {
if logServerAddr := os.Getenv("SECRET_INIT_LOG_SERVER"); logServerAddr != "" {
writer, err := net.Dial("udp", logServerAddr)

// We silently ignore syslog connection errors for the lack of a better solution
Expand Down Expand Up @@ -123,8 +124,8 @@ func main() {
os.Exit(1)
}

daemonMode := cast.ToBool(os.Getenv("VAULT_ENV_DAEMON"))
delayExec := cast.ToDuration(os.Getenv("VAULT_ENV_DELAY"))
daemonMode := cast.ToBool(os.Getenv("SECRET_INIT_DAEMON"))
delayExec := cast.ToDuration(os.Getenv("SECRET_INIT_DELAY"))

entrypointCmd := os.Args[1:]

Expand Down
21 changes: 21 additions & 0 deletions provider/file/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package file

import "os"

const (
EnvPrefix = "FILE_"
DefaultMountPath = "/"
)

type Config struct {
MountPath string `json:"mountPath"`
}

func NewConfig() *Config {
mountPath, ok := os.LookupEnv(EnvPrefix + "MOUNT_PATH")
if !ok {
mountPath = DefaultMountPath
}

return &Config{MountPath: mountPath}
}
41 changes: 41 additions & 0 deletions provider/file/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package file

import (
"os"
"testing"
)

func TestConfig(t *testing.T) {
tests := []struct {
name string
env map[string]string
wantMountPath string
}{
{
name: "Default mount path",
env: map[string]string{},
wantMountPath: "/",
},
{
name: "Custom mount path",
env: map[string]string{
"FILE_MOUNT_PATH": "test/secrets",
},
wantMountPath: "test/secrets",
},
}

for _, tt := range tests {
ttp := tt
t.Run(ttp.name, func(t *testing.T) {
for envKey, envVal := range ttp.env {
os.Setenv(envKey, envVal)
}

config := NewConfig()
if config.MountPath != ttp.wantMountPath {
t.Errorf("NewConfig() = %v, wantMountPath %v", config.MountPath, ttp.wantMountPath)
}
})
}
}
14 changes: 7 additions & 7 deletions provider/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"fmt"
"io/fs"
"os"
"strings"

"github.com/bank-vaults/secret-init/provider"
Expand All @@ -29,10 +30,8 @@ type Provider struct {
fs fs.FS
}

func NewProvider(fs fs.FS) (provider.Provider, error) {
if fs == nil {
return nil, fmt.Errorf("file system is nil")
}
func NewProvider(config *Config) (provider.Provider, error) {
fs := os.DirFS(config.MountPath)

return &Provider{fs: fs}, nil
}
Expand All @@ -55,9 +54,10 @@ func (p *Provider) LoadSecrets(_ context.Context, paths []string) ([]provider.Se
return secrets, nil
}

func (p *Provider) getSecretFromFile(filepath string) (string, error) {
filepath = strings.TrimLeft(filepath, "/")
content, err := fs.ReadFile(p.fs, filepath)
func (p *Provider) getSecretFromFile(path string) (string, error) {
path = strings.TrimLeft(path, "/")
fmt.Println("file path:", path, "fs:", p.fs)
content, err := fs.ReadFile(p.fs, path)
if err != nil {
return "", fmt.Errorf("failed to read file: %w", err)
}
Expand Down
54 changes: 15 additions & 39 deletions provider/file/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package file

import (
"context"
"io/fs"
"testing"
"testing/fstest"

Expand All @@ -28,60 +27,39 @@ import (
func TestNewProvider(t *testing.T) {
tests := []struct {
name string
fs fs.FS
config *Config
wantErr bool
wantType bool
}{
{
name: "Valid file system",
fs: fstest.MapFS{
"test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")},
"test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")},
"test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")},
name: "Valid config",
config: &Config{
MountPath: "test/secrets",
},
wantErr: false,
wantType: true,
},
{
name: "Nil file system",
fs: nil,
wantErr: true,
wantType: false,
},
}

for _, tt := range tests {
ttp := tt
t.Run(ttp.name, func(t *testing.T) {
prov, err := NewProvider(ttp.fs)
if (err != nil) != ttp.wantErr {
t.Fatalf("NewProvider() error = %v, wantErr %v", err, ttp.wantErr)
return
}
// Use type assertion to check if the provider is of the correct type
_, ok := prov.(*Provider)
if ok != ttp.wantType {
t.Fatalf("NewProvider() = %v, wantType %v", ok, ttp.wantType)
}
provider, err := NewProvider(ttp.config)
assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status")
assert.Equal(t, ttp.wantType, provider != nil, "Unexpected provider type")
})
}
}

func TestLoadSecrets(t *testing.T) {
tests := []struct {
name string
fs fs.FS
paths []string
wantErr bool
wantData []provider.Secret
}{
{
name: "Load secrets successfully",
fs: fstest.MapFS{
"test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")},
"test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")},
"test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")},
},
paths: []string{
"test/secrets/sqlpass.txt",
"test/secrets/awsaccess.txt",
Expand All @@ -96,11 +74,6 @@ func TestLoadSecrets(t *testing.T) {
},
{
name: "Fail to load secrets due to invalid path",
fs: fstest.MapFS{
"test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")},
"test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")},
"test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")},
},
paths: []string{
"test/secrets/mistake/sqlpass.txt",
"test/secrets/mistake/awsaccess.txt",
Expand All @@ -114,12 +87,15 @@ func TestLoadSecrets(t *testing.T) {
for _, tt := range tests {
ttp := tt
t.Run(ttp.name, func(t *testing.T) {
provider, err := NewProvider(ttp.fs)
if assert.NoError(t, err, "Unexpected error") {
secrets, err := provider.LoadSecrets(context.Background(), ttp.paths)
assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status")
assert.ElementsMatch(t, ttp.wantData, secrets, "Unexpected secrets loaded")
fs := fstest.MapFS{
"test/secrets/sqlpass.txt": {Data: []byte("3xtr3ms3cr3t")},
"test/secrets/awsaccess.txt": {Data: []byte("s3cr3t")},
"test/secrets/awsid.txt": {Data: []byte("secretId")},
}
provider := Provider{fs: fs}
secrets, err := provider.LoadSecrets(context.Background(), ttp.paths)
assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status")
assert.Equal(t, ttp.wantData, secrets, "Unexpected secrets")
})
}
}

0 comments on commit 37a4c51

Please sign in to comment.