Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add provider selection logic, file provider and tests #17

Merged
merged 10 commits into from
Nov 29, 2023
83 changes: 83 additions & 0 deletions env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright © 2023 Bank-Vaults Maintainers
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"fmt"
"os"
"strings"

"github.com/bank-vaults/secret-init/provider"
"github.com/bank-vaults/secret-init/provider/file"
)

func GetEnvironMap() map[string]string {
environ := make(map[string]string, len(os.Environ()))
for _, env := range os.Environ() {
split := strings.SplitN(env, "=", 2)
name := split[0]
value := split[1]
environ[name] = value
}

return environ
}

func ExtractPathsFromEnvs(envs map[string]string) []string {
var secretPaths []string

for _, path := range envs {
if p, path := getProviderPath(path); p != nil {
secretPaths = append(secretPaths, path)
}
}

return secretPaths
}

func CreateSecretEnvsFrom(envs map[string]string, secrets []provider.Secret) ([]string, error) {
// Reverse the map so we can match
// the environment variable key to the secret
// by using the secret path
reversedEnvs := make(map[string]string)
for envKey, path := range envs {
if p, path := getProviderPath(path); p != nil {
reversedEnvs[path] = envKey
}
}

var secretsEnv []string
for _, secret := range secrets {
path := secret.Path
value := secret.Value
key, ok := reversedEnvs[path]
if !ok {
return nil, fmt.Errorf("failed to find environment variable key for secret path: %s", path)
}
secretsEnv = append(secretsEnv, fmt.Sprintf("%s=%s", key, value))
}

return secretsEnv, nil
}

