From e8906618a77e442d5ad93afbc7d66ebb78113e0e Mon Sep 17 00:00:00 2001 From: Rick Rossi Date: Tue, 21 Jan 2025 09:22:22 -0500 Subject: [PATCH] Restarting rework to address comments --- cfg/aws/credentials.go | 15 +++++++++------ cfg/aws/credentials_test.go | 30 +++++++++--------------------- cfg/envconfig/envconfig.go | 4 ++-- 3 files changed, 20 insertions(+), 29 deletions(-) diff --git a/cfg/aws/credentials.go b/cfg/aws/credentials.go index f1e22209ef..e8bd986dc1 100644 --- a/cfg/aws/credentials.go +++ b/cfg/aws/credentials.go @@ -205,9 +205,14 @@ func newStsCredentials(c client.ConfigProvider, roleARN string, region string) * return credentials.NewCredentials(&stsCredentialProvider{regional: regional, partitional: partitional}) } +const ( + SourceArnHeaderKey = "x-amz-source-arn" + SourceAccountHeaderKey = "x-amz-source-account" +) + var ( - sourceAccount = os.Getenv(envconfig.AMZ_SOURCE_ACCOUNT) // populates the "x-amz-source-account" header - sourceArn = os.Getenv(envconfig.AMZ_SOURCE_ARN) // populates the "x-amz-source-arn" header + sourceAccount = os.Getenv(envconfig.AmzSourceAccount) // populates the "x-amz-source-account" header + sourceArn = os.Getenv(envconfig.AmzSourceArn) // populates the "x-amz-source-arn" header ) // newStsClient creates a new STS client with the provided config and options. @@ -221,10 +226,8 @@ func newStsClient(p client.ConfigProvider, cfgs ...*aws.Config) *sts.STS { client := sts.New(p, cfgs...) if sourceAccount != "" && sourceArn != "" { client.Handlers.Sign.PushFront(func(r *request.Request) { - r.ApplyOptions(request.WithSetRequestHeaders(map[string]string{ - "x-amz-source-arn": sourceArn, - "x-amz-source-account": sourceAccount, - })) + r.HTTPRequest.Header.Set(SourceArnHeaderKey, sourceArn) + r.HTTPRequest.Header.Set(SourceAccountHeaderKey, sourceAccount) }) log.Printf("I! Found confused deputy header environment variables: source account: %q, source arn: %q", sourceAccount, sourceArn) diff --git a/cfg/aws/credentials_test.go b/cfg/aws/credentials_test.go index b7c999dae3..39aa0188de 100644 --- a/cfg/aws/credentials_test.go +++ b/cfg/aws/credentials_test.go @@ -7,29 +7,14 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/awstesting/mock" "github.com/aws/aws-sdk-go/service/sts" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -type mockConfigProvider struct{} - -func (m mockConfigProvider) ClientConfig(serviceName string, cfgs ...*aws.Config) client.Config { - return client.Config{ - Config: &aws.Config{ - // These are examples credentials pulled from: - // https://docs.aws.amazon.com/STS/latest/APIReference/API_GetAccessKeyInfo.html - Credentials: credentials.NewStaticCredentials("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", ""), - Region: aws.String("us-east-1"), - }, - } -} - func TestConfusedDeputyHeaders(t *testing.T) { - mockProvider := mockConfigProvider{} - tests := []struct { name string envSourceArn string @@ -72,16 +57,19 @@ func TestConfusedDeputyHeaders(t *testing.T) { sourceArn = tt.envSourceArn sourceAccount = tt.envSourceAccount - client := newStsClient(mockProvider) + client := newStsClient(mock.Session, &aws.Config{ + // These are examples credentials pulled from: + // https://docs.aws.amazon.com/STS/latest/APIReference/API_GetAccessKeyInfo.html + Credentials: credentials.NewStaticCredentials("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", ""), + Region: aws.String("us-east-1"), + }) - // Generate the assume role request, but do not actually send it - // We don't need this unit test making real AWS calls request, _ := client.AssumeRoleRequest(&sts.AssumeRoleInput{ // We aren't going to actually make the assume role call, we are just going // to verify the headers are present once signed so the RoleArn and RoleSessionName // arguments are irrelevant. Fill them out with something so the request is valid. - RoleArn: aws.String("XXXXXXX"), - RoleSessionName: aws.String("XXXXXXX"), + RoleArn: aws.String("arn:aws:iam::012345678912:role/XXXXXXXX"), + RoleSessionName: aws.String("MockSession"), }) // Headers are generated after the request is signed (but before it's sent) diff --git a/cfg/envconfig/envconfig.go b/cfg/envconfig/envconfig.go index c1e5d4adeb..afbf4918de 100644 --- a/cfg/envconfig/envconfig.go +++ b/cfg/envconfig/envconfig.go @@ -34,8 +34,8 @@ const ( CWAgentMergedOtelConfig = "CWAGENT_MERGED_OTEL_CONFIG" // confused deputy prevention related headers - AMZ_SOURCE_ACCOUNT = "AMZ_SOURCE_ACCOUNT" // populates the "x-amz-source-account" header - AMZ_SOURCE_ARN = "AMZ_SOURCE_ARN" // populates the "x-amz-source-arn" header + AmzSourceAccount = "AMZ_SOURCE_ACCOUNT" // populates the "x-amz-source-account" header + AmzSourceArn = "AMZ_SOURCE_ARN" // populates the "x-amz-source-arn" header ) const (