diff --git a/cache/cache.go b/cache/cache.go index ae26b7b..d24ea61 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -1,113 +1,41 @@ -/* - * Copyright 2020 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - package cache import ( - "fmt" "strings" - "sync" "github.com/netflix/weep/creds" - "github.com/netflix/weep/errors" log "github.com/sirupsen/logrus" + "github.com/spf13/viper" ) -var GlobalCache CredentialCache - -type CredentialCache struct { - RoleCredentials map[string]*creds.RefreshableProvider - DefaultRole string - mu sync.RWMutex -} +var GlobalCache WeepCache func init() { - GlobalCache = CredentialCache{ - RoleCredentials: make(map[string]*creds.RefreshableProvider), - } -} - -// getCacheSlug returns a string unique to a particular combination of a role and chain of roles to assume. -func getCacheSlug(role string, assume []string) string { + var err error + cacheType := viper.GetString("cache.type") + log.Debugf("initializing %s cache", cacheType) + switch cacheType { + case "memory": + GlobalCache = NewMemoryCache() + case "file": + GlobalCache, err = NewFileCache() + if err != nil { + log.Fatalf("failed to initialize file cache: %v", err) + } + default: + log.Fatal("invalid cache type specified") + } +} + +type WeepCache interface { + Get(role string, assumeChain []string) (*creds.RefreshableProvider, error) + GetOrSet(client *creds.Client, role string, region string, assumeChain []string) (*creds.RefreshableProvider, error) + SetDefault(client *creds.Client, role string, region string, assumeChain []string) error + GetDefault() (*creds.RefreshableProvider, error) +} + +// getSlug returns a string unique to a particular combination of a role and chain of roles to assume. +func getSlug(role string, assume []string) string { elements := append([]string{role}, assume...) return strings.Join(elements, "/") } - -func (cc *CredentialCache) Get(role string, assumeChain []string) (*creds.RefreshableProvider, error) { - log.WithFields(log.Fields{ - "role": role, - "assumeChain": assumeChain, - }).Info("retrieving credentials") - c, ok := cc.get(getCacheSlug(role, assumeChain)) - if ok { - log.Debugf("found credentials for %s in cache", role) - return c, nil - } - return nil, errors.NoCredentialsFoundInCache -} - -func (cc *CredentialCache) GetOrSet(client *creds.Client, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) { - c, err := cc.Get(role, assumeChain) - if err == nil { - return c, nil - } - log.Debugf("no credentials for %s in cache, creating", role) - - c, err = cc.set(client, role, region, assumeChain) - if err != nil { - return nil, err - } - - return c, nil -} - -func (cc *CredentialCache) SetDefault(client *creds.Client, role, region string, assumeChain []string) error { - _, err := cc.set(client, role, region, assumeChain) - if err != nil { - return err - } - cc.DefaultRole = getCacheSlug(role, assumeChain) - return nil -} - -func (cc *CredentialCache) GetDefault() (*creds.RefreshableProvider, error) { - if cc.DefaultRole == "" { - return nil, errors.NoDefaultRoleSet - } - c, ok := cc.get(cc.DefaultRole) - if ok { - return c, nil - } - return nil, errors.NoCredentialsFoundInCache -} - -func (cc *CredentialCache) get(slug string) (*creds.RefreshableProvider, bool) { - cc.mu.RLock() - defer cc.mu.RUnlock() - c, ok := cc.RoleCredentials[slug] - return c, ok -} - -func (cc *CredentialCache) set(client *creds.Client, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) { - c, err := creds.NewRefreshableProvider(client, role, region, assumeChain, false) - if err != nil { - return nil, fmt.Errorf("could not generate creds: %w", err) - } - cc.mu.Lock() - defer cc.mu.Unlock() - cc.RoleCredentials[getCacheSlug(role, assumeChain)] = c - return c, nil -} diff --git a/cache/file.go b/cache/file.go new file mode 100644 index 0000000..8396fe8 --- /dev/null +++ b/cache/file.go @@ -0,0 +1,128 @@ +package cache + +import ( + "encoding/json" + "fmt" + + "github.com/boltdb/bolt" + "github.com/netflix/weep/creds" + "github.com/netflix/weep/errors" + log "github.com/sirupsen/logrus" +) + +const BUCKET = "credentials" + +type FileDB struct { + db *bolt.DB +} + +func NewFileCache() (*FileDB, error) { + db, err := bolt.Open("weep.db", 0600, nil) + if err != nil { + return nil, err + } + + fdb := &FileDB{ + db: db, + } + err = fdb.setup() + if err != nil { + return nil, err + } + + return fdb, nil +} + +func (f *FileDB) setup() error { + err := f.db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucketIfNotExists([]byte(BUCKET)) + if err != nil { + return err + } + return nil + }) + return err +} + +func (f *FileDB) Get(role string, assumeChain []string) (*creds.RefreshableProvider, error) { + log.WithFields(log.Fields{ + "role": role, + "assumeChain": assumeChain, + "cacheType": "file", + }).Info("retrieving credentials") + c, err := f.get(getSlug(role, assumeChain)) + if err != nil { + return nil, errors.NoCredentialsFoundInCache + } + return c, nil +} + +func (f *FileDB) GetOrSet(client *creds.Client, role string, region string, assumeChain []string) (*creds.RefreshableProvider, error) { + c, err := f.Get(role, assumeChain) + if err == nil { + return c, nil + } + log.Debugf("no credentials for %s in cache, creating", role) + + c, err = f.set(client, role, region, assumeChain) + if err != nil { + return nil, err + } + + return c, nil +} + +func (f *FileDB) SetDefault(client *creds.Client, role string, region string, assumeChain []string) error { + // TODO + return nil +} + +func (f *FileDB) GetDefault() (*creds.RefreshableProvider, error) { + // TODO + return nil, nil +} + +func (f *FileDB) get(slug string) (*creds.RefreshableProvider, error) { + credentials := &creds.RefreshableProvider{} + err := f.db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(BUCKET)) + result := b.Get([]byte(slug)) + err := json.Unmarshal(result, credentials) + if err != nil { + return nil + } + return nil + }) + if err != nil { + return credentials, err + } + err = credentials.EnsureRefreshed() + if err != nil { + return credentials, err + } + return credentials, nil +} + +func (f *FileDB) set(client *creds.Client, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) { + c, err := creds.NewRefreshableProvider(client, role, region, assumeChain, false) + if err != nil { + return nil, fmt.Errorf("could not generate creds: %w", err) + } + data, err := json.Marshal(c) + slug := getSlug(role, assumeChain) + if err != nil { + return nil, fmt.Errorf("could not marshal creds: %w", err) + } + err = f.db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(BUCKET)) + err := b.Put([]byte(slug), data) + if err != nil { + return err + } + return nil + }) + if err != nil { + return nil, err + } + return c, nil +} diff --git a/cache/memory.go b/cache/memory.go new file mode 100644 index 0000000..b6148b2 --- /dev/null +++ b/cache/memory.go @@ -0,0 +1,88 @@ +package cache + +import ( + "fmt" + "sync" + + "github.com/netflix/weep/creds" + "github.com/netflix/weep/errors" + log "github.com/sirupsen/logrus" +) + +type InMemory struct { + RoleCredentials map[string]*creds.RefreshableProvider + DefaultRole string + mu sync.RWMutex +} + +func NewMemoryCache() *InMemory { + return &InMemory{ + RoleCredentials: make(map[string]*creds.RefreshableProvider), + } +} + +func (cc *InMemory) Get(role string, assumeChain []string) (*creds.RefreshableProvider, error) { + log.WithFields(log.Fields{ + "role": role, + "assumeChain": assumeChain, + }).Info("retrieving credentials") + c, ok := cc.get(getSlug(role, assumeChain)) + if ok { + log.Debugf("found credentials for %s in cache", role) + return c, nil + } + return nil, errors.NoCredentialsFoundInCache +} + +func (cc *InMemory) GetOrSet(client *creds.Client, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) { + c, err := cc.Get(role, assumeChain) + if err == nil { + return c, nil + } + log.Debugf("no credentials for %s in cache, creating", role) + + c, err = cc.set(client, role, region, assumeChain) + if err != nil { + return nil, err + } + + return c, nil +} + +func (cc *InMemory) SetDefault(client *creds.Client, role, region string, assumeChain []string) error { + _, err := cc.set(client, role, region, assumeChain) + if err != nil { + return err + } + cc.DefaultRole = getSlug(role, assumeChain) + return nil +} + +func (cc *InMemory) GetDefault() (*creds.RefreshableProvider, error) { + if cc.DefaultRole == "" { + return nil, errors.NoDefaultRoleSet + } + c, ok := cc.get(cc.DefaultRole) + if ok { + return c, nil + } + return nil, errors.NoCredentialsFoundInCache +} + +func (cc *InMemory) get(slug string) (*creds.RefreshableProvider, bool) { + cc.mu.RLock() + defer cc.mu.RUnlock() + c, ok := cc.RoleCredentials[slug] + return c, ok +} + +func (cc *InMemory) set(client *creds.Client, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) { + c, err := creds.NewRefreshableProvider(client, role, region, assumeChain, false) + if err != nil { + return nil, fmt.Errorf("could not generate creds: %w", err) + } + cc.mu.Lock() + defer cc.mu.Unlock() + cc.RoleCredentials[getSlug(role, assumeChain)] = c + return c, nil +} diff --git a/cache/cache_test.go b/cache/memory_test.go similarity index 98% rename from cache/cache_test.go rename to cache/memory_test.go index 8fffc66..4ec63a7 100644 --- a/cache/cache_test.go +++ b/cache/memory_test.go @@ -87,7 +87,7 @@ func TestCredentialCache_Get(t *testing.T) { for i, tc := range cases { t.Logf("test case %d: %s", i, tc.Description) - testCache := CredentialCache{ + testCache := InMemory{ RoleCredentials: tc.CacheContents, } actualResult, actualError := testCache.Get(tc.Role, tc.AssumeChain) @@ -167,7 +167,7 @@ func TestCredentialCache_GetDefault(t *testing.T) { for i, tc := range cases { t.Logf("test case %d: %s", i, tc.Description) - testCache := CredentialCache{ + testCache := InMemory{ RoleCredentials: tc.CacheContents, DefaultRole: tc.DefaultRole, } @@ -183,7 +183,7 @@ func TestCredentialCache_GetDefault(t *testing.T) { } func TestCredentialCache_SetDefault(t *testing.T) { - testCache := CredentialCache{ + testCache := InMemory{ RoleCredentials: map[string]*creds.RefreshableProvider{}, } expectedRole := "a" @@ -255,7 +255,7 @@ func TestCredentialCache_GetOrSet(t *testing.T) { for i, tc := range cases { t.Logf("test case %d: %s", i, tc.Description) - testCache := CredentialCache{ + testCache := InMemory{ RoleCredentials: tc.CacheContents, } client, err := creds.GetTestClient(creds.ConsolemeCredentialResponseType{ diff --git a/config/config.go b/config/config.go index 1c42068..a6d2eb2 100644 --- a/config/config.go +++ b/config/config.go @@ -31,6 +31,7 @@ import ( func init() { // Set default configuration values here viper.SetTypeByDefaultValue(true) + viper.SetDefault("cache.type", "memory") viper.SetDefault("log_file", getDefaultLogFile()) viper.SetDefault("mtls_settings.old_cert_message", "mTLS certificate is too old, please refresh mtls certificate") viper.SetDefault("server.http_timeout", 20) diff --git a/creds/refreshable.go b/creds/refreshable.go index 281db11..5b57c7e 100644 --- a/creds/refreshable.go +++ b/creds/refreshable.go @@ -62,6 +62,11 @@ func (rp *RefreshableProvider) AutoRefresh() { } } +func (rp *RefreshableProvider) EnsureRefreshed() error { + _, err := rp.checkAndRefresh(10) + return err +} + func (rp *RefreshableProvider) checkAndRefresh(threshold int) (bool, error) { log.Debugf("checking credentials for %s", rp.Role) // refresh creds if we're within 10 minutes of them expiring @@ -103,14 +108,14 @@ func (rp *RefreshableProvider) refresh() error { } rp.Expiration = newCreds.Expiration - rp.value.AccessKeyID = newCreds.AccessKeyId - rp.value.SessionToken = newCreds.SessionToken - rp.value.SecretAccessKey = newCreds.SecretAccessKey - rp.value.AccessKeyID = newCreds.AccessKeyId + rp.Value.AccessKeyID = newCreds.AccessKeyId + rp.Value.SessionToken = newCreds.SessionToken + rp.Value.SecretAccessKey = newCreds.SecretAccessKey + rp.Value.AccessKeyID = newCreds.AccessKeyId rp.LastRefreshed = Time(time.Now()) rp.RoleArn = newCreds.RoleArn - if rp.value.ProviderName == "" { - rp.value.ProviderName = "WeepRefreshableProvider" + if rp.Value.ProviderName == "" { + rp.Value.ProviderName = "WeepRefreshableProvider" } log.Debugf("successfully refreshed credentials for %s", rp.Role) return nil @@ -120,7 +125,7 @@ func (rp *RefreshableProvider) refresh() error { func (rp *RefreshableProvider) Retrieve() (credentials.Value, error) { rp.mu.RLock() defer rp.mu.RUnlock() - return rp.value, nil + return rp.Value, nil } // IsExpired always returns false because we should never have expired credentials diff --git a/creds/refreshable_test.go b/creds/refreshable_test.go index 17bfd2a..ec5ca5b 100644 --- a/creds/refreshable_test.go +++ b/creds/refreshable_test.go @@ -200,8 +200,8 @@ func TestRefreshableProvider_refresh(t *testing.T) { AssumeChain: tc.AssumeChain, } // pre-refresh checks - if rp.value.SessionToken != "" || rp.value.AccessKeyID != "" || rp.value.SecretAccessKey != "" || rp.value.ProviderName != "" { - t.Errorf("%s failed: credential values should not exist: %v", tc.Description, rp.value) + if rp.Value.SessionToken != "" || rp.Value.AccessKeyID != "" || rp.Value.SecretAccessKey != "" || rp.Value.ProviderName != "" { + t.Errorf("%s failed: credential values should not exist: %v", tc.Description, rp.Value) continue } if rp.Expiration != zeroTime { @@ -216,8 +216,8 @@ func TestRefreshableProvider_refresh(t *testing.T) { } else { continue } - if rp.value.SessionToken == "" || rp.value.AccessKeyID == "" || rp.value.SecretAccessKey == "" || rp.value.ProviderName == "" { - t.Errorf("%s failed: credential values should not be empty: %v", tc.Description, rp.value) + if rp.Value.SessionToken == "" || rp.Value.AccessKeyID == "" || rp.Value.SecretAccessKey == "" || rp.Value.ProviderName == "" { + t.Errorf("%s failed: credential values should not be empty: %v", tc.Description, rp.Value) } if rp.Expiration == zeroTime { t.Errorf("%s failed: Expiration should be set, got %v", tc.Description, rp.Expiration) @@ -318,7 +318,7 @@ func TestRefreshableProvider_Retrieve(t *testing.T) { } rp := RefreshableProvider{ - value: expected, + Value: expected, } result, err := rp.Retrieve() diff --git a/creds/types.go b/creds/types.go index ea7956b..c04090e 100644 --- a/creds/types.go +++ b/creds/types.go @@ -33,7 +33,7 @@ type AwsCredentials struct { } type RefreshableProvider struct { - value credentials.Value + Value credentials.Value mu sync.RWMutex client *Client retries int diff --git a/go.mod b/go.mod index 9d6ba1c..9aebe7f 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.15 require ( github.com/aws/aws-sdk-go v1.36.7 + github.com/boltdb/bolt v1.3.1 github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/gobuffalo/here v0.6.2 // indirect github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b diff --git a/go.sum b/go.sum index 49a3d49..08320f9 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQ github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c h1:+0HFd5KSZ/mm3JmhmrDukiId5iR6w4+BdFtfSy4yWIc= github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJmJgSg28kpZDP6UIiPt0e0Oz0kqKNGyRaWEPv84= +github.com/boltdb/bolt v1.3.1 h1:JQmyP4ZBrce+ZQu0dY660FMfatumYDLun9hBCUVIkF4= +github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps= github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/chzyer/logex v1.1.10 h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE=