From a9193b4d2e7d21943bfed1bf88de761708d9c02f Mon Sep 17 00:00:00 2001 From: Michael Shen Date: Tue, 8 Oct 2024 09:15:19 -0400 Subject: [PATCH 1/3] Setup AWS KMS client with aws-sdk-go-v2 Signed-off-by: Michael Shen --- go.mod | 14 +++++++++++ go.sum | 28 +++++++++++++++++++++ pkg/cloud/cloud.go | 49 ++++++++++++++++++------------------ pkg/cloud/mock.go | 10 ++++---- pkg/healthz/healthz_test.go | 10 ++++---- pkg/plugin/plugin.go | 20 +++++++-------- pkg/plugin/plugin_v2.go | 22 ++++++++-------- pkg/plugin/plugin_v2_test.go | 2 +- 8 files changed, 99 insertions(+), 56 deletions(-) diff --git a/go.mod b/go.mod index f6d31756..d9fd7ae2 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,9 @@ go 1.22.2 require ( github.com/aws/aws-sdk-go v1.54.6 + github.com/aws/aws-sdk-go-v2 v1.32.1 + github.com/aws/aws-sdk-go-v2/config v1.27.42 + github.com/aws/aws-sdk-go-v2/service/kms v1.37.1 github.com/prometheus/client_golang v1.14.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.0 @@ -14,6 +17,17 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2/credentials v1.17.40 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.20 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.20 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.1 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.32.1 // indirect + github.com/aws/smithy-go v1.22.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/go.sum b/go.sum index df5f0644..9fee3704 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,34 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g= github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.32.1 h1:8WuZ43ytA+TV6QEPT/R23mr7pWyI7bSSiEHdt9BS2Pw= +github.com/aws/aws-sdk-go-v2 v1.32.1/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= +github.com/aws/aws-sdk-go-v2/config v1.27.42 h1:Zsy9coUPuOsCWkjTvHpl2/DB9bptXtv7WeNPxvFr87s= +github.com/aws/aws-sdk-go-v2/config v1.27.42/go.mod h1:FGASs+PuJM2EY+8rt8qyQKLPbbX/S5oY+6WzJ/KE7ko= +github.com/aws/aws-sdk-go-v2/credentials v1.17.40 h1:RjnlA7t0p/IamxAM7FUJ5uS13Vszh4sjVGvsx91tGro= +github.com/aws/aws-sdk-go-v2/credentials v1.17.40/go.mod h1:dgpdnSs1Bp/atS6vLlW83h9xZPP+uSPB/27dFSgC1BM= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.16 h1:fwrer1pJeaiia0CcOfWVbZxvj9Adc7rsuaMTwPR0DIA= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.16/go.mod h1:XyEwwp8XI4zMar7MTnJ0Sk7qY/9aN8Hp929XhuX5SF8= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.20 h1:OErdlGnt+hg3tTwGYAlKvFkKVUo/TXkoHcxDxuhYYU8= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.20/go.mod h1:HsPfuL5gs+407ByRXBMgpYoyrV1sgMrzd18yMXQHJpo= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.20 h1:822cE1CYSwY/EZnErlF46pyynuxvf1p+VydHRQW+XNs= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.20/go.mod h1:79/Tn7H7hYC5Gjz6fbnOV4OeBpkao7E8Tv95RO72pMM= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 h1:TToQNkvGguu209puTojY/ozlqy2d/SFNcoLIqTFi42g= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0/go.mod h1:0jp+ltwkf+SwG2fm/PKo8t4y8pJSgOCO4D8Lz3k0aHQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.1 h1:5vBMBTakOvtd8aNaicswcrr9qqCYUlasuzyoU6/0g8I= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.1/go.mod h1:WSUbDa5qdg05Q558KXx2Scb+EDvOPXT9gfET0fyrJSk= +github.com/aws/aws-sdk-go-v2/service/kms v1.37.1 h1:XbpPk8TZ8FZ+Q1B2bDiI2/w9nzVoYu22+RHvi2nchVo= +github.com/aws/aws-sdk-go-v2/service/kms v1.37.1/go.mod h1:eA3st65Rlr+sUU5bVJdeMpbRQ+xffVBh7Sx+G4M+NJs= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.1 h1:aAIr0WhAgvKrxZtkBqne87Gjmd7/lJVTFkR2l2yuhL8= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.1/go.mod h1:8XhxGMWUfikJuginPQl5SGZ0LSJuNX3TCEQmFWZwHTM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.1 h1:J6kIsIkgFOaU6aKjigXJoue1XEHtKIIrpSh4vKdmRTs= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.1/go.mod h1:2V2JLP7tXOmUbL3Hd1ojq+774t2KUAEQ35//shoNEL0= +github.com/aws/aws-sdk-go-v2/service/sts v1.32.1 h1:q76Ig4OaJzVJGNUSGO3wjSTBS94g+EhHIbpY9rPvkxs= +github.com/aws/aws-sdk-go-v2/service/sts v1.32.1/go.mod h1:664dajZ7uS7JMUMUG0R5bWbtN97KECNCVdFDdQ6Ipu8= +github.com/aws/smithy-go v1.22.0 h1:uunKnWlcoL3zO7q+gG2Pk53joueEOsnNB28QdMsmiMM= +github.com/aws/smithy-go v1.22.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index b8cd489a..7a8f4d44 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -14,40 +14,41 @@ limitations under the License. package cloud import ( + "context" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/ec2metadata" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kms" - "github.com/aws/aws-sdk-go/service/kms/kmsiface" - "sigs.k8s.io/aws-encryption-provider/pkg/httputil" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/ratelimit" + "github.com/aws/aws-sdk-go-v2/aws/retry" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/kms" ) -type AWSKMS struct { - kmsiface.KMSAPI +type AWSKMSv2 interface { + Encrypt(ctx context.Context, params *kms.EncryptInput, optFns ...func(*kms.Options)) (*kms.EncryptOutput, error) + Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error) } -func New(region, kmsEndpoint string, qps, burst int) (*AWSKMS, error) { - sess, err := session.NewSession() +func New(region, kmsEndpoint string, qps, burst int) (AWSKMSv2, error) { + cfg, err := config.LoadDefaultConfig(context.Background()) if err != nil { return nil, fmt.Errorf("failed to create new session: %w", err) } - if region == "" { - region, err = ec2metadata.New(sess).Region() - if err != nil { - return nil, fmt.Errorf("failed to call the metadata server's region API, %v", err) - } - } - cfg := &aws.Config{ - Region: aws.String(region), - CredentialsChainVerboseErrors: aws.Bool(true), - Endpoint: aws.String(kmsEndpoint), - } + if qps > 0 { - if sess.Config.HTTPClient, err = httputil.NewRateLimitedClient(qps, burst); err != nil { - return nil, err + cfg, err = config.LoadDefaultConfig(context.Background(), config.WithRetryer(func() aws.Retryer { + return retry.NewStandard(func(o *retry.StandardOptions) { + o.RateLimiter = ratelimit.NewTokenRateLimit(uint(qps) * uint(burst)) + }) + })) + if err != nil { + return nil, fmt.Errorf("failed to create new session: %w", err) } } - return &AWSKMS{kms.New(sess, cfg)}, nil + + client := kms.NewFromConfig(cfg, func(o *kms.Options) { + o.Region = region + o.BaseEndpoint = aws.String(kmsEndpoint) + }) + return client, nil } diff --git a/pkg/cloud/mock.go b/pkg/cloud/mock.go index 4c694640..3c110598 100644 --- a/pkg/cloud/mock.go +++ b/pkg/cloud/mock.go @@ -14,14 +14,14 @@ limitations under the License. package cloud import ( + "context" "sync" - "github.com/aws/aws-sdk-go/service/kms" - "github.com/aws/aws-sdk-go/service/kms/kmsiface" + "github.com/aws/aws-sdk-go-v2/service/kms" ) type KMSMock struct { - kmsiface.KMSAPI + AWSKMSv2 mutex sync.Mutex @@ -47,13 +47,13 @@ func (m *KMSMock) SetDecryptResp(dec string, decErr error) *KMSMock { return m } -func (m *KMSMock) Encrypt(input *kms.EncryptInput) (*kms.EncryptOutput, error) { +func (m *KMSMock) Encrypt(ctx context.Context, params *kms.EncryptInput, optFns ...func(*kms.Options)) (*kms.EncryptOutput, error) { m.mutex.Lock() defer m.mutex.Unlock() return m.encOut, m.encErr } -func (m *KMSMock) Decrypt(input *kms.DecryptInput) (*kms.DecryptOutput, error) { +func (m *KMSMock) Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error) { m.mutex.Lock() defer m.mutex.Unlock() return m.decOut, m.decErr diff --git a/pkg/healthz/healthz_test.go b/pkg/healthz/healthz_test.go index be806203..f9d43a7a 100644 --- a/pkg/healthz/healthz_test.go +++ b/pkg/healthz/healthz_test.go @@ -12,8 +12,8 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go-v2/aws" + kmstypes "github.com/aws/aws-sdk-go-v2/service/kms/types" "go.uber.org/zap" "sigs.k8s.io/aws-encryption-provider/pkg/cloud" "sigs.k8s.io/aws-encryption-provider/pkg/plugin" @@ -42,18 +42,18 @@ func TestHealthz(t *testing.T) { { path: "/test-healthz-fail-with-internal-error", - kmsEncryptErr: awserr.New(kms.ErrCodeInternalException, "test", errors.New("fail")), + kmsEncryptErr: &kmstypes.KMSInternalException{Message: aws.String("test")}, shouldSucceed: false, }, // user-induced errors should still fail "/healthz" { path: "/test-healthz-fail-with-user-induced-invalid-key-state", - kmsEncryptErr: awserr.New(kms.ErrCodeInvalidStateException, "test", errors.New("fail")), + kmsEncryptErr: &kmstypes.KMSInvalidStateException{Message: aws.String("test")}, shouldSucceed: false, }, { path: "/test-healthz-fail-with-user-induced-invalid-grant", - kmsEncryptErr: awserr.New(kms.ErrCodeInvalidGrantTokenException, "test", errors.New("fail")), + kmsEncryptErr: &kmstypes.InvalidGrantTokenException{Message: aws.String("test")}, shouldSucceed: false, }, } diff --git a/pkg/plugin/plugin.go b/pkg/plugin/plugin.go index 937e4b5c..b1a6d90b 100644 --- a/pkg/plugin/plugin.go +++ b/pkg/plugin/plugin.go @@ -18,12 +18,12 @@ import ( "fmt" "time" + "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/kms" - "github.com/aws/aws-sdk-go/service/kms/kmsiface" "go.uber.org/zap" "google.golang.org/grpc" pb "k8s.io/kms/apis/v1beta1" + "sigs.k8s.io/aws-encryption-provider/pkg/cloud" "sigs.k8s.io/aws-encryption-provider/pkg/kmsplugin" "sigs.k8s.io/aws-encryption-provider/pkg/version" ) @@ -36,14 +36,14 @@ const ( // Plugin implements the KeyManagementServiceServer type V1Plugin struct { - svc kmsiface.KMSAPI + svc cloud.AWSKMSv2 keyID string - encryptionCtx map[string]*string + encryptionCtx map[string]string healthCheck *SharedHealthCheck } // New returns a new *V1Plugin -func New(key string, svc kmsiface.KMSAPI, encryptionCtx map[string]string, healthCheck *SharedHealthCheck) *V1Plugin { +func New(key string, svc cloud.AWSKMSv2, encryptionCtx map[string]string, healthCheck *SharedHealthCheck) *V1Plugin { return newPlugin( key, svc, @@ -54,7 +54,7 @@ func New(key string, svc kmsiface.KMSAPI, encryptionCtx map[string]string, healt func newPlugin( key string, - svc kmsiface.KMSAPI, + svc cloud.AWSKMSv2, encryptionCtx map[string]string, sharedHealthCheck *SharedHealthCheck, ) *V1Plugin { @@ -64,10 +64,10 @@ func newPlugin( healthCheck: sharedHealthCheck, } if len(encryptionCtx) > 0 { - p.encryptionCtx = make(map[string]*string) + p.encryptionCtx = make(map[string]string) } for k, v := range encryptionCtx { - p.encryptionCtx[k] = aws.String(v) + p.encryptionCtx[k] = v } return p } @@ -138,7 +138,7 @@ func (p *V1Plugin) Encrypt(ctx context.Context, request *pb.EncryptRequest) (*pb input.EncryptionContext = p.encryptionCtx } - result, err := p.svc.Encrypt(input) + result, err := p.svc.Encrypt(ctx, input) if err != nil { select { case p.healthCheck.healthCheckErrc <- err: @@ -173,7 +173,7 @@ func (p *V1Plugin) Decrypt(ctx context.Context, request *pb.DecryptRequest) (*pb input.EncryptionContext = p.encryptionCtx } - result, err := p.svc.Decrypt(input) + result, err := p.svc.Decrypt(ctx, input) if err != nil { select { case p.healthCheck.healthCheckErrc <- err: diff --git a/pkg/plugin/plugin_v2.go b/pkg/plugin/plugin_v2.go index 65e4b22a..2ade4407 100644 --- a/pkg/plugin/plugin_v2.go +++ b/pkg/plugin/plugin_v2.go @@ -18,12 +18,12 @@ import ( "fmt" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/kms" - "github.com/aws/aws-sdk-go/service/kms/kmsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kms" "go.uber.org/zap" "google.golang.org/grpc" pb "k8s.io/kms/apis/v2" + "sigs.k8s.io/aws-encryption-provider/pkg/cloud" "sigs.k8s.io/aws-encryption-provider/pkg/kmsplugin" ) @@ -35,14 +35,14 @@ const ( // Plugin implements the KeyManagementServiceServer type V2Plugin struct { - svc kmsiface.KMSAPI + svc cloud.AWSKMSv2 keyID string - encryptionCtx map[string]*string + encryptionCtx map[string]string healthCheck *SharedHealthCheck } // New returns a new *V2Plugin -func NewV2(key string, svc kmsiface.KMSAPI, encryptionCtx map[string]string, healthCheck *SharedHealthCheck) *V2Plugin { +func NewV2(key string, svc cloud.AWSKMSv2, encryptionCtx map[string]string, healthCheck *SharedHealthCheck) *V2Plugin { return newPluginV2( key, svc, @@ -53,7 +53,7 @@ func NewV2(key string, svc kmsiface.KMSAPI, encryptionCtx map[string]string, hea func newPluginV2( key string, - svc kmsiface.KMSAPI, + svc cloud.AWSKMSv2, encryptionCtx map[string]string, healthCheck *SharedHealthCheck, ) *V2Plugin { @@ -63,10 +63,10 @@ func newPluginV2( healthCheck: healthCheck, } if len(encryptionCtx) > 0 { - p.encryptionCtx = make(map[string]*string) + p.encryptionCtx = make(map[string]string) } for k, v := range encryptionCtx { - p.encryptionCtx[k] = aws.String(v) + p.encryptionCtx[k] = v } return p } @@ -138,7 +138,7 @@ func (p *V2Plugin) Encrypt(ctx context.Context, request *pb.EncryptRequest) (*pb input.EncryptionContext = p.encryptionCtx } - result, err := p.svc.Encrypt(input) + result, err := p.svc.Encrypt(ctx, input) if err != nil { select { case p.healthCheck.healthCheckErrc <- err: @@ -181,7 +181,7 @@ func (p *V2Plugin) Decrypt(ctx context.Context, request *pb.DecryptRequest) (*pb input.EncryptionContext = p.encryptionCtx } - result, err := p.svc.Decrypt(input) + result, err := p.svc.Decrypt(ctx, input) if err != nil { select { case p.healthCheck.healthCheckErrc <- err: diff --git a/pkg/plugin/plugin_v2_test.go b/pkg/plugin/plugin_v2_test.go index 2ec6426d..72c2bf37 100644 --- a/pkg/plugin/plugin_v2_test.go +++ b/pkg/plugin/plugin_v2_test.go @@ -21,8 +21,8 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/kms" "go.uber.org/zap" pb "k8s.io/kms/apis/v2" "sigs.k8s.io/aws-encryption-provider/pkg/cloud" From ad708cc8141e2a16e2cb08d495dd907354263a58 Mon Sep 17 00:00:00 2001 From: Michael Shen Date: Tue, 8 Oct 2024 10:05:53 -0400 Subject: [PATCH 2/3] Convert KMS error parsing logic to aws-sdk-go-v2 Signed-off-by: Michael Shen --- go.mod | 4 +-- go.sum | 6 ---- pkg/kmsplugin/kms.go | 40 ++++++++++++++------------- pkg/livez/livez_test.go | 10 +++---- pkg/plugin/metrics_test.go | 5 ++-- pkg/plugin/plugin.go | 2 +- pkg/plugin/plugin_test.go | 53 ++++++++++++++++++++++-------------- pkg/plugin/plugin_v2_test.go | 53 ++++++++++++++++++++++-------------- 8 files changed, 97 insertions(+), 76 deletions(-) diff --git a/go.mod b/go.mod index d9fd7ae2..5926f788 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,10 @@ module sigs.k8s.io/aws-encryption-provider go 1.22.2 require ( - github.com/aws/aws-sdk-go v1.54.6 github.com/aws/aws-sdk-go-v2 v1.32.1 github.com/aws/aws-sdk-go-v2/config v1.27.42 github.com/aws/aws-sdk-go-v2/service/kms v1.37.1 + github.com/aws/smithy-go v1.22.0 github.com/prometheus/client_golang v1.14.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.0 @@ -27,13 +27,11 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.24.1 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.1 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.32.1 // indirect - github.com/aws/smithy-go v1.22.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect - github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2 // indirect github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect diff --git a/go.sum b/go.sum index 9fee3704..8cbef95c 100644 --- a/go.sum +++ b/go.sum @@ -38,8 +38,6 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuy github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= -github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g= -github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/aws/aws-sdk-go-v2 v1.32.1 h1:8WuZ43ytA+TV6QEPT/R23mr7pWyI7bSSiEHdt9BS2Pw= github.com/aws/aws-sdk-go-v2 v1.32.1/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= github.com/aws/aws-sdk-go-v2/config v1.27.42 h1:Zsy9coUPuOsCWkjTvHpl2/DB9bptXtv7WeNPxvFr87s= @@ -165,10 +163,6 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= -github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= -github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= -github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= diff --git a/pkg/kmsplugin/kms.go b/pkg/kmsplugin/kms.go index ad13950b..064210f5 100644 --- a/pkg/kmsplugin/kms.go +++ b/pkg/kmsplugin/kms.go @@ -5,10 +5,10 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/request" - awsreq "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" + kmstypes "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/aws/smithy-go" "go.uber.org/zap" ) @@ -51,30 +51,31 @@ func ParseError(err error) (errorType KMSErrorType) { uerr = err } - ev, ok := uerr.(awserr.Error) - if !ok { + var ae smithy.APIError + if !errors.As(uerr, &ae) { return KMSErrorTypeOther } - zap.L().Debug("parsed error", zap.String("code", ev.Code()), zap.String("message", ev.Message())) - if request.IsErrorThrottle(uerr) { + zap.L().Debug("parsed error", zap.String("code", ae.ErrorCode()), zap.String("message", ae.ErrorMessage())) + var defaultCodes retry.IsErrorThrottles = retry.DefaultThrottles + if defaultCodes.IsErrorThrottle(uerr) == aws.TrueTernary { return KMSErrorTypeThrottled } - switch ev.Code() { + switch ae.ErrorCode() { // CMK is disabled or pending deletion - case kms.ErrCodeDisabledException, - kms.ErrCodeInvalidStateException: + case (&kmstypes.DisabledException{}).ErrorCode(), + (&kmstypes.KMSInvalidStateException{}).ErrorCode(): return KMSErrorTypeUserInduced // CMK does not exist, or grant is not valid - case kms.ErrCodeKeyUnavailableException, - kms.ErrCodeInvalidArnException, - kms.ErrCodeInvalidGrantIdException, - kms.ErrCodeInvalidGrantTokenException: + case (&kmstypes.KeyUnavailableException{}).ErrorCode(), + (&kmstypes.InvalidArnException{}).ErrorCode(), + (&kmstypes.InvalidGrantIdException{}).ErrorCode(), + (&kmstypes.InvalidGrantTokenException{}).ErrorCode(): return KMSErrorTypeUserInduced // ref. https://docs.aws.amazon.com/kms/latest/developerguide/requests-per-second.html - case kms.ErrCodeLimitExceededException: + case (&kmstypes.LimitExceededException{}).ErrorCode(): return KMSErrorTypeThrottled // AWS SDK Go for KMS does not "yet" define specific error code for a case where a customer specifies the deleted key @@ -84,8 +85,8 @@ func ParseError(err error) (errorType KMSErrorType) { // e.g., "AccessDeniedException: The ciphertext refers to a customer master key that does not exist, does not exist in this region, or you are not allowed to access." // KMS service may change the error message, so we do the string match. case "AccessDeniedException": - if strings.Contains(ev.Message(), "customer master key that does not exist") || - strings.Contains(ev.Message(), "does not exist in this region") { + if strings.Contains(ae.ErrorMessage(), "customer master key that does not exist") || + strings.Contains(ae.ErrorMessage(), "does not exist in this region") { return KMSErrorTypeUserInduced } } @@ -121,10 +122,11 @@ func GetMillisecondsSince(startTime time.Time) float64 { } func GetStatusLabel(err error) string { + var defaultCodes retry.IsErrorThrottles = retry.DefaultThrottles switch { case err == nil: return StatusSuccess - case awsreq.IsErrorThrottle(err): + case defaultCodes.IsErrorThrottle(err) == aws.TrueTernary: return StatusFailureThrottle default: return StatusFailure diff --git a/pkg/livez/livez_test.go b/pkg/livez/livez_test.go index 3c23bcbe..c1187332 100644 --- a/pkg/livez/livez_test.go +++ b/pkg/livez/livez_test.go @@ -12,8 +12,8 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go-v2/aws" + kmstypes "github.com/aws/aws-sdk-go-v2/service/kms/types" "go.uber.org/zap" "sigs.k8s.io/aws-encryption-provider/pkg/cloud" "sigs.k8s.io/aws-encryption-provider/pkg/plugin" @@ -41,19 +41,19 @@ func TestLivez(t *testing.T) { }, { path: "/test-livez-fail-with-internal-error", - kmsEncryptErr: awserr.New(kms.ErrCodeInternalException, "test", errors.New("fail")), + kmsEncryptErr: &kmstypes.KMSInternalException{Message: aws.String("test")}, shouldSucceed: false, }, // user-induced { path: "/test-livez-fail-with-user-induced-invalid-key-state", - kmsEncryptErr: awserr.New(kms.ErrCodeInvalidStateException, "test", errors.New("fail")), + kmsEncryptErr: &kmstypes.KMSInvalidStateException{Message: aws.String("test")}, shouldSucceed: true, }, { path: "/test-livez-fail-with-user-induced-invalid-grant", - kmsEncryptErr: awserr.New(kms.ErrCodeInvalidGrantTokenException, "test", errors.New("fail")), + kmsEncryptErr: &kmstypes.InvalidGrantTokenException{Message: aws.String("test")}, shouldSucceed: true, }, } diff --git a/pkg/plugin/metrics_test.go b/pkg/plugin/metrics_test.go index f72a558a..9d6007e0 100644 --- a/pkg/plugin/metrics_test.go +++ b/pkg/plugin/metrics_test.go @@ -14,7 +14,8 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go-v2/aws" + kmstypes "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/prometheus/client_golang/prometheus/promhttp" "go.uber.org/zap" pb "k8s.io/kms/apis/v1beta1" @@ -38,7 +39,7 @@ func TestMetrics(t *testing.T) { }, { key: "test-key-throttle", - encryptErr: awserr.New("RequestLimitExceeded", "test", errors.New("fail")), + encryptErr: &kmstypes.LimitExceededException{Message: aws.String("test")}, expects: `aws_encryption_provider_kms_operations_total{key_arn="test-key-throttle",operation="encrypt",status="failure-throttle",version="v1"} 1`, }, } diff --git a/pkg/plugin/plugin.go b/pkg/plugin/plugin.go index b1a6d90b..bca89941 100644 --- a/pkg/plugin/plugin.go +++ b/pkg/plugin/plugin.go @@ -18,8 +18,8 @@ import ( "fmt" "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/kms" - "github.com/aws/aws-sdk-go/aws" "go.uber.org/zap" "google.golang.org/grpc" pb "k8s.io/kms/apis/v1beta1" diff --git a/pkg/plugin/plugin_test.go b/pkg/plugin/plugin_test.go index b81081cc..f62a02c6 100644 --- a/pkg/plugin/plugin_test.go +++ b/pkg/plugin/plugin_test.go @@ -22,8 +22,9 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go-v2/aws" + kmstypes "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/aws/smithy-go" "go.uber.org/zap" pb "k8s.io/kms/apis/v1beta1" "sigs.k8s.io/aws-encryption-provider/pkg/cloud" @@ -67,10 +68,14 @@ func TestEncrypt(t *testing.T) { checkErr: true, }, { - input: plainMessage, - ctx: nil, - output: "", - err: awserr.New("RequestLimitExceeded", "test", errors.New("fail")), + input: plainMessage, + ctx: nil, + output: "", + err: &smithy.GenericAPIError{ + Code: "RequestLimitExceeded", + Message: "test", + Fault: 0, + }, errType: kmsplugin.KMSErrorTypeThrottled, healthErr: true, checkErr: true, @@ -79,7 +84,7 @@ func TestEncrypt(t *testing.T) { input: plainMessage, ctx: nil, output: "", - err: awserr.New(kms.ErrCodeInternalException, "test", errors.New("fail")), + err: &kmstypes.KMSInternalException{Message: aws.String("test")}, errType: kmsplugin.KMSErrorTypeOther, healthErr: true, checkErr: true, @@ -88,25 +93,33 @@ func TestEncrypt(t *testing.T) { input: plainMessage, ctx: nil, output: "", - err: awserr.New(kms.ErrCodeLimitExceededException, "test", errors.New("fail")), + err: &kmstypes.LimitExceededException{Message: aws.String("test")}, errType: kmsplugin.KMSErrorTypeThrottled, healthErr: true, checkErr: true, }, { - input: plainMessage, - ctx: nil, - output: "", - err: awserr.New("AccessDeniedException", "The ciphertext refers to a customer master key that does not exist, does not exist in this region, or you are not allowed to access", errors.New("fail")), + input: plainMessage, + ctx: nil, + output: "", + err: &smithy.GenericAPIError{ + Code: "AccessDeniedException", + Message: "The ciphertext refers to a customer master key that does not exist, does not exist in this region, or you are not allowed to access", + Fault: 0, + }, errType: kmsplugin.KMSErrorTypeUserInduced, healthErr: true, checkErr: false, }, { - input: plainMessage, - ctx: nil, - output: "", - err: awserr.New("AccessDeniedException", "Some other error message", errors.New("fail")), + input: plainMessage, + ctx: nil, + output: "", + err: &smithy.GenericAPIError{ + Code: "AccessDeniedException", + Message: "Some other error message", + Fault: 0, + }, errType: kmsplugin.KMSErrorTypeOther, healthErr: true, checkErr: true, @@ -115,7 +128,7 @@ func TestEncrypt(t *testing.T) { input: plainMessage, ctx: nil, output: "", - err: awserr.New(kms.ErrCodeDisabledException, "test", errors.New("fail")), + err: &kmstypes.DisabledException{Message: aws.String("test")}, errType: kmsplugin.KMSErrorTypeUserInduced, healthErr: true, checkErr: false, @@ -124,7 +137,7 @@ func TestEncrypt(t *testing.T) { input: plainMessage, ctx: nil, output: "", - err: awserr.New(kms.ErrCodeInvalidStateException, "test", errors.New("fail")), + err: &kmstypes.KMSInvalidStateException{Message: aws.String("test")}, errType: kmsplugin.KMSErrorTypeUserInduced, healthErr: true, checkErr: false, @@ -133,7 +146,7 @@ func TestEncrypt(t *testing.T) { input: plainMessage, ctx: nil, output: "", - err: awserr.New(kms.ErrCodeInvalidGrantIdException, "test", errors.New("fail")), + err: &kmstypes.InvalidGrantIdException{Message: aws.String("test")}, errType: kmsplugin.KMSErrorTypeUserInduced, healthErr: true, checkErr: false, @@ -142,7 +155,7 @@ func TestEncrypt(t *testing.T) { input: plainMessage, ctx: nil, output: "", - err: awserr.New(kms.ErrCodeInvalidGrantTokenException, "test", errors.New("fail")), + err: &kmstypes.InvalidGrantTokenException{Message: aws.String("test")}, errType: kmsplugin.KMSErrorTypeUserInduced, healthErr: true, checkErr: false, diff --git a/pkg/plugin/plugin_v2_test.go b/pkg/plugin/plugin_v2_test.go index 72c2bf37..f857eafa 100644 --- a/pkg/plugin/plugin_v2_test.go +++ b/pkg/plugin/plugin_v2_test.go @@ -21,8 +21,9 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go-v2/service/kms" - "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go-v2/aws" + kmstypes "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/aws/smithy-go" "go.uber.org/zap" pb "k8s.io/kms/apis/v2" "sigs.k8s.io/aws-encryption-provider/pkg/cloud" @@ -58,10 +59,14 @@ func TestEncryptV2(t *testing.T) { checkErr: true, }, { - input: plainMessage, - ctx: nil, - output: "", - err: awserr.New("RequestLimitExceeded", "test", errors.New("fail")), + input: plainMessage, + ctx: nil, + output: "", + err: &smithy.GenericAPIError{ + Code: "RequestLimitExceeded", + Message: "test", + Fault: 0, + }, errType: kmsplugin.KMSErrorTypeThrottled, healthErr: true, checkErr: true, @@ -70,7 +75,7 @@ func TestEncryptV2(t *testing.T) { input: plainMessage, ctx: nil, output: "", - err: awserr.New(kms.ErrCodeInternalException, "test", errors.New("fail")), + err: &kmstypes.KMSInternalException{Message: aws.String("test")}, errType: kmsplugin.KMSErrorTypeOther, healthErr: true, checkErr: true, @@ -79,25 +84,33 @@ func TestEncryptV2(t *testing.T) { input: plainMessage, ctx: nil, output: "", - err: awserr.New(kms.ErrCodeLimitExceededException, "test", errors.New("fail")), + err: &kmstypes.LimitExceededException{Message: aws.String("test")}, errType: kmsplugin.KMSErrorTypeThrottled, healthErr: true, checkErr: true, }, { - input: plainMessage, - ctx: nil, - output: "", - err: awserr.New("AccessDeniedException", "The ciphertext refers to a customer master key that does not exist, does not exist in this region, or you are not allowed to access", errors.New("fail")), + input: plainMessage, + ctx: nil, + output: "", + err: &smithy.GenericAPIError{ + Code: "AccessDeniedException", + Message: "The ciphertext refers to a customer master key that does not exist, does not exist in this region, or you are not allowed to access", + Fault: 0, + }, errType: kmsplugin.KMSErrorTypeUserInduced, healthErr: true, checkErr: false, }, { - input: plainMessage, - ctx: nil, - output: "", - err: awserr.New("AccessDeniedException", "Some other error message", errors.New("fail")), + input: plainMessage, + ctx: nil, + output: "", + err: &smithy.GenericAPIError{ + Code: "AccessDeniedException", + Message: "Some other error message", + Fault: 0, + }, errType: kmsplugin.KMSErrorTypeOther, healthErr: true, checkErr: true, @@ -106,7 +119,7 @@ func TestEncryptV2(t *testing.T) { input: plainMessage, ctx: nil, output: "", - err: awserr.New(kms.ErrCodeDisabledException, "test", errors.New("fail")), + err: &kmstypes.DisabledException{Message: aws.String("test")}, errType: kmsplugin.KMSErrorTypeUserInduced, healthErr: true, checkErr: false, @@ -115,7 +128,7 @@ func TestEncryptV2(t *testing.T) { input: plainMessage, ctx: nil, output: "", - err: awserr.New(kms.ErrCodeInvalidStateException, "test", errors.New("fail")), + err: &kmstypes.KMSInvalidStateException{Message: aws.String("test")}, errType: kmsplugin.KMSErrorTypeUserInduced, healthErr: true, checkErr: false, @@ -124,7 +137,7 @@ func TestEncryptV2(t *testing.T) { input: plainMessage, ctx: nil, output: "", - err: awserr.New(kms.ErrCodeInvalidGrantIdException, "test", errors.New("fail")), + err: &kmstypes.InvalidGrantIdException{Message: aws.String("test")}, errType: kmsplugin.KMSErrorTypeUserInduced, healthErr: true, checkErr: false, @@ -133,7 +146,7 @@ func TestEncryptV2(t *testing.T) { input: plainMessage, ctx: nil, output: "", - err: awserr.New(kms.ErrCodeInvalidGrantTokenException, "test", errors.New("fail")), + err: &kmstypes.InvalidGrantTokenException{Message: aws.String("test")}, errType: kmsplugin.KMSErrorTypeUserInduced, healthErr: true, checkErr: false, From 0d0cbecf4d9d194a1e543b019a38bef77d4ee3bb Mon Sep 17 00:00:00 2001 From: Michael Shen Date: Tue, 8 Oct 2024 10:16:13 -0400 Subject: [PATCH 3/3] Remove httputil package It was providing a client-side rate limiter that is now built-in functionality in aws-sdk-go-v2 Signed-off-by: Michael Shen --- go.mod | 1 - go.sum | 2 - pkg/httputil/client.go | 37 ---------- pkg/httputil/client_test.go | 134 ------------------------------------ 4 files changed, 174 deletions(-) delete mode 100644 pkg/httputil/client.go delete mode 100644 pkg/httputil/client_test.go diff --git a/go.mod b/go.mod index 5926f788..862565e7 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,6 @@ require ( github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.0 go.uber.org/zap v1.19.0 - golang.org/x/time v0.3.0 google.golang.org/grpc v1.65.0 k8s.io/kms v0.31.0 ) diff --git a/go.sum b/go.sum index 8cbef95c..b11d7cf7 100644 --- a/go.sum +++ b/go.sum @@ -402,8 +402,6 @@ golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= -golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= diff --git a/pkg/httputil/client.go b/pkg/httputil/client.go deleted file mode 100644 index fc09e43d..00000000 --- a/pkg/httputil/client.go +++ /dev/null @@ -1,37 +0,0 @@ -// Package httputil implements HTTP utilities. -package httputil - -import ( - "fmt" - "net/http" - - "golang.org/x/time/rate" -) - -// NewRateLimitedClient returns a new HTTP client with rate limiter. -func NewRateLimitedClient(qps int, burst int) (*http.Client, error) { - if qps == 0 { - return http.DefaultClient, nil - } - if burst < 1 { - return nil, fmt.Errorf("burst expected >0, got %d", burst) - } - return &http.Client{ - Transport: &rateLimitedRoundTripper{ - rt: http.DefaultTransport, - rl: rate.NewLimiter(rate.Limit(qps), burst), - }, - }, nil -} - -type rateLimitedRoundTripper struct { - rt http.RoundTripper - rl *rate.Limiter -} - -func (rr *rateLimitedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if err := rr.rl.Wait(req.Context()); err != nil { - return nil, err - } - return rr.rt.RoundTrip(req) -} diff --git a/pkg/httputil/client_test.go b/pkg/httputil/client_test.go deleted file mode 100644 index 2e31229d..00000000 --- a/pkg/httputil/client_test.go +++ /dev/null @@ -1,134 +0,0 @@ -package httputil - -import ( - "context" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" -) - -func TestNewRateLimitedClient(t *testing.T) { - mux := http.NewServeMux() - mux.HandleFunc("/test", testHandler) - - ts := httptest.NewServer(mux) - defer ts.Close() - - u := ts.URL + "/test" - - // requests are to be throttled if qps+burst < reqs - // estimated time: reqs / (qps+burst) seconds - tbs := []struct { - ctxTimeout time.Duration - qps int - burst int - requests int // concurrent requests - err string - }{ - { - qps: 1, - burst: 1, - requests: 10, - }, - { - qps: 15, - burst: 5, - requests: 100, - }, - { - qps: 8, - burst: 2, - requests: 20, - }, - { - // 20 concurrent ec2 API requests should exceed 1 QPS before 10ms - // thus rate limiter returns an error - ctxTimeout: 10 * time.Millisecond, - qps: 1, - burst: 1, - requests: 20, - err: `context deadline`, - // "Wait(n=1) would exceed context deadline" for requests before timeout - // "context deadline exceeded" for requests after timeout - }, - } - for idx, tt := range tbs { - cli, err := NewRateLimitedClient(tt.qps, tt.burst) - if err != nil { - t.Fatalf("#%d: failed to create a new client (%v)", idx, err) - } - - start := time.Now() - - errc := make(chan error, tt.requests) - for i := 0; i < tt.requests; i++ { - go func() { - var ctx context.Context - if tt.ctxTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(context.TODO(), tt.ctxTimeout) - defer cancel() - } else { - ctx = context.TODO() - } - req, err := http.NewRequest(http.MethodGet, u, nil) - if err != nil { - errc <- err - return - } - _, err = cli.Do(req.WithContext(ctx)) - errc <- err - }() - } - - failed := false - for i := 0; i < tt.requests; i++ { - err = <-errc - switch { - case tt.err == "": // expects no error - if err != nil { - t.Errorf("#%d-%d: unexpected error %v", idx, i, err) - } - case tt.err != "": // expects error - if err == nil { - // this means that the request did not get throttled. - continue - } - if !strings.Contains(err.Error(), tt.err) && - // TODO: why does this happen even when ctx is not canceled - // ref. https://github.com/golang/go/issues/36848 - !strings.Contains(err.Error(), "i/o timeout") { - t.Errorf("#%d-%d: expected %q, got %v", idx, i, tt.err, err) - } - failed = true - } - } - - if tt.err != "" && !failed { - t.Fatalf("#%d: expected failure %q, got no error", idx, tt.err) - } - - if tt.err == "" { - observedDuration := time.Since(start).Round(time.Second) - expectedDuration := time.Duration(0) - if tt.qps+tt.burst < tt.requests { - expectedDuration = (time.Duration(tt.requests/(tt.qps)) * time.Second) - } - if expectedDuration > 0 && observedDuration > expectedDuration { - t.Fatalf("with rate limit, requests expected duration %v, got %v", expectedDuration, observedDuration) - } - } - } -} - -func testHandler(w http.ResponseWriter, req *http.Request) { - switch req.Method { - case "GET": - fmt.Fprint(w, `test`) - default: - http.Error(w, "Method Not Allowed", 405) - } -}