// Returns the detected provider name and path with removed prefix
func getProviderPath(path string) (*string, string) {
if strings.HasPrefix(path, "file:") {
var fileProviderName = file.ProviderName
return &fileProviderName, strings.TrimPrefix(path, "file:")
}

return nil, path
}
7 changes: 7 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@ require (
github.com/spf13/cast v1.5.1
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

require (
github.com/samber/lo v1.38.1 // indirect
github.com/stretchr/testify v1.8.4
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect
)
10 changes: 10 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY=
github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
Expand All @@ -6,6 +8,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
Expand All @@ -16,5 +20,11 @@ github.com/samber/slog-syslog v1.0.0 h1:4tf8sNv9+qTQ6Fj8+N6U1ZEtUbqbAIzd+q26/Neg
github.com/samber/slog-syslog v1.0.0/go.mod h1:jjupk+yHPVSuXuGhKleoClYc/HEaC+Ro5X4YYeBrt6g=
github.com/spf13/cast v1.5.1 h1:R+kOtfhWQE6TVQzY+4D7wJLBgkdVasCEFxSUBYBYIlA=
github.com/spf13/cast v1.5.1/go.mod h1:b9PdjNptOpzXr7Rq1q9gJML/2cdGQAo69NKzQ10KN48=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
41 changes: 35 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,24 @@ import (
"github.com/spf13/cast"

"github.com/bank-vaults/secret-init/provider"
"github.com/bank-vaults/secret-init/provider/file"
)

func NewProvider(providerName string) (provider.Provider, error) {
csatib02 marked this conversation as resolved.
Show resolved Hide resolved
switch providerName {
case file.ProviderName:
provider, err := file.NewProvider(os.DirFS("/"))
if err != nil {
return nil, err
}

return provider, nil

default:
return nil, errors.New("invalid provider specified")
}
}

func main() {
var logger *slog.Logger
{
Expand Down Expand Up @@ -94,8 +110,12 @@ func main() {
slog.SetDefault(logger)
}

// TODO: enable providers
var provider provider.Provider
provider, err := NewProvider(os.Getenv("PROVIDER"))
if err != nil {
logger.Error(fmt.Errorf("failed to create provider: %w", err).Error())

os.Exit(1)
}

if len(os.Args) == 1 {
logger.Error("no command is given, vault-env can't determine the entrypoint (command), please specify it explicitly or let the webhook query it (see documentation)")
Expand All @@ -115,10 +135,19 @@ func main() {
os.Exit(1)
}

environ := GetEnvironMap()
paths := ExtractPathsFromEnvs(environ)

ctx := context.Background()
envs, err := provider.LoadSecrets(ctx, os.Environ())
secrets, err := provider.LoadSecrets(ctx, paths)
if err != nil {
logger.Error(fmt.Errorf("failed to load secrets from provider: %w", err).Error())

os.Exit(1)
}
secretsEnv, err := CreateSecretEnvsFrom(environ, secrets)
if err != nil {
logger.Error("could not retrieve secrets from the provider.", err)
logger.Error(fmt.Errorf("failed to create environment variables from loaded secrets: %w", err).Error())

os.Exit(1)
}
Expand All @@ -135,7 +164,7 @@ func main() {
if daemonMode {
logger.Info("in daemon mode...")
cmd := exec.Command(binary, entrypointCmd[1:]...)
cmd.Env = append(os.Environ(), envs...)
cmd.Env = append(os.Environ(), secretsEnv...)
cmd.Stdin = os.Stdin
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
Expand Down Expand Up @@ -184,7 +213,7 @@ func main() {

os.Exit(cmd.ProcessState.ExitCode())
}
err = syscall.Exec(binary, entrypointCmd, envs)
err = syscall.Exec(binary, entrypointCmd, secretsEnv)
if err != nil {
logger.Error(fmt.Errorf("failed to exec process: %w", err).Error(), slog.String("entrypoint", fmt.Sprint(entrypointCmd)))

Expand Down
66 changes: 66 additions & 0 deletions provider/file/file.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright © 2023 Bank-Vaults Maintainers
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package file

import (
"context"
"fmt"
"io/fs"
"strings"

"github.com/bank-vaults/secret-init/provider"
)

const ProviderName = "file"

type Provider struct {
fs fs.FS
}

func NewProvider(fs fs.FS) (provider.Provider, error) {
if fs == nil {
return nil, fmt.Errorf("file system is nil")
}

csatib02 marked this conversation as resolved.
Show resolved Hide resolved
return &Provider{fs: fs}, nil
}

func (p *Provider) LoadSecrets(_ context.Context, paths []string) ([]provider.Secret, error) {
var secrets []provider.Secret

for _, path := range paths {
secret, err := p.getSecretFromFile(path)
if err != nil {
return nil, fmt.Errorf("failed to get secret from file: %w", err)
}

secrets = append(secrets, provider.Secret{
Path: path,
Value: secret,
})
}

return secrets, nil
}

func (p *Provider) getSecretFromFile(filepath string) (string, error) {
filepath = strings.TrimLeft(filepath, "/")
content, err := fs.ReadFile(p.fs, filepath)
if err != nil {
return "", fmt.Errorf("failed to read file: %w", err)
}

return string(content), nil
}
125 changes: 125 additions & 0 deletions provider/file/file_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright © 2023 Bank-Vaults Maintainers
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package file

import (
"context"
"io/fs"
"testing"
"testing/fstest"

"github.com/stretchr/testify/assert"

"github.com/bank-vaults/secret-init/provider"
)

func TestNewProvider(t *testing.T) {
tests := []struct {
name string
fs fs.FS
wantErr bool
wantType bool
}{
{
name: "Valid file system",
fs: fstest.MapFS{
"test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")},
"test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")},
"test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")},
},
wantErr: false,
wantType: true,
},
{
name: "Nil file system",
fs: nil,
wantErr: true,
wantType: false,
},
}

for _, tt := range tests {
ttp := tt
t.Run(ttp.name, func(t *testing.T) {
prov, err := NewProvider(ttp.fs)
if (err != nil) != ttp.wantErr {
t.Fatalf("NewProvider() error = %v, wantErr %v", err, ttp.wantErr)
return
}
// Use type assertion to check if the provider is of the correct type
_, ok := prov.(*Provider)
if ok != ttp.wantType {
t.Fatalf("NewProvider() = %v, wantType %v", ok, ttp.wantType)
}
})
}
}

func TestLoadSecrets(t *testing.T) {
tests := []struct {
name string
fs fs.FS
paths []string
wantErr bool
wantData []provider.Secret
}{
{
name: "Load secrets successfully",
fs: fstest.MapFS{
"test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")},
"test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")},
"test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")},
},
paths: []string{
"test/secrets/sqlpass.txt",
"test/secrets/awsaccess.txt",
"test/secrets/awsid.txt",
},
wantErr: false,
wantData: []provider.Secret{
{Path: "test/secrets/sqlpass.txt", Value: "3xtr3ms3cr3t"},
{Path: "test/secrets/awsaccess.txt", Value: "s3cr3t"},
{Path: "test/secrets/awsid.txt", Value: "secretId"},
},
},
{
name: "Fail to load secrets due to invalid path",
fs: fstest.MapFS{
"test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")},
"test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")},
"test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")},
},
paths: []string{
"test/secrets/mistake/sqlpass.txt",
"test/secrets/mistake/awsaccess.txt",
"test/secrets/mistake/awsid.txt",
},
wantErr: true,
wantData: nil,
},
}

for _, tt := range tests {
ttp := tt
t.Run(ttp.name, func(t *testing.T) {
provider, err := NewProvider(ttp.fs)
if assert.NoError(t, err, "Unexpected error") {
secrets, err := provider.LoadSecrets(context.Background(), ttp.paths)
assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status")
assert.ElementsMatch(t, ttp.wantData, secrets, "Unexpected secrets loaded")
}
})
}
}
Loading