diff --git a/.gitleaksignore b/.gitleaksignore index 6ce65c7b..cec0319d 100644 --- a/.gitleaksignore +++ b/.gitleaksignore @@ -14,3 +14,4 @@ b0ba7f6ed181c23a5c6532cf6f124b731d390f86:internal/presenters/testdata/with-ignor 237c7f05ec087733fa7929ee9fa3db2bd56bdba4:pkg/logging/scrubbingLogWriter_test.go:github-pat:183 237c7f05ec087733fa7929ee9fa3db2bd56bdba4:pkg/logging/scrubbingLogWriter_test.go:snyk-api-token:208 0481cdb4d07351149e65a57ebc9ad5b983896849:pkg/auth/oauth2authenticator.go:generic-api-key:30 +0a646dd3b9eeca0463fd8240b87957106f4b71f3:pkg/app/app_test.go:jwt:57 diff --git a/pkg/app/app.go b/pkg/app/app.go index 3605b8e6..c7412932 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -1,6 +1,10 @@ package app import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" "io" "log" "net/http" @@ -12,10 +16,12 @@ import ( "github.com/rs/zerolog" zlog "github.com/rs/zerolog/log" "github.com/snyk/go-httpauth/pkg/httpauth" + "golang.org/x/oauth2" "github.com/snyk/go-application-framework/internal/api" "github.com/snyk/go-application-framework/internal/constants" "github.com/snyk/go-application-framework/internal/utils" + "github.com/snyk/go-application-framework/pkg/auth" "github.com/snyk/go-application-framework/pkg/configuration" localworkflows "github.com/snyk/go-application-framework/pkg/local_workflows" pkg_utils "github.com/snyk/go-application-framework/pkg/utils" @@ -72,7 +78,7 @@ func defaultFuncOrganization(engine workflow.Engine, config configuration.Config return callback } -func defaultFuncApiUrl(logger *zerolog.Logger) configuration.DefaultValueFunction { +func defaultFuncApiUrl(config configuration.Configuration, logger *zerolog.Logger) configuration.DefaultValueFunction { callback := func(existingValue interface{}) interface{} { urlString := constants.SNYK_DEFAULT_API_URL @@ -80,6 +86,10 @@ func defaultFuncApiUrl(logger *zerolog.Logger) configuration.DefaultValueFunctio if temp, ok := existingValue.(string); ok { urlString = temp } + } else if u, err := oauthApiUrl(config); err != nil { + logger.Warn().Err(err).Msg("failed to read oauth token") + } else if u != "" { + urlString = u } apiString, err := api.GetCanonicalApiUrlFromString(urlString) @@ -91,6 +101,61 @@ func defaultFuncApiUrl(logger *zerolog.Logger) configuration.DefaultValueFunctio return callback } +// oauthApiUrl returns the API URL specified by the audience claim in a JWT +// token established by a prior OAuth authentication flow. +// +// Returns an empty string if an OAuth token is not available, cannot be parsed, +// or lacks such an audience claim, along with an error that may have occurred +// in the attempt to parse it. +func oauthApiUrl(config configuration.Configuration) (string, error) { + oauthTokenString, ok := config.Get(auth.CONFIG_KEY_OAUTH_TOKEN).(string) + if !ok || oauthTokenString == "" { + return "", nil + } + var token oauth2.Token + if err := json.Unmarshal([]byte(oauthTokenString), &token); err != nil { + return "", err + } + return readAudience(&token) +} + +// readAudience returns the first audience claim from an OAuth2 access token, or +// an error which prevented its parsing. +// +// If the claim is not present, an empty string is returned. +// +// This function was derived from https://pkg.go.dev/golang.org/x/oauth2/jws#Decode, +// which is licensed as follows: +// +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +func readAudience(token *oauth2.Token) (string, error) { + // decode returned id token to get expiry + s := strings.Split(token.AccessToken, ".") + if len(s) < 2 { + // TODO(jbd): Provide more context about the error. + return "", errors.New("jws: invalid token received") + } + decoded, err := base64.RawURLEncoding.DecodeString(s[1]) + if err != nil { + return "", err + } + c := struct { + // NOTE: The original jws package models audience with a string, not a + // []string. This fails to parse Snyk JWTs. + Aud []string `json:"aud"` + }{} + err = json.NewDecoder(bytes.NewBuffer(decoded)).Decode(&c) + if err != nil { + return "", err + } + if len(c.Aud) > 0 { + return c.Aud[0], nil + } + return "", nil +} + func defaultInputDirectory() configuration.DefaultValueFunction { callback := func(existingValue interface{}) interface{} { if existingValue == nil { @@ -184,7 +249,7 @@ func initConfiguration(engine workflow.Engine, config configuration.Configuratio config.AddDefaultValue(configuration.MAX_THREADS, configuration.StandardDefaultValueFunction(runtime.NumCPU())) // set default filesize threshold to 512MB config.AddDefaultValue(configuration.IN_MEMORY_THRESHOLD_BYTES, configuration.StandardDefaultValueFunction(constants.SNYK_DEFAULT_IN_MEMORY_THRESHOLD_MB)) - config.AddDefaultValue(configuration.API_URL, defaultFuncApiUrl(logger)) + config.AddDefaultValue(configuration.API_URL, defaultFuncApiUrl(config, logger)) config.AddDefaultValue(configuration.TEMP_DIR_PATH, defaultTempDirectory(engine, config, logger)) config.AddDefaultValue(configuration.WEB_APP_URL, func(existingValue any) any { diff --git a/pkg/app/app_test.go b/pkg/app/app_test.go index 5b99a8c3..a844088f 100644 --- a/pkg/app/app_test.go +++ b/pkg/app/app_test.go @@ -3,6 +3,7 @@ package app import ( "errors" "fmt" + "log" "net/http" "os" "testing" @@ -14,6 +15,7 @@ import ( "github.com/snyk/go-application-framework/internal/api" "github.com/snyk/go-application-framework/internal/constants" "github.com/snyk/go-application-framework/internal/mocks" + "github.com/snyk/go-application-framework/pkg/auth" "github.com/snyk/go-application-framework/pkg/configuration" "github.com/snyk/go-application-framework/pkg/runtimeinfo" "github.com/snyk/go-application-framework/pkg/workflow" @@ -47,6 +49,21 @@ func Test_CreateAppEngine_config_replaceV1inApi(t *testing.T) { assert.Equal(t, expectApiUrl, actualApiUrl) } +func Test_CreateAppEngine_config_oauthApiUrl(t *testing.T) { + config := configuration.New() + config.Set(auth.CONFIG_KEY_OAUTH_TOKEN, + // JWT generated at https://jwt.io with claim: + // "aud": ["https://api.example.com"] + `{"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJhdWQiOlsiaHR0cHM6Ly9hcGkuZXhhbXBsZS5jb20iXX0.hWq0fKukObQSkphAdyEC7-m4jXIb4VdWyQySmmgy0GU"}`, + ) + logger := log.New(os.Stderr, "", 0) + engine := CreateAppEngineWithOptions(WithConfiguration(config), WithLogger(logger)) + initConfiguration(engine, config, engine.GetLogger(), nil) + + actualApiUrl := config.GetString(configuration.API_URL) + assert.Equal(t, "https://api.example.com", actualApiUrl) +} + func Test_initConfiguration_updateDefaultOrgId(t *testing.T) { orgName := "someOrgName" orgId := "someOrgId" diff --git a/pkg/auth/oauth2authenticator.go b/pkg/auth/oauth2authenticator.go index cdd7acfd..4d19fe78 100644 --- a/pkg/auth/oauth2authenticator.go +++ b/pkg/auth/oauth2authenticator.go @@ -14,6 +14,8 @@ import ( "math/big" "net" "net/http" + "net/url" + "regexp" "sync" "time" @@ -287,6 +289,7 @@ func (o *oAuth2Authenticator) authenticateWithAuthorizationCode() error { var responseCode string var responseState string var responseError string + var responseInstance string verifier, err := createVerifier(128) if err != nil { return err @@ -336,6 +339,7 @@ func (o *oAuth2Authenticator) authenticateWithAuthorizationCode() error { appUrl := o.config.GetString(configuration.WEB_APP_URL) responseCode = html.EscapeString(r.URL.Query().Get("code")) responseState = html.EscapeString(r.URL.Query().Get("state")) + responseInstance = html.EscapeString(r.URL.Query().Get("instance")) w.Header().Add("Location", appUrl+"/authenticated?type=oauth") w.WriteHeader(http.StatusMovedPermanently) } @@ -388,6 +392,23 @@ func (o *oAuth2Authenticator) authenticateWithAuthorizationCode() error { return fmt.Errorf("incorrect response state: %s != %s", responseState, state) } + if responseInstance != "" { + authHost := redirectAuthHost(responseInstance) + if err != nil { + return fmt.Errorf("invalid instance: %q", responseInstance) + } + if !isValidAuthHost(authHost) { + return fmt.Errorf("invalid instance: %q", responseInstance) + } + + authURL, err := url.Parse(o.oauthConfig.Endpoint.AuthURL) + if err != nil { + return fmt.Errorf("failed to parse auth url: %w", err) + } + authURL.Host = authHost + o.oauthConfig.Endpoint.AuthURL = authURL.String() + } + // Use the custom HTTP client when requesting a token. if o.httpClient != nil { ctx = context.WithValue(ctx, oauth2.HTTPClient, o.httpClient) @@ -402,6 +423,16 @@ func (o *oAuth2Authenticator) authenticateWithAuthorizationCode() error { return err } +func redirectAuthHost(instance string) string { + return fmt.Sprintf("api.%s", instance) +} + +var redirectAuthHostRE = regexp.MustCompile(`^api\.(.+)\.snyk\.io$`) + +func isValidAuthHost(authHost string) bool { + return redirectAuthHostRE.MatchString(authHost) +} + func (o *oAuth2Authenticator) AddAuthenticationHeader(request *http.Request) error { if request == nil { return fmt.Errorf("request must not be nil")