diff --git a/.build-tools/builtin-authentication-profiles.yaml b/.build-tools/builtin-authentication-profiles.yaml index 4970fd67bc..2abe6ac3df 100644 --- a/.build-tools/builtin-authentication-profiles.yaml +++ b/.build-tools/builtin-authentication-profiles.yaml @@ -47,6 +47,14 @@ aws: ARN of the AWS IAM role to assume in the trusting AWS account. example: arn:aws:iam:012345678910:role/exampleIAMRoleName required: true + - name: sessionDuration + type: duration + description: | + Duration of the session using AWS IAM Roles Anywhere. + If set to 0m, temporary credentials will automatically rotate. + default: '15m' + example: '0m' + required: true azuread: - title: "Azure AD: Managed identity" diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index 540b33a5ff..6d1d54d483 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -29,6 +29,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3manager" @@ -108,6 +109,26 @@ func NewAWSS3(logger logger.Logger) bindings.OutputBinding { return &AWSS3{logger: logger} } +func (s *AWSS3) getAWSConfig(awsA *awsAuth.AWS) *aws.Config { + cfg := awsA.GetConfig().WithS3ForcePathStyle(s.metadata.ForcePathStyle).WithDisableSSL(s.metadata.DisableSSL) + + // Use a custom HTTP client to allow self-signed certs + if s.metadata.InsecureSSL { + customTransport := http.DefaultTransport.(*http.Transport).Clone() + customTransport.TLSClientConfig = &tls.Config{ + //nolint:gosec + InsecureSkipVerify: true, + } + client := &http.Client{ + Transport: customTransport, + } + cfg = cfg.WithHTTPClient(client) + + s.logger.Infof("aws s3: you are using 'insecureSSL' to skip server config verify which is unsafe!") + } + return cfg +} + // Init does metadata parsing and connection creation. func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { m, err := s.parseMetadata(metadata) @@ -116,46 +137,40 @@ func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { } if s.s3Client == nil { + awsA, err := awsAuth.New(awsAuth.Options{ Logger: s.logger, Properties: metadata.Properties, Region: m.Region, + Endpoint: m.Endpoint, AccessKey: m.AccessKey, SecretKey: m.SecretKey, SessionToken: m.SessionToken, - Endpoint: m.Endpoint, }) if err != nil { return err } - - session, err := awsA.GetClient(ctx) + // initiate clients, before refreshing if needed + sess, err := awsA.GetClient(ctx) if err != nil { return err } - cfg := aws.NewConfig(). - WithS3ForcePathStyle(m.ForcePathStyle). - WithDisableSSL(m.DisableSSL) - - // Use a custom HTTP client to allow self-signed certs - if m.InsecureSSL { - customTransport := http.DefaultTransport.(*http.Transport).Clone() - customTransport.TLSClientConfig = &tls.Config{ - //nolint:gosec - InsecureSkipVerify: true, - } - client := &http.Client{ - Transport: customTransport, - } - cfg = cfg.WithHTTPClient(client) - - s.logger.Infof("aws s3: you are using 'insecureSSL' to skip server config verify which is unsafe!") - } - - s.s3Client = s3.New(session, cfg) + s.metadata = m + s.s3Client = s3.New(sess, s.getAWSConfig(awsA)) s.downloader = s3manager.NewDownloaderWithClient(s.s3Client) s.uploader = s3manager.NewUploaderWithClient(s.s3Client) + + go func() { + for { + select { + case refreshSession := <-awsA.GetSessionUpdateChannel(): + s.updateAWSClients(refreshSession, s.getAWSConfig(awsA)) + case <-ctx.Done(): + return + } + } + }() } s.metadata = m @@ -163,6 +178,12 @@ func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { return nil } +func (s *AWSS3) updateAWSClients(session *session.Session, cfgs *aws.Config) { + s.s3Client = s3.New(session, cfgs) + s.downloader = s3manager.NewDownloaderWithClient(s.s3Client) + s.uploader = s3manager.NewUploaderWithClient(s.s3Client) +} + func (s *AWSS3) Close() error { return nil } diff --git a/common/authentication/aws/aws.go b/common/authentication/aws/aws.go index f88fd8e912..f55eba5930 100644 --- a/common/authentication/aws/aws.go +++ b/common/authentication/aws/aws.go @@ -37,6 +37,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" awssh "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere/rolesanywhereiface" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" @@ -51,6 +52,90 @@ type EnvironmentSettings struct { Metadata map[string]string } +type AWS struct { + mu sync.RWMutex + logger logger.Logger + + x509Auth *x509Auth + + region string + endpoint string + accessKey string + secretKey string + sessionToken string +} + +type AWSIAM struct { + // Ignored by metadata parser because included in built-in authentication profile + // Access key to use for accessing PostgreSQL. + AWSAccessKey string `json:"awsAccessKey" mapstructure:"awsAccessKey"` + // Secret key to use for accessing PostgreSQL. + AWSSecretKey string `json:"awsSecretKey" mapstructure:"awsSecretKey"` + // AWS region in which PostgreSQL is deployed. + AWSRegion string `json:"awsRegion" mapstructure:"awsRegion"` +} + +type Options struct { + Logger logger.Logger + Properties map[string]string + + PoolConfig *pgxpool.Config `json:"poolConfig" mapstructure:"poolConfig"` + ConnectionString string `json:"connectionString" mapstructure:"connectionString"` + Region string `json:"region" mapstructure:"region"` + AccessKey string `json:"accessKey" mapstructure:"accessKey"` + SecretKey string `json:"secretKey" mapstructure:"secretKey"` + SessionToken string `json:"sessionToken" mapstructure:"sessionToken"` + Endpoint string `json:"endpoint" mapstructure:"endpoint"` +} + +type x509Auth struct { + TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` + TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` + AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` + SessionDuration *time.Duration `json:"sessionDuration" mapstructure:"sessionDuration"` + sessionExpiration time.Time + + chainPEM []byte + keyPEM []byte + + sessionUpdateChannel chan *session.Session + + rolesAnywhereClient rolesanywhereiface.RolesAnywhereAPI +} + +func New(opts Options) (*AWS, error) { + var x509AuthConfig x509Auth + if err := kitmd.DecodeMetadata(opts.Properties, &x509AuthConfig); err != nil { + return nil, err + } + + return &AWS{ + x509Auth: &x509AuthConfig, + logger: opts.Logger, + region: opts.Region, + accessKey: opts.AccessKey, + secretKey: opts.SecretKey, + sessionToken: opts.SessionToken, + endpoint: opts.Endpoint, + }, nil +} + +func (a *AWS) GetConfig() *aws.Config { + cfg := aws.NewConfig() + + if a.region != "" { + cfg.WithRegion(a.region) + } + + return cfg +} + +func (a *AWS) GetSessionUpdateChannel() chan *session.Session { + a.mu.Lock() + defer a.mu.Unlock() + return a.x509Auth.sessionUpdateChannel +} + func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) { optFns := []func(*config.LoadOptions) error{} if region != "" { @@ -75,47 +160,116 @@ func GetConfigV2(accessKey string, secretKey string, sessionToken string, region } func (a *AWS) GetClient(ctx context.Context) (*session.Session, error) { - a.lock.Lock() - defer a.lock.Unlock() + a.mu.Lock() + defer a.mu.Unlock() switch { // IAM Roles Anywhere option case a.x509Auth.TrustAnchorArn != nil && a.x509Auth.AssumeRoleArn != nil: a.logger.Debug("using X.509 RolesAnywhere authentication using Dapr SVID") - return a.getX509Client(ctx) + session, err := a.getX509Client(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create X.509 RolesAnywhere client") + } + // start a session refresher background goroutine to keep rotating the temporary creds + // use background context to keep alive + go a.startSessionRefresher(context.Background()) + + return session, nil default: a.logger.Debugf("using AWS session client...") - return a.getSessionClient() + return a.getTokenClient() } } -func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { +func (a *AWS) getTokenClient() (*session.Session, error) { + awsConfig := aws.NewConfig() + + if a.region != "" { + awsConfig = awsConfig.WithRegion(a.region) + } + + if a.accessKey != "" && a.secretKey != "" { + awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(a.accessKey, a.secretKey, a.sessionToken)) + } + + if a.endpoint != "" { + awsConfig = awsConfig.WithEndpoint(a.endpoint) + } + + awsSession, err := session.NewSessionWithOptions(session.Options{ + Config: *awsConfig, + SharedConfigState: session.SharedConfigEnable, + }) + if err != nil { + return nil, err + } + + userAgentHandler := request.NamedHandler{ + Name: "UserAgentHandler", + Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion), + } + awsSession.Handlers.Build.PushBackNamed(userAgentHandler) + + return awsSession, nil +} + +func (a *AWS) getCertPEM(ctx context.Context) error { // retrieve svid from spiffe context svid, ok := spiffecontext.From(ctx) if !ok { - return nil, errors.New("no SVID found in context") + return errors.New("no SVID found in context") } // get x.509 svid svidx, err := svid.GetX509SVID() if err != nil { - return nil, err + return err } // marshal x.509 svid to pem format chainPEM, keyPEM, err := svidx.Marshal() if err != nil { - return nil, fmt.Errorf("failed to marshal SVID: %w", err) + return fmt.Errorf("failed to marshal SVID: %w", err) + } + + a.x509Auth.chainPEM = chainPEM + a.x509Auth.keyPEM = keyPEM + return nil +} + +func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { + // retrieve svid from spiffe context + err := a.getCertPEM(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get x.509 credentials: %v", err) + } + + if err := a.initializeTrustAnchors(); err != nil { + return nil, err + } + + if err := a.initializeRolesAnywhereClient(); err != nil { + return nil, err + } + + err = a.createOrRefreshSession(ctx) + if err != nil { + return nil, fmt.Errorf("failed to refresh token for new session client") } + return a.getTokenClient() +} + +func (a *AWS) initializeTrustAnchors() error { var ( trustAnchor arn.ARN profile arn.ARN + err error ) - if a.x509Auth.TrustAnchorArn != nil { trustAnchor, err = arn.Parse(*a.x509Auth.TrustAnchorArn) if err != nil { - return nil, err + return err } a.region = trustAnchor.Region } @@ -123,104 +277,176 @@ func (a *AWS) getX509Client(ctx context.Context) (*session.Session, error) { if a.x509Auth.TrustProfileArn != nil { profile, err = arn.Parse(*a.x509Auth.TrustProfileArn) if err != nil { - return nil, err + return err } + if profile.Region != "" && trustAnchor.Region != profile.Region { - return nil, fmt.Errorf("trust anchor and profile must be in the same region: trustAnchor=%s, profile=%s", + return fmt.Errorf("trust anchor and profile must be in the same region: trustAnchor=%s, profile=%s", trustAnchor.Region, profile.Region) } } + return nil +} - mySession, err := session.NewSession() - if err != nil { - return nil, err +func (a *AWS) initializeRolesAnywhereClient() error { + if a.x509Auth.rolesAnywhereClient == nil { + client := &http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + }} + mySession, err := session.NewSession() + if err != nil { + return err + } + config := aws.NewConfig().WithRegion(a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) + rolesAnywhereClient := rolesanywhere.New(mySession, config) + + // Set up signing function and handlers + if err := a.setSigningFunction(rolesAnywhereClient); err != nil { + return err + } + a.x509Auth.rolesAnywhereClient = rolesAnywhereClient } - client := &http.Client{Transport: &http.Transport{ - TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, - }} - config := aws.NewConfig().WithRegion(trustAnchor.Region).WithHTTPClient(client).WithLogLevel(aws.LogOff) - rolesAnywhereClient := rolesanywhere.New(mySession, config) - certs, err := cryptopem.DecodePEMCertificatesChain(chainPEM) + return nil + +} + +func (a *AWS) setSigningFunction(rolesAnywhereClient *rolesanywhere.RolesAnywhere) error { + certs, err := cryptopem.DecodePEMCertificatesChain(a.x509Auth.chainPEM) if err != nil { - return nil, err + return err } - ints := make([]x509.Certificate, len(certs)-1) + var ints []x509.Certificate for i := range certs[1:] { - ints[i] = *certs[i+1] + ints = append(ints, *certs[i+1]) } - key, err := cryptopem.DecodePEMPrivateKey(keyPEM) + key, err := cryptopem.DecodePEMPrivateKey(a.x509Auth.keyPEM) if err != nil { - return nil, err + return err } keyECDSA := key.(*ecdsa.PrivateKey) signFunc := awssh.CreateSignFunction(*keyECDSA, *certs[0], ints) - agentHandlerFunc := request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) + agentHandlerFunc := request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler") rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: agentHandlerFunc}) rolesAnywhereClient.Handlers.Sign.Clear() rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: signFunc}) - // TODO: make metadata field? - var duration int64 = 10000 - createSessionRequest := rolesanywhere.CreateSessionInput{ - Cert: ptr.Of(string(chainPEM)), - ProfileArn: a.x509Auth.TrustProfileArn, - TrustAnchorArn: a.x509Auth.TrustAnchorArn, - RoleArn: a.x509Auth.AssumeRoleArn, - DurationSeconds: &duration, - InstanceProperties: nil, - SessionName: nil, - } - output, err := rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) - if err != nil { - return nil, fmt.Errorf("failed to create session using dapr app dentity: %w", err) + return nil +} + +func (a *AWS) createOrRefreshSession(ctx context.Context) error { + var ( + duration int64 + createSessionRequest rolesanywhere.CreateSessionInput + ) + + if *a.x509Auth.SessionDuration != 0 { + duration = int64(a.x509Auth.SessionDuration.Seconds()) + + createSessionRequest = rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(a.x509Auth.chainPEM)), + ProfileArn: a.x509Auth.TrustProfileArn, + TrustAnchorArn: a.x509Auth.TrustAnchorArn, + RoleArn: a.x509Auth.AssumeRoleArn, + DurationSeconds: aws.Int64(duration), + InstanceProperties: nil, + SessionName: nil, + } + } else { + duration = 900 // 15 minutes in seconds by default and be autorefreshed + + createSessionRequest = rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(a.x509Auth.chainPEM)), + ProfileArn: a.x509Auth.TrustProfileArn, + TrustAnchorArn: a.x509Auth.TrustAnchorArn, + RoleArn: a.x509Auth.AssumeRoleArn, + DurationSeconds: aws.Int64(duration), + InstanceProperties: nil, + SessionName: nil, + } } + output, err := a.x509Auth.rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) + if err != nil { + return fmt.Errorf("failed to create session using dapr app identity: %w", err) + } if len(output.CredentialSet) != 1 { - return nil, fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) + return fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) } a.accessKey = *output.CredentialSet[0].Credentials.AccessKeyId a.secretKey = *output.CredentialSet[0].Credentials.SecretAccessKey a.sessionToken = *output.CredentialSet[0].Credentials.SessionToken - return a.getSessionClient() -} - -func (a *AWS) getSessionClient() (*session.Session, error) { - awsConfig := aws.NewConfig() - - if a.region != "" { - awsConfig = awsConfig.WithRegion(a.region) + // convert expiration time from *string to time.Time + expirationStr := output.CredentialSet[0].Credentials.Expiration + if expirationStr == nil { + return fmt.Errorf("expiration time is nil") } - if a.accessKey != "" && a.secretKey != "" { - awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(a.accessKey, a.secretKey, a.sessionToken)) + expirationTime, err := time.Parse(time.RFC3339, *expirationStr) + if err != nil { + return fmt.Errorf("failed to parse expiration time: %w", err) } - if a.endpoint != "" { - awsConfig = awsConfig.WithEndpoint(a.endpoint) - } + a.x509Auth.sessionExpiration = expirationTime - awsSession, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) - if err != nil { - return nil, err - } + return nil +} - userAgentHandler := request.NamedHandler{ - Name: "UserAgentHandler", - Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion), +func (a *AWS) startSessionRefresher(ctx context.Context) error { + a.logger.Debugf("starting session refresher for x509 auth") + // if there is a set session duration, then exit bc we will not auto refresh the session. + if *a.x509Auth.SessionDuration != 0 { + return nil } - awsSession.Handlers.Build.PushBackNamed(userAgentHandler) - return awsSession, nil + errChan := make(chan error, 1) + go func() { + ticker := time.NewTicker(2 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + remaining := time.Until(a.x509Auth.sessionExpiration) + if remaining <= 8*time.Minute { + a.logger.Infof("Refreshing session as expiration is within %v", remaining) + + // Refresh the session + err := a.createOrRefreshSession(ctx) + if err != nil { + errChan <- fmt.Errorf("failed to refresh session: %w", err) + return + } + + a.logger.Debugf("AWS IAM Roles Anywhere session refreshed successfully") + refreshedSession, err := a.getTokenClient() + if err != nil { + errChan <- fmt.Errorf("failed to get token client with refreshed credentials: %v", err) + return + } + a.x509Auth.sessionUpdateChannel <- refreshedSession + + } + case <-ctx.Done(): + a.logger.Infof("Session refresher stopped due to context cancellation") + errChan <- nil + return + } + } + }() + + select { + case err := <-errChan: + return err + case <-ctx.Done(): + return ctx.Err() + } } // NewEnvironmentSettings returns a new EnvironmentSettings configured for a given AWS resource. @@ -232,70 +458,6 @@ func NewEnvironmentSettings(md map[string]string) (EnvironmentSettings, error) { return es, nil } -type AWS struct { - lock sync.RWMutex - logger logger.Logger - - x509Auth *x509Auth - - region string - endpoint string - accessKey string - secretKey string - sessionToken string -} - -type AWSIAM struct { - // Ignored by metadata parser because included in built-in authentication profile - // Access key to use for accessing PostgreSQL. - AWSAccessKey string `json:"awsAccessKey" mapstructure:"awsAccessKey"` - // Secret key to use for accessing PostgreSQL. - AWSSecretKey string `json:"awsSecretKey" mapstructure:"awsSecretKey"` - // AWS region in which PostgreSQL is deployed. - AWSRegion string `json:"awsRegion" mapstructure:"awsRegion"` -} - -type Options struct { - Logger logger.Logger - Properties map[string]string - - PoolConfig *pgxpool.Config `json:"poolConfig" mapstructure:"poolConfig"` - ConnectionString string `json:"connectionString" mapstructure:"connectionString"` - Region string `json:"region" mapstructure:"region"` - AccessKey string `json:"accessKey" mapstructure:"accessKey"` - SecretKey string `json:"secretKey" mapstructure:"secretKey"` - SessionToken string `json:"sessionToken" mapstructure:"sessionToken"` - Endpoint string `json:"endpoint" mapstructure:"endpoint"` -} - -type x509Auth struct { - TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` - TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` - AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` -} - -func New(opts Options) (*AWS, error) { - var x509AuthConfig x509Auth - if err := kitmd.DecodeMetadata(opts.Properties, &x509AuthConfig); err != nil { - return nil, err - } - if x509AuthConfig.AssumeRoleArn != nil { - opts.Logger.Infof("sam x509 fields %s %s ", *x509AuthConfig.AssumeRoleArn, *x509AuthConfig.TrustAnchorArn) - } else { - opts.Logger.Infof("sam still nil somehow...") - } - - return &AWS{ - x509Auth: &x509AuthConfig, - logger: opts.Logger, - region: opts.Region, - accessKey: opts.AccessKey, - secretKey: opts.SecretKey, - sessionToken: opts.SessionToken, - endpoint: opts.Endpoint, - }, nil -} - func (opts *Options) GetAccessToken(ctx context.Context) (string, error) { dbEndpoint := opts.PoolConfig.ConnConfig.Host + ":" + strconv.Itoa(int(opts.PoolConfig.ConnConfig.Port)) var authenticationToken string