From 0f37ccd408eee6258246888c5990de75a24ba00d Mon Sep 17 00:00:00 2001 From: Jay Conrod Date: Wed, 27 Nov 2024 07:54:40 -0800 Subject: [PATCH] REC-110: refactor to remove CacheAlert (#58) This is the beginning of a refactoring to move some LoadStorer implementations into appState. Although CacheAlert satisfied the LoadStorer interface type, it was a decorator that didn't actually load or store any tokens or provide useful or meaningful abstraction. It seems better to squash this into appState and remove unnecessary abstraction. This also makes the test a little more realistic. Also refactored fake token generation to make testing easier. --- cmd/engflow_auth/BUILD | 6 +- cmd/engflow_auth/main.go | 24 ++++---- cmd/engflow_auth/main_test.go | 48 +++++++++------ cmd/engflow_auth/tokens.go | 74 +++++++++++++++++++++++ internal/oauthtoken/BUILD | 3 - internal/oauthtoken/cache_alert.go | 69 --------------------- internal/oauthtoken/cache_alert_test.go | 80 ------------------------- internal/oauthtoken/fake.go | 28 +++++++++ 8 files changed, 149 insertions(+), 183 deletions(-) create mode 100644 cmd/engflow_auth/tokens.go delete mode 100644 internal/oauthtoken/cache_alert.go delete mode 100644 internal/oauthtoken/cache_alert_test.go diff --git a/cmd/engflow_auth/BUILD b/cmd/engflow_auth/BUILD index 78e500b..73d7769 100644 --- a/cmd/engflow_auth/BUILD +++ b/cmd/engflow_auth/BUILD @@ -3,7 +3,10 @@ load("//infra:visibility.bzl", "RELEASE_ARTIFACT") go_library( name = "engflow_auth_lib", - srcs = ["main.go"], + srcs = [ + "main.go", + "tokens.go", + ], importpath = "github.com/EngFlow/auth/cmd/engflow_auth", visibility = ["//visibility:private"], deps = [ @@ -13,6 +16,7 @@ go_library( "//internal/oauthdevice", "//internal/oauthtoken", "@com_github_engflow_credential_helper_go//:credential-helper-go", + "@com_github_golang_jwt_jwt_v5//:jwt", "@com_github_urfave_cli_v2//:cli", "@org_golang_x_oauth2//:oauth2", ], diff --git a/cmd/engflow_auth/main.go b/cmd/engflow_auth/main.go index be4c678..641df74 100644 --- a/cmd/engflow_auth/main.go +++ b/cmd/engflow_auth/main.go @@ -19,6 +19,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "io/fs" "net" "net/url" @@ -53,6 +54,7 @@ type appState struct { browserOpener browser.Opener authenticator oauthdevice.Authenticator tokenStore oauthtoken.LoadStorer + stderr io.Writer } type ExportedToken struct { @@ -112,15 +114,11 @@ func (r *appState) build(cliCtx *cli.Context) error { return autherr.CodedErrorf(autherr.CodeBadParams, "unknown token store type %q", writeStoreName) } - r.tokenStore = - oauthtoken.NewCacheAlert( - oauthtoken.NewFallback( - /* gets Store() operations */ writeStore, - /* gets Load() operations */ keyring, fileStore, - ), - cliCtx.App.ErrWriter, - ) + r.tokenStore = oauthtoken.NewFallback( + /* gets Store() operations */ writeStore, + /* gets Load() operations */ keyring, fileStore) } + r.stderr = cliCtx.App.ErrWriter return nil } @@ -136,7 +134,7 @@ func (r *appState) get(cliCtx *cli.Context) error { if err != nil { return autherr.CodedErrorf(autherr.CodeBadParams, "failed to parse cluster URL %q from request: %w", req.URI, err) } - token, err := r.tokenStore.Load(clusterURL.Host) + token, err := r.loadToken(clusterURL.Host) if err != nil { return autherr.ReauthRequired(clusterURL.Host) } @@ -165,7 +163,7 @@ func (r *appState) export(cliCtx *cli.Context) error { return autherr.CodedErrorf(autherr.CodeBadParams, "invalid cluster: %w", err) } - token, err := r.tokenStore.Load(clusterURL.Host) + token, err := r.loadToken(clusterURL.Host) if err != nil { if reauthErr := (*autherr.CodedError)(nil); errors.As(err, &reauthErr) && reauthErr.Code == autherr.CodeReauthRequired { return reauthErr @@ -207,7 +205,7 @@ func (r *appState) import_(cliCtx *cli.Context) error { var storeErrs []error for _, storeURL := range storeURLs { - if err := r.tokenStore.Store(storeURL.Host, token.Token); err != nil { + if err := r.storeToken(storeURL.Host, token.Token); err != nil { storeErrs = append(storeErrs, fmt.Errorf("failed to save token for host %q: %w", storeURL.Host, err)) } } @@ -289,7 +287,7 @@ Visit %s for help.`, var storeErrs []error for _, storeURL := range storeURLs { - if err := r.tokenStore.Store(storeURL.Host, token); err != nil { + if err := r.storeToken(storeURL.Host, token); err != nil { storeErrs = append(storeErrs, fmt.Errorf("failed to save token for host %q: %w", storeURL.Host, err)) } } @@ -323,7 +321,7 @@ func (r *appState) logout(cliCtx *cli.Context) error { return autherr.CodedErrorf(autherr.CodeBadParams, "invalid cluster: %w", err) } - if err := r.tokenStore.Delete(clusterURL.Host); errors.Is(err, fs.ErrNotExist) { + if err := r.deleteToken(clusterURL.Host); errors.Is(err, fs.ErrNotExist) { return &autherr.CodedError{Code: autherr.CodeBadParams, Err: fmt.Errorf("no credentials found for cluster %q", clusterURL.Host)} } else if err != nil { return &autherr.CodedError{Code: autherr.CodeTokenStoreFailure, Err: err} diff --git a/cmd/engflow_auth/main_test.go b/cmd/engflow_auth/main_test.go index c73324f..e116144 100644 --- a/cmd/engflow_auth/main_test.go +++ b/cmd/engflow_auth/main_test.go @@ -83,17 +83,18 @@ func codedErrorContains(t *testing.T, gotErr error, code int, wantMsg string) bo } type fakeAuth struct { - res *oauth2.DeviceAuthResponse - fetchCodeErr error - fetchTokenErr error + deviceResponse *oauth2.DeviceAuthResponse + token *oauth2.Token + fetchCodeErr error + fetchTokenErr error } func (f *fakeAuth) FetchCode(ctx context.Context, authEndpint *oauth2.Endpoint) (*oauth2.DeviceAuthResponse, error) { - return f.res, f.fetchCodeErr + return f.deviceResponse, f.fetchCodeErr } func (f *fakeAuth) FetchToken(ctx context.Context, authRes *oauth2.DeviceAuthResponse) (*oauth2.Token, error) { - return nil, f.fetchTokenErr + return f.token, f.fetchTokenErr } type fakeBrowser struct { @@ -220,7 +221,7 @@ func TestRun(t *testing.T) { desc: "login happy path", args: []string{"login", "cluster.example.com"}, authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ + deviceResponse: &oauth2.DeviceAuthResponse{ VerificationURIComplete: "https://cluster.example.com/with/auth/code", }, }, @@ -229,7 +230,7 @@ func TestRun(t *testing.T) { desc: "login with alias", args: []string{"login", "--alias", "cluster.local.example.com", "cluster.example.com"}, authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ + deviceResponse: &oauth2.DeviceAuthResponse{ VerificationURIComplete: "https://cluster.example.com/with/auth/code", }, }, @@ -243,7 +244,7 @@ func TestRun(t *testing.T) { desc: "login with alias with store errors", args: []string{"login", "--alias", "cluster.local.example.com", "cluster.example.com"}, authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ + deviceResponse: &oauth2.DeviceAuthResponse{ VerificationURIComplete: "https://cluster.example.com/with/auth/code", }, }, @@ -257,7 +258,7 @@ func TestRun(t *testing.T) { desc: "login with host and port", args: []string{"login", "cluster.example.com:8080"}, authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ + deviceResponse: &oauth2.DeviceAuthResponse{ VerificationURIComplete: "https://cluster.example.com:8080/with/auth/code", }, }, @@ -272,7 +273,7 @@ func TestRun(t *testing.T) { desc: "login code fetch failure", args: []string{"login", "cluster.example.com"}, authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ + deviceResponse: &oauth2.DeviceAuthResponse{ VerificationURIComplete: "https://cluster.example.com/with/auth/code", }, fetchCodeErr: errors.New("fetch_code_fail"), @@ -284,7 +285,7 @@ func TestRun(t *testing.T) { desc: "login code fetch RetrieveError", args: []string{"login", "cluster.example.com"}, authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ + deviceResponse: &oauth2.DeviceAuthResponse{ VerificationURIComplete: "https://cluster.example.com/with/auth/code", }, fetchCodeErr: &oauth2.RetrieveError{}, @@ -296,7 +297,7 @@ func TestRun(t *testing.T) { desc: "login code fetch unexpected HTML", args: []string{"login", "cluster.example.com"}, authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ + deviceResponse: &oauth2.DeviceAuthResponse{ VerificationURIComplete: "https://cluster.example.com/with/auth/code", }, fetchCodeErr: autherr.UnexpectedHTML, @@ -308,7 +309,7 @@ func TestRun(t *testing.T) { desc: "login browser open failure", args: []string{"login", "cluster.example.com"}, authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ + deviceResponse: &oauth2.DeviceAuthResponse{ VerificationURIComplete: "https://cluster.example.com/with/auth/code", }, }, @@ -322,7 +323,7 @@ func TestRun(t *testing.T) { desc: "login token fetch failure", args: []string{"login", "cluster.example.com"}, authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ + deviceResponse: &oauth2.DeviceAuthResponse{ VerificationURIComplete: "https://cluster.example.com/with/auth/code", }, fetchTokenErr: errors.New("fetch_token_fail"), @@ -334,7 +335,7 @@ func TestRun(t *testing.T) { desc: "login token store failure", args: []string{"login", "cluster.example.com"}, authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ + deviceResponse: &oauth2.DeviceAuthResponse{ VerificationURIComplete: "https://cluster.example.com/with/auth/code", }, }, @@ -348,7 +349,7 @@ func TestRun(t *testing.T) { desc: "login with file-backed token storage", args: []string{"login", "--store=file", "cluster.example.com"}, authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ + deviceResponse: &oauth2.DeviceAuthResponse{ VerificationURIComplete: "https://cluster.example.com/with/auth/code", }, }, @@ -357,7 +358,7 @@ func TestRun(t *testing.T) { desc: "login with keyring-backed token storage", args: []string{"login", "--store=keyring", "cluster.example.com"}, authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ + deviceResponse: &oauth2.DeviceAuthResponse{ VerificationURIComplete: "https://cluster.example.com/with/auth/code", }, }, @@ -374,6 +375,19 @@ func TestRun(t *testing.T) { wantCode: autherr.CodeBadParams, wantErr: "flag provided but not defined", }, + { + desc: "login with changed subject", + args: []string{"login", "cluster.example.com"}, + tokenStore: oauthtoken.NewFakeTokenStore().WithTokenForSubject( + "cluster.example.com", "alice"), + authenticator: &fakeAuth{ + deviceResponse: &oauth2.DeviceAuthResponse{ + VerificationURIComplete: "https://cluster.example.com/with/auth/code", + }, + token: oauthtoken.NewFakeTokenForSubject("bob"), + }, + wantStderrContaining: []string{"Login identity has changed"}, + }, { desc: "logout without cluster", args: []string{"logout"}, diff --git a/cmd/engflow_auth/tokens.go b/cmd/engflow_auth/tokens.go new file mode 100644 index 0000000..8f86f2e --- /dev/null +++ b/cmd/engflow_auth/tokens.go @@ -0,0 +1,74 @@ +// Copyright 2024 EngFlow Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" + + "github.com/golang-jwt/jwt/v5" + "golang.org/x/oauth2" +) + +// loadToken loads a token for the given cluster or returns an error equivalent +// to fs.ErrNotExist if the token is not found in any store. +// +// loadToken may contain logic specific to this app and should be called +// by commands instead of calling LoadStorer.Load directly. +func (r *appState) loadToken(cluster string) (*oauth2.Token, error) { + return r.tokenStore.Load(cluster) +} + +// storeToken stores a token for the given cluster in one of the backends. +// +// storeToken may contain logic specific to this app and should be called +// by commands instead of calling LoadStorer.Store directly. For example, +// storeToken prints a message if the token's subject has changed. +func (r *appState) storeToken(cluster string, token *oauth2.Token) error { + oldToken, err := r.loadToken(cluster) + if err == nil { + r.warnIfSubjectChanged(cluster, oldToken, token) + } + return r.tokenStore.Store(cluster, token) +} + +// warnIfSubjectChanged prints a warning on stderr if the new token belongs to +// a different user than the previously stored token. The user is reminded to +// shutdown Bazel since it caches tokens in memory to avoid running actions +// with the old credential, which is probably still valid. +func (r *appState) warnIfSubjectChanged(cluster string, oldToken, newToken *oauth2.Token) { + // Disable claims validation, since expired tokens should be allowed to + // parse. + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + oldClaims, newClaims := &jwt.RegisteredClaims{}, &jwt.RegisteredClaims{} + // Unverified parsing, since issuing a warning vs. not is not a security + // concern. + if _, _, err := parser.ParseUnverified(oldToken.AccessToken, oldClaims); err != nil { + return + } + if _, _, err := parser.ParseUnverified(newToken.AccessToken, newClaims); err != nil { + return + } + if oldClaims.Subject != newClaims.Subject { + fmt.Fprintf(r.stderr, "WARNING: Login identity has changed since last login to %q.\nPlease run `bazel shutdown` in current workspaces in order to ensure bazel picks up new credentials.\n", cluster) + } +} + +// deleteToken removes a token from all of the backends. +// +// deleteToken may contain logic specific to this app and should be called +// by commands instead of calling LoadStorer.Delete directly. +func (r *appState) deleteToken(cluster string) error { + return r.tokenStore.Delete(cluster) +} diff --git a/internal/oauthtoken/BUILD b/internal/oauthtoken/BUILD index 933dbab..44a7745 100644 --- a/internal/oauthtoken/BUILD +++ b/internal/oauthtoken/BUILD @@ -3,7 +3,6 @@ load("@rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "oauthtoken", srcs = [ - "cache_alert.go", "debug.go", "fake.go", "fallback.go", @@ -23,14 +22,12 @@ go_library( go_test( name = "oauthtoken_test", srcs = [ - "cache_alert_test.go", "fallback_test.go", "keyring_test.go", "load_storer_test.go", ], embed = [":oauthtoken"], deps = [ - "@com_github_golang_jwt_jwt_v5//:jwt", "@com_github_google_uuid//:uuid", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", diff --git a/internal/oauthtoken/cache_alert.go b/internal/oauthtoken/cache_alert.go deleted file mode 100644 index fdd044f..0000000 --- a/internal/oauthtoken/cache_alert.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2024 EngFlow Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package oauthtoken - -import ( - "fmt" - "io" - - "github.com/golang-jwt/jwt/v5" - "golang.org/x/oauth2" -) - -// CacheAlert is a tokenLoadStorer that detects when a token's subject for -// a different cluster is changing, and produces a warning over an appropriate -// communication channel. -type CacheAlert struct { - LoadStorer - stderr io.Writer -} - -func NewCacheAlert(impl LoadStorer, stderr io.Writer) LoadStorer { - return &CacheAlert{ - LoadStorer: impl, - stderr: stderr, - } -} - -func (a *CacheAlert) Store(cluster string, token *oauth2.Token) error { - oldToken, err := a.Load(cluster) - if err != nil { - // Failed to fetch any sort of previous valid token. Defer to the - // wrapped implementation; we'll assume that the token didn't exist - // previously (and therefore no need to issue a warning). - return a.LoadStorer.Store(cluster, token) - } - - // Disable claims validation, since expired tokens should be allowed to - // parse. - parser := jwt.NewParser(jwt.WithoutClaimsValidation()) - oldClaims, newClaims := &jwt.RegisteredClaims{}, &jwt.RegisteredClaims{} - // Unverified parsing, since issuing a warning vs. not is not a security - // concern. - _, _, err = parser.ParseUnverified(oldToken.AccessToken, oldClaims) - if err != nil { - return a.LoadStorer.Store(cluster, token) - } - _, _, err = parser.ParseUnverified(token.AccessToken, newClaims) - if err != nil { - return a.LoadStorer.Store(cluster, token) - } - - if oldClaims.Subject != newClaims.Subject { - fmt.Fprintf(a.stderr, "WARNING: Login identity has changed since last login to %q.\nPlease run `bazel shutdown` in current workspaces in order to ensure bazel picks up new credentials.\n", cluster) - } - - return a.LoadStorer.Store(cluster, token) -} diff --git a/internal/oauthtoken/cache_alert_test.go b/internal/oauthtoken/cache_alert_test.go deleted file mode 100644 index c63386b..0000000 --- a/internal/oauthtoken/cache_alert_test.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2024 EngFlow Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package oauthtoken - -import ( - "bytes" - "testing" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/oauth2" -) - -func mustTokenForSubject(t *testing.T, name string) string { - t.Helper() - now := time.Now() - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - Issuer: "engflow unit tests", - Subject: name, - Audience: nil, - ExpiresAt: jwt.NewNumericDate(now.Add(time.Minute)), - NotBefore: jwt.NewNumericDate(now), - IssuedAt: jwt.NewNumericDate(now), - }) - tokenStr, err := token.SignedString([]byte("some signing key")) - require.NoError(t, err) - - return tokenStr -} - -func TestTokenCacheWarning(t *testing.T) { - var testStderr bytes.Buffer - tokenStore := NewCacheAlert(NewFakeTokenStore(), &testStderr) - - testTokenAlice := &oauth2.Token{ - AccessToken: mustTokenForSubject(t, "alice"), - TokenType: "Bearer", - } - testTokenBob := &oauth2.Token{ - AccessToken: mustTokenForSubject(t, "bob"), - TokenType: "Bearer", - } - - // Storing an initial token should produce no warning - err := tokenStore.Store("default", testTokenAlice) - require.NoError(t, err) - assert.Len(t, testStderr.String(), 0) - err = tokenStore.Store("special", testTokenAlice) - require.NoError(t, err) - assert.Len(t, testStderr.String(), 0) - - // Storing a token with a different principal for a given cluster should - // produce a warning - err = tokenStore.Store("default", testTokenBob) - require.NoError(t, err) - assert.Contains(t, testStderr.String(), "Login identity has changed") - assert.Contains(t, testStderr.String(), "bazel shutdown") - testStderr.Reset() - - // Storing a token with the same principal for a given cluster should - // produce no warning - err = tokenStore.Store("special", testTokenAlice) - require.NoError(t, err) - assert.Len(t, testStderr.String(), 0) -} diff --git a/internal/oauthtoken/fake.go b/internal/oauthtoken/fake.go index f1d48dc..7425ee6 100644 --- a/internal/oauthtoken/fake.go +++ b/internal/oauthtoken/fake.go @@ -17,7 +17,9 @@ package oauthtoken import ( "fmt" "io/fs" + "time" + "github.com/golang-jwt/jwt/v5" "golang.org/x/oauth2" ) @@ -71,6 +73,10 @@ func (f *FakeTokenStore) WithToken(cluster string, token *oauth2.Token) *FakeTok return f } +func (f *FakeTokenStore) WithTokenForSubject(cluster, subject string) *FakeTokenStore { + return f.WithToken(cluster, NewFakeTokenForSubject(subject)) +} + func (f *FakeTokenStore) WithLoadErr(err error) *FakeTokenStore { f.LoadErr = err return f @@ -85,3 +91,25 @@ func (f *FakeTokenStore) WithDeleteErr(err error) *FakeTokenStore { f.DeleteErr = err return f } + +func NewFakeTokenForSubject(subject string) *oauth2.Token { + now := time.Now() + expiry := now.Add(time.Hour) + payload := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + Issuer: "engflow unit tests", + Subject: subject, + Audience: nil, + ExpiresAt: jwt.NewNumericDate(expiry), + NotBefore: jwt.NewNumericDate(now), + IssuedAt: jwt.NewNumericDate(now), + }) + tokenStr, err := payload.SignedString([]byte("some signing key")) + if err != nil { + panic(err) + } + return &oauth2.Token{ + AccessToken: tokenStr, + TokenType: "Bearer", + Expiry: expiry, + } +}