Skip to content

Commit

Permalink
feat: set api url from jwt audience claim
Browse files Browse the repository at this point in the history
Snyk's OAuth implementation is capable of indicating the environment
which the user is authenticated into and authorized to access.

This is specified in the audience JWT claim ("aud"). Snyk's
implementation of this claim contains an array of strings, per RFC 7519.

If set and non-empty, the first audience URL is taken as the default API
URL that the client should use, unless the endpoint was specifically
configured.
  • Loading branch information
cmars committed Oct 8, 2024
1 parent f9b5256 commit 5f18602
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 2 deletions.
69 changes: 67 additions & 2 deletions pkg/app/app.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package app

import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"io"
"log"
"net/http"
Expand All @@ -12,10 +16,12 @@ import (
"github.com/rs/zerolog"
zlog "github.com/rs/zerolog/log"
"github.com/snyk/go-httpauth/pkg/httpauth"
"golang.org/x/oauth2"

"github.com/snyk/go-application-framework/internal/api"
"github.com/snyk/go-application-framework/internal/constants"
"github.com/snyk/go-application-framework/internal/utils"
"github.com/snyk/go-application-framework/pkg/auth"
"github.com/snyk/go-application-framework/pkg/configuration"
localworkflows "github.com/snyk/go-application-framework/pkg/local_workflows"
pkg_utils "github.com/snyk/go-application-framework/pkg/utils"
Expand Down Expand Up @@ -72,14 +78,18 @@ func defaultFuncOrganization(engine workflow.Engine, config configuration.Config
return callback
}

func defaultFuncApiUrl(logger *zerolog.Logger) configuration.DefaultValueFunction {
func defaultFuncApiUrl(config configuration.Configuration, logger *zerolog.Logger) configuration.DefaultValueFunction {
callback := func(existingValue interface{}) interface{} {
urlString := constants.SNYK_DEFAULT_API_URL

if existingValue != nil {
if temp, ok := existingValue.(string); ok {
urlString = temp
}
} else if u, err := oauthApiUrl(config); err != nil {
logger.Warn().Err(err).Msg("failed to read oauth token")
} else if u != "" {
urlString = u
}

apiString, err := api.GetCanonicalApiUrlFromString(urlString)
Expand All @@ -91,6 +101,61 @@ func defaultFuncApiUrl(logger *zerolog.Logger) configuration.DefaultValueFunctio
return callback
}

// oauthApiUrl returns the API URL specified by the audience claim in a JWT
// token established by a prior OAuth authentication flow.
//
// Returns an empty string if an OAuth token is not available, cannot be parsed,
// or lacks such an audience claim, along with an error that may have occurred
// in the attempt to parse it.
func oauthApiUrl(config configuration.Configuration) (string, error) {
oauthTokenString, ok := config.Get(auth.CONFIG_KEY_OAUTH_TOKEN).(string)
if !ok || oauthTokenString == "" {
return "", nil
}
var token oauth2.Token
if err := json.Unmarshal([]byte(oauthTokenString), &token); err != nil {
return "", err
}
return readAudience(&token)
}

// readAudience returns the first audience claim from an OAuth2 access token, or
// an error which prevented its parsing.
//
// If the claim is not present, an empty string is returned.
//
// This function was derived from https://pkg.go.dev/golang.org/x/oauth2/jws#Decode,
// which is licensed as follows:
//
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
func readAudience(token *oauth2.Token) (string, error) {
// decode returned id token to get expiry
s := strings.Split(token.AccessToken, ".")
if len(s) < 2 {
// TODO(jbd): Provide more context about the error.
return "", errors.New("jws: invalid token received")
}
decoded, err := base64.RawURLEncoding.DecodeString(s[1])
if err != nil {
return "", err
}
c := struct {
// NOTE: The original jws package models audience with a string, not a
// []string. This fails to parse Snyk JWTs.
Aud []string `json:"aud"`
}{}
err = json.NewDecoder(bytes.NewBuffer(decoded)).Decode(&c)
if err != nil {
return "", err
}
if len(c.Aud) > 0 {
return c.Aud[0], nil
}
return "", nil
}

func defaultInputDirectory() configuration.DefaultValueFunction {
callback := func(existingValue interface{}) interface{} {
if existingValue == nil {
Expand Down Expand Up @@ -184,7 +249,7 @@ func initConfiguration(engine workflow.Engine, config configuration.Configuratio
config.AddDefaultValue(configuration.MAX_THREADS, configuration.StandardDefaultValueFunction(runtime.NumCPU()))
// set default filesize threshold to 512MB
config.AddDefaultValue(configuration.IN_MEMORY_THRESHOLD_BYTES, configuration.StandardDefaultValueFunction(constants.SNYK_DEFAULT_IN_MEMORY_THRESHOLD_MB))
config.AddDefaultValue(configuration.API_URL, defaultFuncApiUrl(logger))
config.AddDefaultValue(configuration.API_URL, defaultFuncApiUrl(config, logger))
config.AddDefaultValue(configuration.TEMP_DIR_PATH, defaultTempDirectory(engine, config, logger))

config.AddDefaultValue(configuration.WEB_APP_URL, func(existingValue any) any {
Expand Down
17 changes: 17 additions & 0 deletions pkg/app/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package app
import (
"errors"
"fmt"
"log"
"net/http"
"os"
"testing"
Expand All @@ -14,6 +15,7 @@ import (
"github.com/snyk/go-application-framework/internal/api"
"github.com/snyk/go-application-framework/internal/constants"
"github.com/snyk/go-application-framework/internal/mocks"
"github.com/snyk/go-application-framework/pkg/auth"
"github.com/snyk/go-application-framework/pkg/configuration"
"github.com/snyk/go-application-framework/pkg/runtimeinfo"
"github.com/snyk/go-application-framework/pkg/workflow"
Expand Down Expand Up @@ -47,6 +49,21 @@ func Test_CreateAppEngine_config_replaceV1inApi(t *testing.T) {
assert.Equal(t, expectApiUrl, actualApiUrl)
}

func Test_CreateAppEngine_config_oauthApiUrl(t *testing.T) {
config := configuration.New()
config.Set(auth.CONFIG_KEY_OAUTH_TOKEN,
// JWT generated at https://jwt.io with claim:
// "aud": ["https://api.example.com"]
`{"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJhdWQiOlsiaHR0cHM6Ly9hcGkuZXhhbXBsZS5jb20iXX0.hWq0fKukObQSkphAdyEC7-m4jXIb4VdWyQySmmgy0GU"}`,
)
logger := log.New(os.Stderr, "", 0)
engine := CreateAppEngineWithOptions(WithConfiguration(config), WithLogger(logger))
initConfiguration(engine, config, engine.GetLogger(), nil)

actualApiUrl := config.GetString(configuration.API_URL)
assert.Equal(t, "https://api.example.com", actualApiUrl)
}

func Test_initConfiguration_updateDefaultOrgId(t *testing.T) {
orgName := "someOrgName"
orgId := "someOrgId"
Expand Down

0 comments on commit 5f18602

Please sign in to comment.