Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implementation of AWS OIDC #232

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions api/v1alpha1/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,9 @@ type AWSOIDCExchangeToken struct {
// AwsRoleArn is the AWS IAM Role with the permission to use specific resources in AWS account
// which maps to the temporary AWS security credentials exchanged using the authentication token issued by OIDC provider.
AwsRoleArn string `json:"awsRoleArn"`

// ProxyURL can be used when communication with STS.
ProxyURL string `json:"proxyUrl,omitempty"`
mathetake marked this conversation as resolved.
Show resolved Hide resolved
}

// LLMRequestCost configures each request cost.
Expand Down
2 changes: 1 addition & 1 deletion filterapi/filterconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ type BackendAuth struct {
// AWSAuth defines the credentials needed to access AWS.
type AWSAuth struct {
CredentialFileName string `json:"credentialFileName,omitempty"`
Region string `json:"region"`
Region string `json:"region,omitempty"`
}

// APIKeyAuth defines the file that will be mounted to the external proc.
Expand Down
9 changes: 6 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@ require (
github.com/aws/aws-sdk-go-v2 v1.34.0
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8
github.com/aws/aws-sdk-go-v2/config v1.29.2
github.com/aws/aws-sdk-go-v2/credentials v1.17.55
github.com/aws/aws-sdk-go-v2/service/sts v1.33.10
github.com/coreos/go-oidc/v3 v3.12.0
github.com/envoyproxy/gateway v1.3.0-rc.1
github.com/envoyproxy/go-control-plane/envoy v1.32.3
github.com/go-logr/logr v1.4.2
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/google/cel-go v0.23.0
github.com/google/go-cmp v0.6.0
github.com/openai/openai-go v0.1.0-alpha.49
github.com/stretchr/testify v1.10.0
go.uber.org/zap v1.27.0
golang.org/x/exp v0.0.0-20250128144449-3edf0e91c1ae
golang.org/x/oauth2 v0.25.0
google.golang.org/grpc v1.70.0
google.golang.org/protobuf v1.36.4
k8s.io/api v0.32.1
Expand All @@ -31,7 +36,6 @@ require (
require (
cel.dev/expr v0.19.1 // indirect
github.com/antlr4-go/antlr/v4 v4.13.0 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.55 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.25 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.29 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.29 // indirect
Expand All @@ -40,7 +44,6 @@ require (
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.10 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.24.12 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.11 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.33.10 // indirect
github.com/aws/smithy-go v1.22.2 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
Expand All @@ -52,6 +55,7 @@ require (
github.com/evanphx/json-patch/v5 v5.9.0 // indirect
github.com/fsnotify/fsnotify v1.8.0 // indirect
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
github.com/go-jose/go-jose/v4 v4.0.2 // indirect
github.com/go-logr/zapr v1.3.0 // indirect
github.com/go-openapi/jsonpointer v0.21.0 // indirect
github.com/go-openapi/jsonreference v0.21.0 // indirect
Expand Down Expand Up @@ -86,7 +90,6 @@ require (
github.com/x448/float16 v0.8.4 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/net v0.34.0 // indirect
golang.org/x/oauth2 v0.25.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.29.0 // indirect
golang.org/x/term v0.28.0 // indirect
Expand Down
8 changes: 8 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 h1:QVw89YDxXxEe+l8gU8ETbOasdwEV+avkR75ZzsVV9WI=
github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8=
github.com/coreos/go-oidc/v3 v3.12.0 h1:sJk+8G2qq94rDI6ehZ71Bol3oUHy63qNYmkiSjrc/Jo=
github.com/coreos/go-oidc/v3 v3.12.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
Expand All @@ -56,6 +58,8 @@ github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E=
github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk=
github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
Expand All @@ -72,6 +76,8 @@ github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1v
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
Expand Down Expand Up @@ -184,6 +190,8 @@ go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/exp v0.0.0-20250128144449-3edf0e91c1ae h1:COZdc9Ut6wLq7MO9GIYxfZl4n4ScmgqQLoHocKXrxco=
golang.org/x/exp v0.0.0-20250128144449-3edf0e91c1ae/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
Expand Down
12 changes: 11 additions & 1 deletion internal/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
gwapiv1b1 "sigs.k8s.io/gateway-api/apis/v1beta1"

aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1"
"github.com/envoyproxy/ai-gateway/internal/controller/oidc"
)

func init() { MustInitializeScheme(scheme) }
Expand Down Expand Up @@ -115,6 +116,14 @@ func StartControllers(ctx context.Context, config *rest.Config, logger logr.Logg
if err = mgr.Start(ctx); err != nil { // This blocks until the manager is stopped.
return fmt.Errorf("failed to start controller manager: %w", err)
}

handler, err := oidc.NewOIDCHandler(&logger, c)
if err != nil {
return fmt.Errorf("failed to create OIDC handler: %w", err)
}

go handler.UpdateCredentials(ctx)

return nil
}

Expand Down Expand Up @@ -181,8 +190,9 @@ func backendSecurityPolicyIndexFunc(o client.Object) []string {
awsCreds := backendSecurityPolicy.Spec.AWSCredentials
if awsCreds.CredentialsFile != nil {
key = getSecretNameAndNamespace(awsCreds.CredentialsFile.SecretRef, backendSecurityPolicy.Namespace)
} else if awsCreds.OIDCExchangeToken != nil {
key = getSecretNameAndNamespace(&awsCreds.OIDCExchangeToken.OIDC.ClientSecret, backendSecurityPolicy.Namespace)
}
// TODO: OIDC.
}
return []string{key}
}
Expand Down
78 changes: 78 additions & 0 deletions internal/controller/oidc/aws.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package oidc

import (
"context"
"fmt"
"net/http"
"net/url"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
)

const OidcAwsPrefix = "oidc-aws-"

func getSTSCredentials(region, roleArn, proxyURL, accessToken string) (aws.Credentials, error) {
// create sts client
stsCfg := aws.Config{
Region: region,
}
if proxyURL != "" {
stsCfg.HTTPClient = &http.Client{
Transport: &http.Transport{
Proxy: func(*http.Request) (*url.URL, error) {
return url.Parse(proxyURL)
},
},
}
}
stsClient := sts.NewFromConfig(stsCfg)
credentialsCache := aws.NewCredentialsCache(stscreds.NewWebIdentityRoleProvider(
stsClient,
roleArn,
IdentityTokenValue(accessToken),
))
return credentialsCache.Retrieve(context.TODO())
}

func updateOrCreateAWSSecret(k8sClient client.Client, credentials aws.Credentials, namespace, bspKey string) error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make the functions small and do one thing.
Having a function name with Or in it is a sign something isn't right ;)

namespaceName := types.NamespacedName{
Namespace: namespace,
Name: fmt.Sprintf("%s%s", OidcAwsPrefix, bspKey),
}
credentialSecret := corev1.Secret{}
err := k8sClient.Get(context.TODO(), namespaceName, &credentialSecret)
if err != nil {
if client.IgnoreNotFound(err) != nil {
return fmt.Errorf("fail to get secret for backend security policy %w", err)
}
err = k8sClient.Create(context.Background(), &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: namespaceName.Name,
Namespace: namespaceName.Namespace,
},
})
if err != nil {
return err
}
}
if credentialSecret.StringData == nil {
credentialSecret.StringData = make(map[string]string)
}
credentialSecret.StringData["credentials"] = fmt.Sprintf("[default]\n"+
"aws_access_key_id = %s\n"+
"aws_secret_access_key = %s\n"+
"aws_session_token = %s\n",
credentials.AccessKeyID, credentials.SecretAccessKey, credentials.SessionToken)

err = k8sClient.Update(context.TODO(), &credentialSecret)
if err != nil {
return fmt.Errorf("fail to refresh find secret for backend security policy %w", err)
}
return nil
}
1 change: 1 addition & 0 deletions internal/controller/oidc/aws_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package oidc
176 changes: 176 additions & 0 deletions internal/controller/oidc/oidc.go
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect several of these functions do exist in the JWT or OIDC libraries - please take a look

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which specific functions do you suspect? I might be missing something

Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package oidc

import (
"context"
"fmt"
"log/slog"
"net/url"
"time"

"github.com/coreos/go-oidc/v3/oidc"
egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renaming this seems to lose the readability of the code. is it necessary?
@mathetake is this a pattern used across the code base?

Copy link
Member

@mathetake mathetake Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes and this is the standard pattern and naming consistency is ensured by the linter in CI. Btw, this is not only here but also across the whole bunch of k8s related projects. you see the API package name v1alpha1 exists across tons of projects so not renaming is terrible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not rename to egv1alpha1?

"github.com/go-logr/logr"
"github.com/golang-jwt/jwt/v4"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
corev1 "k8s.io/api/core/v1"
"sigs.k8s.io/controller-runtime/pkg/client"

aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1"
)

// OIDC expects the SecretKey to be "client-secret".
const secretKey = "client-secret"

// IdentityTokenValue is for retrieving an identity token
type IdentityTokenValue string

// GetIdentityToken retrieves the JWT and returns the contents as a []byte
func (j IdentityTokenValue) GetIdentityToken() ([]byte, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider renaming function to something like IdentityTokenAsByteArray to make it more accurate for what the function does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetIdentityToken is the official function for IdentityTokenRetriever's interface

// IdentityTokenRetriever is an interface for retrieving a JWT
type IdentityTokenRetriever interface {
	GetIdentityToken() ([]byte, error)
}

But I do agree it's not very clear

return []byte(j), nil
}

type oauth2TokenWithExp struct {
token *oauth2.Token
expTime time.Time
}

type Handler struct {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since Handler has meaning in Go and becomes vague in this context. Consider renaming to be more specific, consider something descriptive like TokenRefresher

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TokenRefresher makes sense if the "handler" only managed refreshing OIDC tokens. What do you think about CredentialRefresher since it updates the OIDC token and updates/creates a secret file with new logins

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does both token exchange and refresh, we discussed to have a generic interface to support both aws and gcp which we are going to add very soon. The steps are the same to get the oidc token to exchange the cloud credential, then update the credential file. The token refresher go routine updates the credential before it expires.

logger *logr.Logger
k8sClient client.Client
// awsCredentialCache cache key is backend security policy's namespace + name.
awsCredentialCache map[string]time.Time
// oidcCredentialCache cache key is backend security policy's namespace + name.
oidcCredentialCache map[string]*oauth2TokenWithExp
interval time.Duration
}

func NewOIDCHandler(logger *logr.Logger, k8sClient client.Client) (*Handler, error) {
handler := &Handler{
logger: logger,
k8sClient: k8sClient,
awsCredentialCache: make(map[string]time.Time),
oidcCredentialCache: make(map[string]*oauth2TokenWithExp),
interval: time.Minute,
}
return handler, nil
}

func (o *Handler) UpdateCredentials(ctx context.Context) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like an infite loop, could we spawn a process instead? If we spawn it as a process it can be terminated and should make the code cleaner.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming you're referring to this process?

Correct me if I'm misunderstanding you, but would this mean we create an executable for the OIDC credential update and have the process handle updating?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simple go routine based timer should work, we do not need a process.

for {
backendSecurityPolicies := &aigv1a1.BackendSecurityPolicyList{}
if err := o.k8sClient.List(context.Background(), backendSecurityPolicies); err != nil {
o.logger.Error(err, "Failed to get backend security policies")
}

for _, backendSecurityPolicy := range backendSecurityPolicies.Items {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider making this a separate function that is called if the backendSecurityPolicies is set. putting this for loop a callable function instead will make it more testable and self contained.

// Only AWS Credentials currently supports OIDC
if backendSecurityPolicy.Spec.Type != aigv1a1.BackendSecurityPolicyTypeAWSCredentials {
continue
}
if backendSecurityPolicy.Spec.AWSCredentials.CredentialsFile != nil {
continue
}

cacheKey := fmt.Sprintf("%s.%s", backendSecurityPolicy.Name, backendSecurityPolicy.Namespace)
awsCredentials := backendSecurityPolicy.Spec.AWSCredentials
if o.oidcCredentialCache[cacheKey].token == nil || time.Now().After(o.oidcCredentialCache[cacheKey].expTime.Add(-5*time.Minute)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider making this logic check a function, also consider making values like -5*time.Minute a constant somewhere, as the 5 minutes before expiry should be easily to semantically understand somewhere in the code.

oidcCred := awsCredentials.OIDCExchangeToken.OIDC
oidcAud := awsCredentials.OIDCExchangeToken.Aud
err := o.updateOIDCExpiredToken(ctx, oidcCred, cacheKey, oidcAud, backendSecurityPolicy.Namespace)
if err != nil {
o.logger.Error(err, "Failed to update OIDC token", "BackendSecurityPolicy", backendSecurityPolicy.Name)
}
}

if expiredTime, ok := o.awsCredentialCache[cacheKey]; !ok || time.Now().After(expiredTime.Add(-5*time.Minute)) {
credentials, err := getSTSCredentials(awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, awsCredentials.OIDCExchangeToken.ProxyURL, awsCredentials.OIDCExchangeToken.Aud)
if err != nil {
o.logger.Error(err, "Failed to get sts credentials", "BackendSecurityPolicy", backendSecurityPolicy.Name)
}
err = updateOrCreateAWSSecret(o.k8sClient, credentials, backendSecurityPolicy.Namespace, cacheKey)
if err != nil {
o.logger.Error(err, "Failed to update AWS secret", "BackendSecurityPolicy", backendSecurityPolicy.Name)
}
o.awsCredentialCache[cacheKey] = credentials.Expires
}
}
time.Sleep(o.interval)
}
}

func (o *Handler) extractOauth2Token(ctx context.Context, oidcCreds egv1a1.OIDC, aud, namespace string) (*oauth2.Token, error) {
provider, err := oidc.NewProvider(ctx, oidcCreds.Provider.Issuer)
if err != nil {
return nil, fmt.Errorf("fail to create oidc provider: %w", err)
}
clientSecret, err := o.extractClientSecret(ctx, namespace, string(oidcCreds.ClientSecret.Name))
if err != nil {
return nil, fmt.Errorf("fail to extract client secret: %w", err)
}
oauth2Config := clientcredentials.Config{
ClientID: oidcCreds.ClientID,
ClientSecret: clientSecret,
// Discovery returns the OAuth2 endpoints.
TokenURL: provider.Endpoint().TokenURL,
Scopes: oidcCreds.Scopes,
}
oauth2Config.EndpointParams = url.Values{"audience": []string{aud}}
t, err := oauth2Config.Token(ctx)
if err != nil {
return nil, fmt.Errorf("fail to refresh oauth2 token %w", err)
}
return t, nil
}

func (o *Handler) oauth2TokenExpireTime(accessToken *oauth2.Token) (*time.Time, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be a lot of casting in this function, which seems a bit excessive and unnecessary. Can we get an explanation of why this is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for the two castings:

  1. We set the "implementation" of Claim to be type MapClaims when parsing the access token via ParseUnverified. We later cast it because we want to use its properties of being a map (we essentially are checking the type this way).
  2. The value of claims is an interface, so we want to type cast to an integer for time.Unix to understand it. This is essentially checking the types + casting all at once.

token, _, err := new(jwt.Parser).ParseUnverified(accessToken.AccessToken, jwt.MapClaims{})
if err != nil {
return nil, fmt.Errorf("fail to parse oauth2 token: %v", slog.Any("error", err))
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("fail to parse oauth2 token claims: %v", slog.Any("error", err))
}
exp, ok := claims["exp"].(float64)
if !ok {
return nil, fmt.Errorf("fail to parse oauth2 token exp: %v", slog.Any("error", err))
}
expTime := time.Unix(int64(exp), 0)
return &expTime, nil
}

func (o *Handler) updateOIDCExpiredToken(ctx context.Context, oidcCreds egv1a1.OIDC, cacheKey, aud, namespace string) error {
if _, ok := o.oidcCredentialCache[cacheKey]; ok {
o.oidcCredentialCache[cacheKey] = &oauth2TokenWithExp{}
}

token, err := o.extractOauth2Token(ctx, oidcCreds, aud, namespace)
if err != nil {
return err
}

expireTime, err := o.oauth2TokenExpireTime(token)
if err != nil {
return err
}

o.oidcCredentialCache[cacheKey].token = token
o.oidcCredentialCache[cacheKey].expTime = *expireTime
return nil
}

func (o *Handler) extractClientSecret(ctx context.Context, ns, secretName string) (string, error) {
secret := &corev1.Secret{}
if err := o.k8sClient.Get(ctx, client.ObjectKey{
Namespace: ns,
Name: secretName,
}, secret); err != nil {
return "", fmt.Errorf("failed to get secret %s.%s: %w", secretName, ns, err)
}
clientSecret, ok := secret.Data[secretKey]
if !ok {
return "", fmt.Errorf("missing '%s' in secret %s.%s", secretKey, secret.Name, secret.Namespace)
}
return string(clientSecret), nil
}
Loading
Loading