diff --git a/.gitleaksignore b/.gitleaksignore index 6ce65c7b5..3a28cc2aa 100644 --- a/.gitleaksignore +++ b/.gitleaksignore @@ -14,3 +14,7 @@ b0ba7f6ed181c23a5c6532cf6f124b731d390f86:internal/presenters/testdata/with-ignor 237c7f05ec087733fa7929ee9fa3db2bd56bdba4:pkg/logging/scrubbingLogWriter_test.go:github-pat:183 237c7f05ec087733fa7929ee9fa3db2bd56bdba4:pkg/logging/scrubbingLogWriter_test.go:snyk-api-token:208 0481cdb4d07351149e65a57ebc9ad5b983896849:pkg/auth/oauth2authenticator.go:generic-api-key:30 +0a646dd3b9eeca0463fd8240b87957106f4b71f3:pkg/app/app_test.go:jwt:57 +997b87d4bed3623bdf34dcaa29b63647f1f0460a:pkg/app/app_test.go:jwt:97 +5a4b6d6a75be3cbfc5122dd7d6d4ead5142b6429:pkg/app/app_test.go:jwt:95 +internal/auth/oauth_test.go:jwt:29 \ No newline at end of file diff --git a/internal/api/urls.go b/internal/api/urls.go index 6c064cba6..a239634a8 100644 --- a/internal/api/urls.go +++ b/internal/api/urls.go @@ -52,9 +52,7 @@ func GetCanonicalApiUrlFromString(userDefinedUrl string) (string, error) { return GetCanonicalApiUrl(*url) } -func GetCanonicalApiUrl(url url.URL) (string, error) { - var result string - +func GetCanonicalApiAsUrl(url url.URL) (url.URL, error) { // for localhost we don't change the host, since there are no subdomains if isImmutableHost(url.Host) { url.Path = strings.Replace(url.Path, "/v1", "", 1) @@ -71,8 +69,16 @@ func GetCanonicalApiUrl(url url.URL) (string, error) { url.RawQuery = "" } - result = url.String() - return result, nil + return url, nil +} + +func GetCanonicalApiUrl(url url.URL) (string, error) { + result, err := GetCanonicalApiAsUrl(url) + if err != nil { + return "", err + } + + return result.String(), nil } func DeriveAppUrl(canonicalUrl string) (string, error) { diff --git a/pkg/app/app.go b/pkg/app/app.go index f70cacb7a..23f6d5dff 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -19,6 +19,7 @@ import ( "github.com/snyk/go-application-framework/internal/constants" "github.com/snyk/go-application-framework/internal/presenters" "github.com/snyk/go-application-framework/internal/utils" + "github.com/snyk/go-application-framework/pkg/auth" "github.com/snyk/go-application-framework/pkg/configuration" localworkflows "github.com/snyk/go-application-framework/pkg/local_workflows" pkg_utils "github.com/snyk/go-application-framework/pkg/utils" @@ -75,11 +76,18 @@ func defaultFuncOrganization(engine workflow.Engine, config configuration.Config return callback } -func defaultFuncApiUrl(logger *zerolog.Logger) configuration.DefaultValueFunction { +func defaultFuncApiUrl(config configuration.Configuration, logger *zerolog.Logger) configuration.DefaultValueFunction { callback := func(existingValue interface{}) interface{} { urlString := constants.SNYK_DEFAULT_API_URL - if existingValue != nil { + urlFromOauthToken, err := auth.GetAudienceClaimFromOauthToken(config.GetString(auth.CONFIG_KEY_OAUTH_TOKEN)) + if err != nil { + logger.Warn().Err(err).Msg("failed to read oauth token") + } + + if len(urlFromOauthToken) > 0 && len(urlFromOauthToken[0]) > 0 { + urlString = urlFromOauthToken[0] + } else if existingValue != nil { // configured value takes precedence if temp, ok := existingValue.(string); ok { urlString = temp } @@ -186,10 +194,11 @@ func initConfiguration(engine workflow.Engine, config configuration.Configuratio config.AddDefaultValue(configuration.AUTHENTICATION_SUBDOMAINS, configuration.StandardDefaultValueFunction([]string{"deeproxy"})) config.AddDefaultValue(configuration.MAX_THREADS, configuration.StandardDefaultValueFunction(runtime.NumCPU())) config.AddDefaultValue(presenters.CONFIG_JSON_STRIP_WHITESPACES, configuration.StandardDefaultValueFunction(true)) + config.AddDefaultValue(auth.CONFIG_KEY_ALLOWED_HOST_REGEXP, configuration.StandardDefaultValueFunction(`^api(\.(.+))?\.snyk|snykgov\.io$`)) // set default filesize threshold to 512MB config.AddDefaultValue(configuration.IN_MEMORY_THRESHOLD_BYTES, configuration.StandardDefaultValueFunction(constants.SNYK_DEFAULT_IN_MEMORY_THRESHOLD_MB)) - config.AddDefaultValue(configuration.API_URL, defaultFuncApiUrl(logger)) + config.AddDefaultValue(configuration.API_URL, defaultFuncApiUrl(config, logger)) config.AddDefaultValue(configuration.TEMP_DIR_PATH, defaultTempDirectory(engine, config, logger)) config.AddDefaultValue(configuration.WEB_APP_URL, func(existingValue any) any { diff --git a/pkg/app/app_test.go b/pkg/app/app_test.go index 2d7b407c7..2853ca6a0 100644 --- a/pkg/app/app_test.go +++ b/pkg/app/app_test.go @@ -3,6 +3,7 @@ package app import ( "errors" "fmt" + "log" "net/http" "os" "path/filepath" @@ -16,6 +17,7 @@ import ( "github.com/snyk/go-application-framework/internal/api" "github.com/snyk/go-application-framework/internal/constants" "github.com/snyk/go-application-framework/internal/mocks" + "github.com/snyk/go-application-framework/pkg/auth" "github.com/snyk/go-application-framework/pkg/configuration" "github.com/snyk/go-application-framework/pkg/runtimeinfo" "github.com/snyk/go-application-framework/pkg/workflow" @@ -58,7 +60,8 @@ func Test_AddsDefaultFunctionForCustomConfigFiles(t *testing.T) { } func Test_CreateAppEngine(t *testing.T) { - engine := CreateAppEngine() + localConfig := configuration.NewWithOpts() + engine := CreateAppEngineWithOptions(WithConfiguration(localConfig)) assert.NotNil(t, engine) err := engine.Init() @@ -70,7 +73,8 @@ func Test_CreateAppEngine(t *testing.T) { } func Test_CreateAppEngine_config_replaceV1inApi(t *testing.T) { - engine := CreateAppEngine() + localConfig := configuration.NewWithOpts() + engine := CreateAppEngineWithOptions(WithConfiguration(localConfig)) assert.NotNil(t, engine) err := engine.Init() @@ -85,6 +89,39 @@ func Test_CreateAppEngine_config_replaceV1inApi(t *testing.T) { assert.Equal(t, expectApiUrl, actualApiUrl) } +func Test_CreateAppEngine_config_OauthAudHasPrecedence(t *testing.T) { + config := configuration.New() + config.Set(auth.CONFIG_KEY_OAUTH_TOKEN, + // JWT generated at https://jwt.io with claim: + // "aud": ["https://api.example.com"] + `{"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJhdWQiOlsiaHR0cHM6Ly9hcGkuZXhhbXBsZS5jb20iXX0.hWq0fKukObQSkphAdyEC7-m4jXIb4VdWyQySmmgy0GU"}`, + ) + logger := log.New(os.Stderr, "", 0) + + t.Run("Audience claim takes precedence of configured value", func(t *testing.T) { + expectedApiUrl := "https://api.example.com" + localConfig := config.Clone() + localConfig.Set(configuration.API_URL, "https://api.dev.snyk.io") + + engine := CreateAppEngineWithOptions(WithConfiguration(localConfig), WithLogger(logger)) + assert.NotNil(t, engine) + + actualApiUrl := localConfig.GetString(configuration.API_URL) + assert.Equal(t, expectedApiUrl, actualApiUrl) + }) + + t.Run("nothing configured", func(t *testing.T) { + expectedApiUrl := "https://api.example.com" + localConfig := config.Clone() + + engine := CreateAppEngineWithOptions(WithConfiguration(localConfig), WithLogger(logger)) + assert.NotNil(t, engine) + + actualApiUrl := localConfig.GetString(configuration.API_URL) + assert.Equal(t, expectedApiUrl, actualApiUrl) + }) +} + func Test_initConfiguration_updateDefaultOrgId(t *testing.T) { orgName := "someOrgName" orgId := "someOrgId" diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go new file mode 100644 index 000000000..dd1cd85f6 --- /dev/null +++ b/pkg/auth/oauth.go @@ -0,0 +1,79 @@ +package auth + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "strings" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/jws" +) + +type arrayClaimSet struct { + // NOTE: The original jws package models audience with a string, not a + // []string. This fails to parse Snyk JWTs. + Aud []string `json:"aud"` +} + +// oauthApiUrl returns the API URL specified by the audience claim in a JWT +// token established by a prior OAuth authentication flow. +// +// Returns an empty string if an OAuth token is not available, cannot be parsed, +// or lacks such an audience claim, along with an error that may have occurred +// in the attempt to parse it. +func GetAudienceClaimFromOauthToken(oauthTokenString string) ([]string, error) { + if oauthTokenString == "" { + return []string{}, nil + } + var token oauth2.Token + if err := json.Unmarshal([]byte(oauthTokenString), &token); err != nil { + return []string{}, err + } + + return readAudience(&token) +} + +// readAudience returns the first audience claim from an OAuth2 access token, or +// an error which prevented its parsing. +// +// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3 +// Audience can be an array or a single value. +// +// If the claim is not present, an empty string is returned. +// +// This function was derived from https://pkg.go.dev/golang.org/x/oauth2/jws#Decode, +// which is licensed as follows: +// +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +func readAudience(token *oauth2.Token) ([]string, error) { + // decode returned id token to get expiry + s := strings.Split(token.AccessToken, ".") + if len(s) < 2 { + return []string{}, errors.New("jws: invalid token received") + } + + decoded, err := base64.RawURLEncoding.DecodeString(s[1]) + if err != nil { + return []string{}, err + } + + // try decode as array + c := arrayClaimSet{} + err = json.NewDecoder(bytes.NewBuffer(decoded)).Decode(&c) + if err == nil { + return c.Aud, nil + } else { + // try decode as single value + claimset := jws.ClaimSet{} + err = json.NewDecoder(bytes.NewBuffer(decoded)).Decode(&claimset) + if err != nil { + return []string{}, err + } + + return []string{claimset.Aud}, nil + } +} diff --git a/pkg/auth/oauth2authenticator.go b/pkg/auth/oauth2authenticator.go index cdd7acfd0..7a07f3832 100644 --- a/pkg/auth/oauth2authenticator.go +++ b/pkg/auth/oauth2authenticator.go @@ -14,26 +14,32 @@ import ( "math/big" "net" "net/http" + "net/url" + "regexp" + "strings" "sync" "time" "github.com/pkg/browser" + "github.com/rs/zerolog" "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" + "github.com/snyk/go-application-framework/internal/api" "github.com/snyk/go-application-framework/pkg/configuration" ) const ( //nolint:gosec // not a token value, but a configuration key - CONFIG_KEY_OAUTH_TOKEN string = "INTERNAL_OAUTH_TOKEN_STORAGE" - OAUTH_CLIENT_ID string = "b56d4c2e-b9e1-4d27-8773-ad47eafb0956" - CALLBACK_HOSTNAME string = "127.0.0.1" - CALLBACK_PATH string = "/authorization-code/callback" - TIMEOUT_SECONDS = 120 * time.Second - AUTHENTICATED_MESSAGE = "Your account has been authenticated." - PARAMETER_CLIENT_ID string = "client-id" - PARAMETER_CLIENT_SECRET string = "client-secret" + CONFIG_KEY_ALLOWED_HOST_REGEXP = "INTERNAL_OAUTH_ALLOWED_HOSTS" + CONFIG_KEY_OAUTH_TOKEN string = "INTERNAL_OAUTH_TOKEN_STORAGE" + OAUTH_CLIENT_ID string = "b56d4c2e-b9e1-4d27-8773-ad47eafb0956" + CALLBACK_HOSTNAME string = "127.0.0.1" + CALLBACK_PATH string = "/authorization-code/callback" + TIMEOUT_SECONDS = 120 * time.Second + AUTHENTICATED_MESSAGE = "Your account has been authenticated." + PARAMETER_CLIENT_ID string = "client-id" + PARAMETER_CLIENT_SECRET string = "client-secret" ) type GrantType int @@ -58,6 +64,7 @@ type oAuth2Authenticator struct { token *oauth2.Token headless bool grantType GrantType + logger *zerolog.Logger openBrowserFunc func(authUrl string) shutdownServerFunc func(server *http.Server) tokenRefresherFunc func(ctx context.Context, oauthConfig *oauth2.Config, token *oauth2.Token) (*oauth2.Token, error) @@ -193,6 +200,9 @@ func NewOAuth2Authenticator(config configuration.Configuration, httpClient *http func NewOAuth2AuthenticatorWithOpts(config configuration.Configuration, opts ...OAuth2AuthenticatorOption) Authenticator { o := &oAuth2Authenticator{} + nopLogger := zerolog.Nop() + + o.logger = &nopLogger o.config = config //nolint:errcheck // breaking api change needed to fix this o.token, _ = GetOAuthToken(config) @@ -287,6 +297,7 @@ func (o *oAuth2Authenticator) authenticateWithAuthorizationCode() error { var responseCode string var responseState string var responseError string + var responseInstance string verifier, err := createVerifier(128) if err != nil { return err @@ -312,30 +323,14 @@ func (o *oAuth2Authenticator) authenticateWithAuthorizationCode() error { mux.HandleFunc(CALLBACK_PATH, func(w http.ResponseWriter, r *http.Request) { responseError = html.EscapeString(r.URL.Query().Get("error")) if len(responseError) > 0 { - details := html.EscapeString(r.URL.Query().Get("error_description")) - - tmpl := template.New("") - tmpl, tmplError := tmpl.Parse(errorResponsePage) - if tmplError != nil { - return - } - - data := struct { - Reason string - Description string - }{ - Reason: responseError, - Description: details, - } - - tmplError = tmpl.Execute(w, data) - if tmplError != nil { + if writeCallbackErrorResponse(w, r.URL.Query(), responseError) { return } } else { appUrl := o.config.GetString(configuration.WEB_APP_URL) responseCode = html.EscapeString(r.URL.Query().Get("code")) responseState = html.EscapeString(r.URL.Query().Get("state")) + responseInstance = html.EscapeString(r.URL.Query().Get("instance")) w.Header().Add("Location", appUrl+"/authenticated?type=oauth") w.WriteHeader(http.StatusMovedPermanently) } @@ -353,8 +348,7 @@ func (o *oAuth2Authenticator) authenticateWithAuthorizationCode() error { // fill redirect url now that the port is known o.oauthConfig.RedirectURL = getRedirectUri(port) - - url := o.oauthConfig.AuthCodeURL(state, oauth2.AccessTypeOffline, + authCodeUrl := o.oauthConfig.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("code_challenge", codeChallenge), oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("response_type", "code"), @@ -362,7 +356,7 @@ func (o *oAuth2Authenticator) authenticateWithAuthorizationCode() error { oauth2.SetAuthURLParam("version", "2021-08-11~experimental")) // launch browser - go o.openBrowserFunc(url) + go o.openBrowserFunc(authCodeUrl) timedOut := false timer := time.AfterFunc(TIMEOUT_SECONDS, func() { @@ -388,6 +382,11 @@ func (o *oAuth2Authenticator) authenticateWithAuthorizationCode() error { return fmt.Errorf("incorrect response state: %s != %s", responseState, state) } + modifyTokenErr := o.modifyTokenUrl(responseInstance) + if modifyTokenErr != nil { + return modifyTokenErr + } + // Use the custom HTTP client when requesting a token. if o.httpClient != nil { ctx = context.WithValue(ctx, oauth2.HTTPClient, o.httpClient) @@ -402,6 +401,99 @@ func (o *oAuth2Authenticator) authenticateWithAuthorizationCode() error { return err } +func writeCallbackErrorResponse(w http.ResponseWriter, q url.Values, responseError string) bool { + tmpl := template.New("") + tmpl, tmplError := tmpl.Parse(errorResponsePage) + if tmplError != nil { + return true + } + + data := struct { + Reason string + Description string + }{ + Reason: responseError, + Description: html.EscapeString(q.Get("error_description")), + } + + tmplError = tmpl.Execute(w, data) + + return tmplError != nil +} + +func (o *oAuth2Authenticator) modifyTokenUrl(responseInstance string) error { + if responseInstance == "" { + return nil + } + + o.logger.Info().Msg("Instance specified in callback " + responseInstance) + authHost, err := redirectAuthHost(responseInstance) + if err != nil { + // todo error-catalog error + return err + } + + redirectAuthHostRE := o.config.GetString(CONFIG_KEY_ALLOWED_HOST_REGEXP) + o.logger.Info().Msgf("Validating with regexp: \"%s\"", redirectAuthHostRE) + isValidHost, err := isValidAuthHost(authHost, redirectAuthHostRE) + if err != nil { + return err + } + + if !isValidHost { + o.logger.Info().Msg("Instance specified in callback was invalid:" + authHost) + return fmt.Errorf("specified instance is an invalid host") + } + + oauthTokenUrl, urlParseErr := url.Parse(o.oauthConfig.Endpoint.TokenURL) + if urlParseErr != nil { + return fmt.Errorf("failed to parse auth url: %w", urlParseErr) + } + if oauthTokenUrl.Host == authHost { + o.logger.Info().Msgf("Instance specified in callback (%s) matches pre-configured value (%s)", authHost, oauthTokenUrl.Host) + return nil + } + + o.logger.Info().Msgf("Instance specified in callback (%s) does not match pre-configured value (%s)", authHost, oauthTokenUrl.Host) + oauthTokenUrl.Host = authHost + o.oauthConfig.Endpoint.TokenURL = oauthTokenUrl.String() + o.logger.Info().Msgf("New token url endpoint is: %s", o.oauthConfig.Endpoint.TokenURL) + + return nil +} + +func redirectAuthHost(instance string) (string, error) { + // handle both cases if instance is a URL or just a host + if !strings.HasPrefix(instance, "http") { + instance = "https://" + instance + } + + instanceUrl, err := url.Parse(instance) + if err != nil { + return "", err + } + + canonicalizedInstanceUrl, err := api.GetCanonicalApiAsUrl(*instanceUrl) + if err != nil { + return "", err + } + + return canonicalizedInstanceUrl.Host, nil +} + +func isValidAuthHost(authHost string, hostRegularExpression string) (bool, error) { + if len(hostRegularExpression) == 0 { + return false, fmt.Errorf("regular expression to check host names must not be empty") + } + + r, err := regexp.Compile(hostRegularExpression) + if err != nil { + return false, err + } + + return r.MatchString(authHost), nil +} + func (o *oAuth2Authenticator) AddAuthenticationHeader(request *http.Request) error { if request == nil { return fmt.Errorf("request must not be nil") diff --git a/pkg/auth/oauth2authenticator_test.go b/pkg/auth/oauth2authenticator_test.go index 6b2a549f2..675597504 100644 --- a/pkg/auth/oauth2authenticator_test.go +++ b/pkg/auth/oauth2authenticator_test.go @@ -3,6 +3,7 @@ package auth import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "testing" @@ -14,6 +15,52 @@ import ( "github.com/snyk/go-application-framework/pkg/configuration" ) +func headlessOpenBrowserFunc(t *testing.T) func(url string) { + t.Helper() + return func(url string) { + fmt.Printf("Mock opening browser... %s", url) + _, err := http.DefaultClient.Get(url) + if err != nil { + fmt.Printf("Error opening browser: %s", err) + } + } +} + +func mockOAuth2TokenHandler(t *testing.T) http.HandlerFunc { + t.Helper() + return func(w http.ResponseWriter, r *http.Request) { + newToken := &oauth2.Token{ + AccessToken: "a", + TokenType: "b", + Expiry: time.Now().Add(60 * time.Second).UTC(), + } + data, err := json.Marshal(newToken) + assert.Nil(t, err) + + w.Header().Set("Content-Type", "application/json;charset=UTF-8") + _, err = w.Write(data) + assert.Nil(t, err) + } +} + +func mockAuthorizeHandler(state string, instance string) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + // Redirect to the redirect_uri with a mock authorization code + redirectURI := r.URL.Query().Get("redirect_uri") + + if state == "" { + state = r.URL.Query().Get("state") + } + + instanceParam := "" + if instance != "" { + instanceParam = "&instance=" + instance + } + + http.Redirect(w, r, redirectURI+"?code=mock-auth-code&state="+state+instanceParam, http.StatusFound) + } +} + func Test_GetVerifier(t *testing.T) { expectedCount := 23 verifier, err := createVerifier(expectedCount) @@ -308,3 +355,123 @@ func Test_Authenticate_CredentialsGrant(t *testing.T) { token := config.GetString(CONFIG_KEY_OAUTH_TOKEN) assert.NotEmpty(t, token) } + +func Test_isValidAuthHost(t *testing.T) { + testCases := []struct { + authHost string + expected bool + }{ + {"api.au.snyk.io", true}, + {"api.example.snyk.io", true}, + {"api.snyk.io", true}, + {"api.snykgov.io", true}, + {"api.pre-release.snykgov.io", true}, + {"snyk.io", false}, + {"api.example.com", false}, + } + + for _, tc := range testCases { + actual, err := isValidAuthHost(tc.authHost, `^api(\.(.+))?\.snyk|snykgov\.io$`) + assert.NoError(t, err) + + if actual != tc.expected { + t.Errorf("isValidAuthHost(%q) = %v, want %v", tc.authHost, actual, tc.expected) + } + } +} + +func Test_Authenticate_AuthorizationCode(t *testing.T) { + t.Run("happy", func(t *testing.T) { + config := configuration.NewWithOpts() + + // Create mock server for successful oauth2 flow + mux := http.NewServeMux() + mux.HandleFunc("/oauth2/authorize", mockAuthorizeHandler("", "")) + mux.HandleFunc("/oauth2/token", mockOAuth2TokenHandler(t)) + ts := httptest.NewServer(mux) + defer ts.Close() + + config.Set(configuration.API_URL, ts.URL) + config.Set(configuration.WEB_APP_URL, ts.URL) + + authenticator := NewOAuth2AuthenticatorWithOpts( + config, + WithOpenBrowserFunc(headlessOpenBrowserFunc(t)), + ) + + err := authenticator.Authenticate() + assert.Nil(t, err) + + assert.Equal(t, "{\"access_token\":\"a\",\"token_type\":\"b\",\"expiry\":\"0001-01-01T00:00:00Z\"}", config.GetString(CONFIG_KEY_OAUTH_TOKEN)) + }) + + t.Run("supports redirect to valid instance", func(t *testing.T) { + tokenServer := httptest.NewServer(mockOAuth2TokenHandler(t)) + defer tokenServer.Close() + + // Create mock server for successful oauth2 flow + mux := http.NewServeMux() + mux.HandleFunc("/oauth2/authorize", mockAuthorizeHandler("", tokenServer.URL)) + initialAuthServer := httptest.NewServer(mux) + defer initialAuthServer.Close() + + config := configuration.NewInMemory() + config.Set(CONFIG_KEY_ALLOWED_HOST_REGEXP, ".*") + config.Set(configuration.API_URL, initialAuthServer.URL) + config.Set(configuration.WEB_APP_URL, initialAuthServer.URL) + + authenticator := NewOAuth2AuthenticatorWithOpts( + config, + WithOpenBrowserFunc(headlessOpenBrowserFunc(t)), + ) + + err := authenticator.Authenticate() + assert.NoError(t, err) + assert.Equal(t, "{\"access_token\":\"a\",\"token_type\":\"b\",\"expiry\":\"0001-01-01T00:00:00Z\"}", config.GetString(CONFIG_KEY_OAUTH_TOKEN)) + }) + + t.Run("does not redirect to invalid instance", func(t *testing.T) { + config := configuration.NewInMemory() + config.Set(CONFIG_KEY_ALLOWED_HOST_REGEXP, `^api(\.(.+))?\.snyk|snykgov\.io$`) + + // Create mock server for successful oauth2 flow + mux := http.NewServeMux() + mux.HandleFunc("/oauth2/authorize", mockAuthorizeHandler("", "api.malicioussnyk.io")) + mux.HandleFunc("/oauth2/token", mockOAuth2TokenHandler(t)) + + ts := httptest.NewServer(mux) + defer ts.Close() + + config.Set(configuration.WEB_APP_URL, ts.URL) + config.Set(configuration.API_URL, ts.URL) + authenticator := NewOAuth2AuthenticatorWithOpts( + config, + WithOpenBrowserFunc(headlessOpenBrowserFunc(t)), + ) + + err := authenticator.Authenticate() + assert.ErrorContains(t, err, "invalid host") + }) + + t.Run("fails with malformed state", func(t *testing.T) { + config := configuration.NewInMemory() + + // Create mock server for unsuccessful oauth2 flow + mux := http.NewServeMux() + mux.HandleFunc("/oauth2/authorize", mockAuthorizeHandler("incorrect-state", "")) + mux.HandleFunc("/oauth2/token", mockOAuth2TokenHandler(t)) + ts := httptest.NewServer(mux) + defer ts.Close() + + config.Set(configuration.API_URL, ts.URL) + config.Set(configuration.WEB_APP_URL, ts.URL) + + authenticator := NewOAuth2AuthenticatorWithOpts( + config, + WithOpenBrowserFunc(headlessOpenBrowserFunc(t)), + ) + + err := authenticator.Authenticate() + assert.ErrorContains(t, err, "incorrect response state") + }) +} diff --git a/pkg/auth/oauth2authenticatoroptions.go b/pkg/auth/oauth2authenticatoroptions.go index 2f74c3cbd..72f0d5ebc 100644 --- a/pkg/auth/oauth2authenticatoroptions.go +++ b/pkg/auth/oauth2authenticatoroptions.go @@ -4,6 +4,7 @@ import ( "context" "net/http" + "github.com/rs/zerolog" "golang.org/x/oauth2" ) @@ -21,6 +22,12 @@ func WithShutdownServerFunc(shutdownServerFunc func(server *http.Server)) OAuth2 } } +func WithLogger(logger *zerolog.Logger) OAuth2AuthenticatorOption { + return func(authenticator *oAuth2Authenticator) { + authenticator.logger = logger + } +} + func WithTokenRefresherFunc(refreshFunc func(ctx context.Context, oauthConfig *oauth2.Config, token *oauth2.Token) (*oauth2.Token, error)) OAuth2AuthenticatorOption { return func(authenticator *oAuth2Authenticator) { authenticator.tokenRefresherFunc = refreshFunc diff --git a/pkg/auth/oauth_test.go b/pkg/auth/oauth_test.go new file mode 100644 index 000000000..8f2c76b28 --- /dev/null +++ b/pkg/auth/oauth_test.go @@ -0,0 +1,89 @@ +package auth + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/oauth2" + "golang.org/x/oauth2/jws" +) + +func getAccessTokenWithSingleAudienceClaim(t *testing.T, audience string) string { + t.Helper() + header := &jws.Header{} + claims := &jws.ClaimSet{ + Aud: audience, + } + pk, err := rsa.GenerateKey(rand.Reader, 1023) + assert.NoError(t, err) + + accessToken, err := jws.Encode(header, claims, pk) + assert.NoError(t, err) + + return accessToken +} + +func getAccessTokenWithMultpleAudienceClaim() string { + return "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJhdWQiOlsiaHR0cHM6Ly9hcGkuZXhhbXBsZS5jb20iXX0.hWq0fKukObQSkphAdyEC7-m4jXIb4VdWyQySmmgy0GU" +} + +func Test_ReadAudience_SingleClaim(t *testing.T) { + expectedString := "api.eu.snyk.io" + expectedAudience := []string{expectedString} + token := oauth2.Token{ + AccessToken: getAccessTokenWithSingleAudienceClaim(t, expectedString), + } + + actualAudience, err := readAudience(&token) + assert.NoError(t, err) + + assert.Equal(t, expectedAudience, actualAudience) +} + +func Test_ReadAudience_ArrayClaim(t *testing.T) { + expectedAudience := []string{"https://api.example.com"} + token := oauth2.Token{ + AccessToken: getAccessTokenWithMultpleAudienceClaim(), + } + + actualAudience, err := readAudience(&token) + assert.NoError(t, err) + + assert.Equal(t, expectedAudience, actualAudience) +} + +func Test_GetAudienceClaimFromOauthToken(t *testing.T) { + t.Run("Happy path", func(t *testing.T) { + expectedString := "api.eu.snyk.io" + expectedAudience := []string{expectedString} + token := oauth2.Token{ + AccessToken: getAccessTokenWithSingleAudienceClaim(t, expectedString), + } + + tokenBytes, err := json.Marshal(token) + assert.NoError(t, err) + + actualClaims, err := GetAudienceClaimFromOauthToken(string(tokenBytes)) + assert.NoError(t, err) + assert.Equal(t, expectedAudience, actualClaims) + }) + + t.Run("empty token string", func(t *testing.T) { + expectedAudience := []string{} + + actualClaims, err := GetAudienceClaimFromOauthToken("") + assert.NoError(t, err) + assert.Equal(t, expectedAudience, actualClaims) + }) + + t.Run("random string value", func(t *testing.T) { + expectedAudience := []string{} + + actualClaims, err := GetAudienceClaimFromOauthToken("aihsfdhajksh") + assert.Error(t, err) + assert.Equal(t, expectedAudience, actualClaims) + }) +} diff --git a/pkg/configuration/configuration.go b/pkg/configuration/configuration.go index d95c8e887..5a3edc80f 100644 --- a/pkg/configuration/configuration.go +++ b/pkg/configuration/configuration.go @@ -64,6 +64,7 @@ type Configuration interface { GetSupportedEnvVarPrefixes() []string SetFiles(files ...string) GetFiles() []string + ReloadConfig() error } // extendedViper is a wrapper around the viper library. @@ -668,3 +669,7 @@ func (ev *extendedViper) GetFiles() []string { return ev.configFiles } + +func (ev *extendedViper) ReloadConfig() error { + return ev.viper.ReadInConfig() +} diff --git a/pkg/local_workflows/auth_workflow.go b/pkg/local_workflows/auth_workflow.go index 951dd8439..0df2c2525 100644 --- a/pkg/local_workflows/auth_workflow.go +++ b/pkg/local_workflows/auth_workflow.go @@ -68,6 +68,7 @@ func authEntryPoint(invocationCtx workflow.InvocationContext, _ []workflow.Data) auth.WithHttpClient(httpClient), auth.WithOpenBrowserFunc(OpenBrowser), auth.WithShutdownServerFunc(auth.ShutdownServer), + auth.WithLogger(logger), ) err = entryPointDI(config, logger, engine, authenticator) diff --git a/pkg/local_workflows/config_utils/sanitycheck.go b/pkg/local_workflows/config_utils/sanitycheck.go new file mode 100644 index 000000000..8fc0823a1 --- /dev/null +++ b/pkg/local_workflows/config_utils/sanitycheck.go @@ -0,0 +1,47 @@ +package config_utils + +import ( + "fmt" + "strings" + + "github.com/snyk/go-application-framework/internal/api" + "github.com/snyk/go-application-framework/pkg/auth" + "github.com/snyk/go-application-framework/pkg/configuration" + "github.com/snyk/go-application-framework/pkg/utils" +) + +type SanityCheckResult struct { + Description string +} + +func CheckSanity(config configuration.Configuration) []SanityCheckResult { + var result []SanityCheckResult + + keys := []string{configuration.API_URL, configuration.AUTHENTICATION_TOKEN, configuration.AUTHENTICATION_BEARER_TOKEN, configuration.ORGANIZATION} + for _, key := range keys { + keysSpecified := config.GetAllKeysThatContainValues(key) + if len(keysSpecified) > 1 { + result = append(result, SanityCheckResult{ + Description: fmt.Sprintf("Possible unexpected behavior, the following configuration values might override each other %s", strings.ToUpper(strings.Join(keysSpecified, ", "))), + }) + } + } + + if keysSpecified := config.GetAllKeysThatContainValues(configuration.API_URL); len(keysSpecified) > 0 { + audience, err := auth.GetAudienceClaimFromOauthToken(config.GetString(auth.CONFIG_KEY_OAUTH_TOKEN)) + if err == nil && len(audience) > 0 { + clonedConfig := config.Clone() + clonedConfig.AddDefaultValue(configuration.API_URL, nil) + configuredValue := clonedConfig.GetString(configuration.API_URL) + differentApiUrlsSpecified := utils.ValueOf(api.GetCanonicalApiUrlFromString(audience[0])) != utils.ValueOf(api.GetCanonicalApiUrlFromString(configuredValue)) + + if differentApiUrlsSpecified { + result = append(result, SanityCheckResult{ + Description: fmt.Sprintf("Using API Url from authentication material, therefore ignoring the specified %s.", strings.ToUpper(strings.Join(keysSpecified, ", "))), + }) + } + } + } + + return result +} diff --git a/pkg/local_workflows/config_utils/sanitycheck_test.go b/pkg/local_workflows/config_utils/sanitycheck_test.go new file mode 100644 index 000000000..c900f7168 --- /dev/null +++ b/pkg/local_workflows/config_utils/sanitycheck_test.go @@ -0,0 +1,82 @@ +package config_utils + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/oauth2" + "golang.org/x/oauth2/jws" + + "github.com/snyk/go-application-framework/pkg/auth" + "github.com/snyk/go-application-framework/pkg/configuration" +) + +func Test_CheckSanity_ApiUrl(t *testing.T) { + expectedAudience := "hello.world" + header := &jws.Header{} + claims := &jws.ClaimSet{ + Aud: expectedAudience, + } + pk, err := rsa.GenerateKey(rand.Reader, 1023) + assert.NoError(t, err) + + accessToken, err := jws.Encode(header, claims, pk) + assert.NoError(t, err) + + token := oauth2.Token{ + AccessToken: accessToken, + } + + tokenBytes, err := json.Marshal(token) + assert.NoError(t, err) + + t.Run("different url from auth material", func(t *testing.T) { + // Create a configuration with duplicate keys + config := configuration.NewWithOpts() + config.Set(configuration.API_URL, "https://api1.example.com") + config.Set(auth.CONFIG_KEY_OAUTH_TOKEN, string(tokenBytes)) + + result := CheckSanity(config) + + expectedDescription := "Using API Url from authentication material, therefore ignoring the specified" + assert.Len(t, result, 1) + assert.Contains(t, result[0].Description, expectedDescription) + }) + + t.Run("same url auth material", func(t *testing.T) { + // Create a configuration with duplicate keys + config := configuration.NewWithOpts() + config.Set(configuration.API_URL, expectedAudience) + config.Set(auth.CONFIG_KEY_OAUTH_TOKEN, string(tokenBytes)) + + result := CheckSanity(config) + assert.Len(t, result, 0) + }) + + t.Run("no auth material", func(t *testing.T) { + // Create a configuration with duplicate keys + config := configuration.NewWithOpts() + config.Set(configuration.API_URL, expectedAudience) + + result := CheckSanity(config) + assert.Len(t, result, 0) + }) +} + +func Test_CheckSanity_Token(t *testing.T) { + alternativeTokenVariable := "my_token" + expectedDescription := "Possible unexpected behavior, the following configuration values might override each other " + + config := configuration.NewWithOpts() + config.Set(configuration.AUTHENTICATION_TOKEN, "random1") + config.Set(alternativeTokenVariable, "random2") + config.AddAlternativeKeys(configuration.AUTHENTICATION_TOKEN, []string{alternativeTokenVariable}) + + result := CheckSanity(config) + assert.Len(t, result, 1) + + assert.Contains(t, result[0].Description, expectedDescription) +}