Skip to content

Commit

Permalink
Restarting rework to address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dricross committed Jan 28, 2025
1 parent 39a5884 commit e890661
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 29 deletions.
15 changes: 9 additions & 6 deletions cfg/aws/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,14 @@ func newStsCredentials(c client.ConfigProvider, roleARN string, region string) *
return credentials.NewCredentials(&stsCredentialProvider{regional: regional, partitional: partitional})
}

const (
SourceArnHeaderKey = "x-amz-source-arn"
SourceAccountHeaderKey = "x-amz-source-account"
)

var (
sourceAccount = os.Getenv(envconfig.AMZ_SOURCE_ACCOUNT) // populates the "x-amz-source-account" header
sourceArn = os.Getenv(envconfig.AMZ_SOURCE_ARN) // populates the "x-amz-source-arn" header
sourceAccount = os.Getenv(envconfig.AmzSourceAccount) // populates the "x-amz-source-account" header
sourceArn = os.Getenv(envconfig.AmzSourceArn) // populates the "x-amz-source-arn" header
)

// newStsClient creates a new STS client with the provided config and options.
Expand All @@ -221,10 +226,8 @@ func newStsClient(p client.ConfigProvider, cfgs ...*aws.Config) *sts.STS {
client := sts.New(p, cfgs...)
if sourceAccount != "" && sourceArn != "" {
client.Handlers.Sign.PushFront(func(r *request.Request) {
r.ApplyOptions(request.WithSetRequestHeaders(map[string]string{
"x-amz-source-arn": sourceArn,
"x-amz-source-account": sourceAccount,
}))
r.HTTPRequest.Header.Set(SourceArnHeaderKey, sourceArn)
r.HTTPRequest.Header.Set(SourceAccountHeaderKey, sourceAccount)
})

log.Printf("I! Found confused deputy header environment variables: source account: %q, source arn: %q", sourceAccount, sourceArn)
Expand Down
30 changes: 9 additions & 21 deletions cfg/aws/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,14 @@ import (
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/awstesting/mock"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type mockConfigProvider struct{}

func (m mockConfigProvider) ClientConfig(serviceName string, cfgs ...*aws.Config) client.Config {
return client.Config{
Config: &aws.Config{
// These are examples credentials pulled from:
// https://docs.aws.amazon.com/STS/latest/APIReference/API_GetAccessKeyInfo.html
Credentials: credentials.NewStaticCredentials("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", ""),
Region: aws.String("us-east-1"),
},
}
}

func TestConfusedDeputyHeaders(t *testing.T) {
mockProvider := mockConfigProvider{}

tests := []struct {
name string
envSourceArn string
Expand Down Expand Up @@ -72,16 +57,19 @@ func TestConfusedDeputyHeaders(t *testing.T) {
sourceArn = tt.envSourceArn
sourceAccount = tt.envSourceAccount

client := newStsClient(mockProvider)
client := newStsClient(mock.Session, &aws.Config{
// These are examples credentials pulled from:
// https://docs.aws.amazon.com/STS/latest/APIReference/API_GetAccessKeyInfo.html
Credentials: credentials.NewStaticCredentials("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", ""),
Region: aws.String("us-east-1"),
})

// Generate the assume role request, but do not actually send it
// We don't need this unit test making real AWS calls
request, _ := client.AssumeRoleRequest(&sts.AssumeRoleInput{
// We aren't going to actually make the assume role call, we are just going
// to verify the headers are present once signed so the RoleArn and RoleSessionName
// arguments are irrelevant. Fill them out with something so the request is valid.
RoleArn: aws.String("XXXXXXX"),
RoleSessionName: aws.String("XXXXXXX"),
RoleArn: aws.String("arn:aws:iam::012345678912:role/XXXXXXXX"),
RoleSessionName: aws.String("MockSession"),
})

// Headers are generated after the request is signed (but before it's sent)
Expand Down
4 changes: 2 additions & 2 deletions cfg/envconfig/envconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ const (
CWAgentMergedOtelConfig = "CWAGENT_MERGED_OTEL_CONFIG"

// confused deputy prevention related headers
AMZ_SOURCE_ACCOUNT = "AMZ_SOURCE_ACCOUNT" // populates the "x-amz-source-account" header
AMZ_SOURCE_ARN = "AMZ_SOURCE_ARN" // populates the "x-amz-source-arn" header
AmzSourceAccount = "AMZ_SOURCE_ACCOUNT" // populates the "x-amz-source-account" header
AmzSourceArn = "AMZ_SOURCE_ARN" // populates the "x-amz-source-arn" header
)

const (
Expand Down

0 comments on commit e890661

Please sign in to comment.