diff --git a/internal/metadata/ec2/chain.go b/internal/metadata/ec2/chain.go new file mode 100644 index 0000000000..46f2b39e15 --- /dev/null +++ b/internal/metadata/ec2/chain.go @@ -0,0 +1,51 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: MIT + +package ec2 + +import ( + "context" + "errors" + "fmt" + "strings" +) + +type chainMetadataProvider struct { + providers []MetadataProvider +} + +func newChainMetadataProvider(providers []MetadataProvider) *chainMetadataProvider { + return &chainMetadataProvider{providers: providers} +} + +func (p *chainMetadataProvider) ID() string { + var providerIDs []string + for _, provider := range p.providers { + providerIDs = append(providerIDs, provider.ID()) + } + return fmt.Sprintf("Chain [%s]", strings.Join(providerIDs, ",")) +} + +func (p *chainMetadataProvider) Get(ctx context.Context) (*Metadata, error) { + var errs error + for _, provider := range p.providers { + if metadata, err := provider.Get(ctx); err != nil { + errs = errors.Join(errs, fmt.Errorf("unable to get metadata from %s: %w", provider.ID(), err)) + } else { + return metadata, nil + } + } + return nil, errs +} + +func (p *chainMetadataProvider) Hostname(ctx context.Context) (string, error) { + var errs error + for _, provider := range p.providers { + if hostname, err := provider.Hostname(ctx); err != nil { + errs = errors.Join(errs, fmt.Errorf("unable to get hostname from %s: %w", provider.ID(), err)) + } else { + return hostname, nil + } + } + return "", errs +} diff --git a/internal/metadata/ec2/chain_test.go b/internal/metadata/ec2/chain_test.go new file mode 100644 index 0000000000..255aab9e97 --- /dev/null +++ b/internal/metadata/ec2/chain_test.go @@ -0,0 +1,122 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: MIT + +package ec2 + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +type mockMetadataProvider struct { + Index int + Metadata *Metadata + Err error +} + +func (m *mockMetadataProvider) ID() string { + return fmt.Sprintf("mock/%v", m.Index) +} + +func (m *mockMetadataProvider) Get(context.Context) (*Metadata, error) { + if m.Metadata != nil { + return m.Metadata, nil + } + return nil, m.Err +} + +func (m *mockMetadataProvider) Hostname(context.Context) (string, error) { + if m.Metadata != nil && m.Metadata.Hostname != "" { + return m.Metadata.Hostname, nil + } + return "", m.Err +} + +func TestChainProvider(t *testing.T) { + errFirstTest := errors.New("skip first") + errSecondTest := errors.New("skip second") + testCases := map[string]struct { + providers []MetadataProvider + wantID string + wantMetadata *Metadata + wantHostname string + wantErr error + }{ + "WithErrors": { + providers: []MetadataProvider{ + &mockMetadataProvider{ + Index: 1, + Err: errFirstTest, + }, + &mockMetadataProvider{ + Index: 2, + Err: errSecondTest, + }, + }, + wantID: "Chain [mock/1,mock/2]", + wantErr: errSecondTest, + }, + "WithEarlyChainSuccess": { + providers: []MetadataProvider{ + &mockMetadataProvider{ + Index: 1, + Metadata: &Metadata{ + Hostname: "hostname-1", + InstanceID: "instance-id-1", + }, + }, + &mockMetadataProvider{ + Index: 2, + Metadata: &Metadata{ + Hostname: "hostname-2", + InstanceID: "instance-id-2", + }, + }, + }, + wantID: "Chain [mock/1,mock/2]", + wantHostname: "hostname-1", + wantMetadata: &Metadata{ + Hostname: "hostname-1", + InstanceID: "instance-id-1", + }, + }, + "WithFallback": { + providers: []MetadataProvider{ + &mockMetadataProvider{ + Index: 1, + Err: errFirstTest, + }, + &mockMetadataProvider{ + Index: 2, + Metadata: &Metadata{ + Hostname: "hostname-2", + InstanceID: "instance-id-2", + }, + }, + }, + wantID: "Chain [mock/1,mock/2]", + wantHostname: "hostname-2", + wantMetadata: &Metadata{ + Hostname: "hostname-2", + InstanceID: "instance-id-2", + }, + }, + } + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + ctx := context.Background() + p := newChainMetadataProvider(testCase.providers) + assert.Equal(t, testCase.wantID, p.ID()) + hostname, err := p.Hostname(ctx) + assert.ErrorIs(t, err, testCase.wantErr) + assert.Equal(t, testCase.wantHostname, hostname) + metadata, err := p.Get(ctx) + assert.ErrorIs(t, err, testCase.wantErr) + assert.Equal(t, testCase.wantMetadata, metadata) + }) + } +} diff --git a/internal/metadata/ec2/ec2.go b/internal/metadata/ec2/ec2.go new file mode 100644 index 0000000000..aff475f46a --- /dev/null +++ b/internal/metadata/ec2/ec2.go @@ -0,0 +1,43 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: MIT + +package ec2 + +import ( + "context" + + "github.com/aws/aws-sdk-go/aws/client" +) + +// Metadata is a set of information about the EC2 instance. +type Metadata struct { + AccountID string + Hostname string + ImageID string + InstanceID string + InstanceType string + PrivateIP string + Region string +} + +type MetadataProviderConfig struct { + // IMDSv2Retries is the number of retries the IMDSv2 MetadataProvider will make before it errors out. + IMDSv2Retries int +} + +// MetadataProvider provides functions to get EC2 Metadata and the hostname. +type MetadataProvider interface { + Get(ctx context.Context) (*Metadata, error) + Hostname(ctx context.Context) (string, error) + ID() string +} + +func NewMetadataProvider(configProvider client.ConfigProvider, config MetadataProviderConfig) MetadataProvider { + return newChainMetadataProvider( + []MetadataProvider{ + newIMDSv2MetadataProvider(configProvider, config.IMDSv2Retries), + newIMDSv1MetadataProvider(configProvider), + newDescribeInstancesMetadataProvider(configProvider), + }, + ) +} diff --git a/internal/metadata/ec2/ec2_test.go b/internal/metadata/ec2/ec2_test.go new file mode 100644 index 0000000000..2b095607da --- /dev/null +++ b/internal/metadata/ec2/ec2_test.go @@ -0,0 +1,21 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: MIT + +package ec2 + +import ( + "testing" + + awsmock "github.com/aws/aws-sdk-go/awstesting/mock" + "github.com/stretchr/testify/assert" +) + +func TestNewMetadataProvider(t *testing.T) { + mp := NewMetadataProvider( + awsmock.Session, + MetadataProviderConfig{IMDSv2Retries: 0}, + ) + cmp, ok := mp.(*chainMetadataProvider) + assert.True(t, ok) + assert.Len(t, cmp.providers, 3) +} diff --git a/internal/metadata/ec2/imds.go b/internal/metadata/ec2/imds.go new file mode 100644 index 0000000000..317c7e23f3 --- /dev/null +++ b/internal/metadata/ec2/imds.go @@ -0,0 +1,99 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: MIT + +package ec2 + +import ( + "context" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/client" + "github.com/aws/aws-sdk-go/aws/ec2metadata" + + configaws "github.com/aws/amazon-cloudwatch-agent/cfg/aws" + "github.com/aws/amazon-cloudwatch-agent/extension/agenthealth/handler/stats/agent" + "github.com/aws/amazon-cloudwatch-agent/internal/retryer" +) + +type imdsVersion string + +const ( + IMDSv1 imdsVersion = "IMDSv1" + IMDSv2 imdsVersion = "IMDSv2" + + metadataKeyHostname = "hostname" +) + +type ec2MetadataClient interface { + GetMetadataWithContext(ctx aws.Context, p string) (string, error) + GetInstanceIdentityDocumentWithContext(ctx aws.Context) (ec2metadata.EC2InstanceIdentityDocument, error) +} + +type imdsMetadataProvider struct { + version imdsVersion + svc ec2MetadataClient +} + +var _ MetadataProvider = (*imdsMetadataProvider)(nil) + +func newIMDSv2MetadataProvider(configProvider client.ConfigProvider, retries int) *imdsMetadataProvider { + return newIMDSProvider(IMDSv2, configProvider, &aws.Config{ + LogLevel: configaws.SDKLogLevel(), + Logger: configaws.SDKLogger{}, + Retryer: retryer.NewIMDSRetryer(retries), + EC2MetadataEnableFallback: aws.Bool(false), + }) +} + +func newIMDSv1MetadataProvider(configProvider client.ConfigProvider) *imdsMetadataProvider { + return newIMDSProvider(IMDSv1, configProvider, &aws.Config{ + LogLevel: configaws.SDKLogLevel(), + Logger: configaws.SDKLogger{}, + }) +} + +func newIMDSProvider(version imdsVersion, configProvider client.ConfigProvider, config *aws.Config) *imdsMetadataProvider { + return &imdsMetadataProvider{ + svc: ec2metadata.New(configProvider, config), + version: version, + } +} + +func (p *imdsMetadataProvider) ID() string { + return string(p.version) +} + +// Hostname more information on API: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html#instance-metadata-ex-2 +func (p *imdsMetadataProvider) Hostname(ctx context.Context) (string, error) { + hostname, err := p.svc.GetMetadataWithContext(ctx, metadataKeyHostname) + if err != nil { + return "", err + } + if p.version == IMDSv1 { + agent.UsageFlags().Set(agent.FlagIMDSFallbackSuccess) + } + return hostname, nil +} + +// Get more information on API: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html +func (p *imdsMetadataProvider) Get(ctx context.Context) (*Metadata, error) { + instanceDocument, err := p.svc.GetInstanceIdentityDocumentWithContext(ctx) + if err != nil { + return nil, err + } + if p.version == IMDSv1 { + agent.UsageFlags().Set(agent.FlagIMDSFallbackSuccess) + } + return fromInstanceIdentityDocument(instanceDocument), nil +} + +func fromInstanceIdentityDocument(document ec2metadata.EC2InstanceIdentityDocument) *Metadata { + return &Metadata{ + AccountID: document.AccountID, + ImageID: document.ImageID, + InstanceID: document.InstanceID, + InstanceType: document.InstanceType, + PrivateIP: document.PrivateIP, + Region: document.Region, + } +} diff --git a/internal/metadata/ec2/imds_test.go b/internal/metadata/ec2/imds_test.go new file mode 100644 index 0000000000..a3c5029556 --- /dev/null +++ b/internal/metadata/ec2/imds_test.go @@ -0,0 +1,95 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: MIT + +package ec2 + +import ( + "context" + "errors" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go/awstesting/mock" + "github.com/stretchr/testify/assert" +) + +type mockIMDSClient struct { + metadata map[string]string + document ec2metadata.EC2InstanceIdentityDocument + err error +} + +func (m *mockIMDSClient) GetMetadataWithContext(_ aws.Context, key string) (string, error) { + if m.metadata == nil { + return "", m.err + } + return m.metadata[key], m.err +} + +func (m *mockIMDSClient) GetInstanceIdentityDocumentWithContext(aws.Context) (ec2metadata.EC2InstanceIdentityDocument, error) { + return m.document, m.err +} + +func TestIMDSProvider(t *testing.T) { + testErr := errors.New("test") + testCases := map[string]struct { + provider *imdsMetadataProvider + metadata map[string]string + document ec2metadata.EC2InstanceIdentityDocument + clientErr error + wantID string + wantHostname string + wantMetadata *Metadata + wantErr error + }{ + "WithSuccess": { + provider: newIMDSv1MetadataProvider(mock.Session), + metadata: map[string]string{ + metadataKeyHostname: "test.hostname", + }, + document: ec2metadata.EC2InstanceIdentityDocument{ + AccountID: "account-id", + ImageID: "image-id", + InstanceID: "instance-id", + InstanceType: "instance-type", + PrivateIP: "private-ip", + Region: "region", + }, + wantID: string(IMDSv1), + wantHostname: "test.hostname", + wantMetadata: &Metadata{ + AccountID: "account-id", + ImageID: "image-id", + InstanceID: "instance-id", + InstanceType: "instance-type", + PrivateIP: "private-ip", + Region: "region", + }, + }, + "WithError": { + provider: newIMDSv2MetadataProvider(mock.Session, 0), + clientErr: testErr, + wantID: string(IMDSv2), + wantErr: testErr, + }, + } + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + ctx := context.Background() + p := testCase.provider + p.svc = &mockIMDSClient{ + metadata: testCase.metadata, + document: testCase.document, + err: testCase.clientErr, + } + assert.Equal(t, testCase.wantID, p.ID()) + hostname, err := p.Hostname(ctx) + assert.ErrorIs(t, err, testCase.wantErr) + assert.Equal(t, testCase.wantHostname, hostname) + metadata, err := p.Get(ctx) + assert.ErrorIs(t, err, testCase.wantErr) + assert.Equal(t, testCase.wantMetadata, metadata) + }) + } +} diff --git a/internal/metadata/ec2/non_imds.go b/internal/metadata/ec2/non_imds.go new file mode 100644 index 0000000000..38aa602d8e --- /dev/null +++ b/internal/metadata/ec2/non_imds.go @@ -0,0 +1,164 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: MIT + +package ec2 + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/client" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + + configaws "github.com/aws/amazon-cloudwatch-agent/cfg/aws" + "github.com/aws/amazon-cloudwatch-agent/cfg/envconfig" +) + +const ( + filterKeyInstanceID = "instance-id" + filterKeyPrivateIpAddress = "private-ip-address" + prefixInstanceID = "i-" + prefixPrivateIpAddress = "ip-" + suffixDefault = ".ec2.internal" + suffixRegional = ".compute.internal" +) + +var ( + errUnsupportedHostname = errors.New("unable to parse non-fixed format hostname") + errUnsupportedFilter = errors.New("unable to determine EC2 filter") + errReservationCount = errors.New("invalid number of reservations found") + errInstanceCount = errors.New("invalid number of instances found") +) + +type ec2ClientProvider func(client.ConfigProvider, ...*aws.Config) ec2iface.EC2API + +type describeInstancesMetadataProvider struct { + configProvider client.ConfigProvider + newEC2Client ec2ClientProvider + osHostname func() (string, error) +} + +var _ MetadataProvider = (*describeInstancesMetadataProvider)(nil) + +func newDescribeInstancesMetadataProvider(configProvider client.ConfigProvider) *describeInstancesMetadataProvider { + return &describeInstancesMetadataProvider{ + configProvider: configProvider, + newEC2Client: func(provider client.ConfigProvider, configs ...*aws.Config) ec2iface.EC2API { + return ec2.New(provider, configs...) + }, + osHostname: os.Hostname, + } +} + +func (p *describeInstancesMetadataProvider) ID() string { + return "DescribeInstances" +} + +func (p *describeInstancesMetadataProvider) Get(ctx context.Context) (*Metadata, error) { + filter, region, err := p.getEC2FilterAndRegion(ctx) + if err != nil { + return nil, err + } + input := &ec2.DescribeInstancesInput{Filters: []*ec2.Filter{filter}} + cfg := &aws.Config{ + LogLevel: configaws.SDKLogLevel(), + Logger: configaws.SDKLogger{}, + CredentialsChainVerboseErrors: aws.Bool(true), + } + if region != "" { + cfg = cfg.WithRegion(region) + } + svc := p.newEC2Client(p.configProvider, cfg) + output, err := svc.DescribeInstances(input) + if err != nil { + return nil, err + } + reservationCount := len(output.Reservations) + if reservationCount == 0 || reservationCount > 1 { + return nil, fmt.Errorf("%w: %v", errReservationCount, reservationCount) + } + metadata, err := fromReservation(*output.Reservations[0]) + if err != nil { + return nil, err + } + metadata.Region = region + return metadata, nil +} + +func (p *describeInstancesMetadataProvider) Hostname(context.Context) (string, error) { + hostname := os.Getenv(envconfig.HostName) + if hostname == "" { + return p.osHostname() + } + return hostname, nil +} + +func (p *describeInstancesMetadataProvider) getEC2FilterAndRegion(ctx context.Context) (*ec2.Filter, string, error) { + hostname, err := p.Hostname(ctx) + if err != nil { + return nil, "", err + } + prefix, region, err := splitHostname(hostname) + if region == "" { + return nil, "", err + } + filter, err := filterFromHostnamePrefix(prefix) + if err != nil { + return nil, "", err + } + return filter, region, nil +} + +func fromReservation(reservation ec2.Reservation) (*Metadata, error) { + instanceCount := len(reservation.Instances) + if instanceCount == 0 || instanceCount > 1 { + return nil, fmt.Errorf("%w: %v", errInstanceCount, instanceCount) + } + instance := reservation.Instances[0] + return &Metadata{ + AccountID: aws.StringValue(reservation.OwnerId), + ImageID: aws.StringValue(instance.ImageId), + InstanceID: aws.StringValue(instance.InstanceId), + InstanceType: aws.StringValue(instance.InstanceType), + PrivateIP: aws.StringValue(instance.PrivateIpAddress), + }, nil +} + +func filterFromHostnamePrefix(prefix string) (*ec2.Filter, error) { + // i-0123456789abcdef + if strings.HasPrefix(prefix, prefixInstanceID) { + return &ec2.Filter{ + Name: aws.String(filterKeyInstanceID), + Values: aws.StringSlice([]string{prefix}), + }, nil + } + // ip-10-24-34-0 -> 10.24.34.0 + if ipAddress, ok := strings.CutPrefix(prefix, prefixPrivateIpAddress); ok { + return &ec2.Filter{ + Name: aws.String(filterKeyPrivateIpAddress), + Values: aws.StringSlice([]string{strings.ReplaceAll(ipAddress, "-", ".")}), + }, nil + } + return nil, fmt.Errorf("%w from hostname prefix: %s", errUnsupportedFilter, prefix) +} + +// splitHostname extracts the prefix and region based on https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-naming.html +func splitHostname(hostname string) (prefix string, region string, err error) { + before, ok := strings.CutSuffix(hostname, suffixRegional) + if ok { + parts := strings.Split(before, ".") + if len(parts) == 2 { + return parts[0], parts[1], nil + } + } + before, ok = strings.CutSuffix(hostname, suffixDefault) + if ok { + return before, "us-east-1", nil + } + return hostname, "", fmt.Errorf("%w: %s", errUnsupportedHostname, hostname) +} diff --git a/internal/metadata/ec2/non_imds_test.go b/internal/metadata/ec2/non_imds_test.go new file mode 100644 index 0000000000..b9e54b9309 --- /dev/null +++ b/internal/metadata/ec2/non_imds_test.go @@ -0,0 +1,175 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: MIT + +package ec2 + +import ( + "context" + "errors" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/client" + awsmock "github.com/aws/aws-sdk-go/awstesting/mock" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/stretchr/testify/assert" + + "github.com/aws/amazon-cloudwatch-agent/cfg/envconfig" +) + +type mockEC2Client struct { + ec2iface.EC2API + reservations []*ec2.Reservation + err error +} + +func (m *mockEC2Client) DescribeInstances(*ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) { + if m.err != nil { + return nil, m.err + } + if m.reservations == nil { + return nil, errors.New("no reservations") + } + return &ec2.DescribeInstancesOutput{ + Reservations: m.reservations, + }, nil +} + +func TestDescribeInstanceProvider(t *testing.T) { + t.Setenv(envconfig.HostName, "") + testErr := errors.New("test") + testCases := map[string]struct { + hostnameFn func() (string, error) + reservations []*ec2.Reservation + clientErr error + wantHostname string + wantMetadata *Metadata + wantHostnameErr error + wantGetErr error + }{ + "WithHostname/PrivateIP": { + hostnameFn: func() (string, error) { + return "ip-10-24-34-0.ec2.internal", nil + }, + reservations: []*ec2.Reservation{ + { + Instances: []*ec2.Instance{ + { + ImageId: aws.String("image-id"), + InstanceId: aws.String("instance-id"), + InstanceType: aws.String("instance-type"), + PrivateIpAddress: aws.String("10.24.34.0"), + }, + }, + OwnerId: aws.String("owner-id"), + }, + }, + wantMetadata: &Metadata{ + AccountID: "owner-id", + ImageID: "image-id", + InstanceID: "instance-id", + InstanceType: "instance-type", + PrivateIP: "10.24.34.0", + Region: "us-east-1", + }, + wantHostname: "ip-10-24-34-0.ec2.internal", + }, + "WithHostname/ResourceName": { + hostnameFn: func() (string, error) { + return "i-0123456789abcdef.us-west-2.compute.internal", nil + }, + reservations: []*ec2.Reservation{ + { + Instances: []*ec2.Instance{ + { + ImageId: aws.String("image-id"), + InstanceId: aws.String("i-0123456789abcdef"), + InstanceType: aws.String("instance-type"), + PrivateIpAddress: aws.String("private-ip"), + }, + }, + OwnerId: aws.String("owner-id"), + }, + }, + wantMetadata: &Metadata{ + AccountID: "owner-id", + ImageID: "image-id", + InstanceID: "i-0123456789abcdef", + InstanceType: "instance-type", + PrivateIP: "private-ip", + Region: "us-west-2", + }, + wantHostname: "i-0123456789abcdef.us-west-2.compute.internal", + }, + "WithHostname/Unsupported": { + hostnameFn: func() (string, error) { + return "hello.us-east-1.amazon.com", nil + }, + wantHostname: "hello.us-east-1.amazon.com", + wantGetErr: errUnsupportedHostname, + }, + "WithHostname/InvalidPrefix": { + hostnameFn: func() (string, error) { + return "other-prefix.us-west-2.compute.internal", nil + }, + wantHostname: "other-prefix.us-west-2.compute.internal", + wantGetErr: errUnsupportedFilter, + }, + "WithHostname/Error": { + hostnameFn: func() (string, error) { + return "", testErr + }, + wantHostname: "", + wantHostnameErr: testErr, + wantGetErr: testErr, + }, + "WithClient/Error": { + hostnameFn: func() (string, error) { + return "i-0123456789abcdef.us-west-2.compute.internal", nil + }, + clientErr: testErr, + wantHostname: "i-0123456789abcdef.us-west-2.compute.internal", + wantGetErr: testErr, + }, + "WithClient/NoReservations": { + hostnameFn: func() (string, error) { + return "i-0123456789abcdef.us-west-2.compute.internal", nil + }, + reservations: []*ec2.Reservation{}, + wantHostname: "i-0123456789abcdef.us-west-2.compute.internal", + wantGetErr: errReservationCount, + }, + "WithClient/NoInstances": { + hostnameFn: func() (string, error) { + return "i-0123456789abcdef.us-west-2.compute.internal", nil + }, + reservations: []*ec2.Reservation{ + {OwnerId: aws.String("owner-id")}, + }, + wantHostname: "i-0123456789abcdef.us-west-2.compute.internal", + wantGetErr: errInstanceCount, + }, + } + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + ctx := context.Background() + p := newDescribeInstancesMetadataProvider(awsmock.Session) + assert.Equal(t, "DescribeInstances", p.ID()) + mockClient := &mockEC2Client{ + reservations: testCase.reservations, + err: testCase.clientErr, + } + p.newEC2Client = func(_ client.ConfigProvider, configs ...*aws.Config) ec2iface.EC2API { + return mockClient + } + p.osHostname = testCase.hostnameFn + hostname, err := p.Hostname(ctx) + assert.ErrorIs(t, err, testCase.wantHostnameErr) + assert.Equal(t, testCase.wantHostname, hostname) + metadata, err := p.Get(ctx) + assert.ErrorIs(t, err, testCase.wantGetErr) + assert.Equal(t, testCase.wantMetadata, metadata) + }) + } +} diff --git a/internal/retryer/imdsretryer.go b/internal/retryer/imdsretryer.go index 29dec2976f..ba9bc290c3 100644 --- a/internal/retryer/imdsretryer.go +++ b/internal/retryer/imdsretryer.go @@ -27,7 +27,7 @@ type IMDSRetryer struct { // otel component layer retries should come from aws config settings // translator layer should come from env vars see GetDefaultRetryNumber() func NewIMDSRetryer(imdsRetries int) IMDSRetryer { - fmt.Printf("I! imds retry client will retry %d times", imdsRetries) + fmt.Printf("I! imds retry client will retry %d times\n", imdsRetries) return IMDSRetryer{ DefaultRetryer: client.DefaultRetryer{ NumMaxRetries: imdsRetries, @@ -43,7 +43,7 @@ func (r IMDSRetryer) ShouldRetry(req *request.Request) bool { if awsError, ok := req.Error.(awserr.Error); r.DefaultRetryer.ShouldRetry(req) || (ok && awsError != nil && awsError.Code() == "EC2MetadataError") { shouldRetry = true } - fmt.Printf("D! should retry %t for imds error : %v", shouldRetry, req.Error) + fmt.Printf("D! should retry %t for imds error : %v\n", shouldRetry, req.Error) return shouldRetry } diff --git a/plugins/processors/ec2tagger/ec2metadataprovider.go b/plugins/processors/ec2tagger/ec2metadataprovider.go deleted file mode 100644 index 6278f69dff..0000000000 --- a/plugins/processors/ec2tagger/ec2metadataprovider.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: MIT - -package ec2tagger - -import ( - "context" - "log" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/client" - "github.com/aws/aws-sdk-go/aws/ec2metadata" - - configaws "github.com/aws/amazon-cloudwatch-agent/cfg/aws" - "github.com/aws/amazon-cloudwatch-agent/extension/agenthealth/handler/stats/agent" - "github.com/aws/amazon-cloudwatch-agent/internal/retryer" -) - -type MetadataProvider interface { - Get(ctx context.Context) (ec2metadata.EC2InstanceIdentityDocument, error) - Hostname(ctx context.Context) (string, error) - InstanceID(ctx context.Context) (string, error) -} - -type metadataClient struct { - metadataFallbackDisabled *ec2metadata.EC2Metadata - metadataFallbackEnabled *ec2metadata.EC2Metadata -} - -var _ MetadataProvider = (*metadataClient)(nil) - -func NewMetadataProvider(p client.ConfigProvider, retries int) MetadataProvider { - disableFallbackConfig := &aws.Config{ - LogLevel: configaws.SDKLogLevel(), - Logger: configaws.SDKLogger{}, - Retryer: retryer.NewIMDSRetryer(retries), - EC2MetadataEnableFallback: aws.Bool(false), - } - enableFallbackConfig := &aws.Config{ - LogLevel: configaws.SDKLogLevel(), - Logger: configaws.SDKLogger{}, - } - return &metadataClient{ - metadataFallbackDisabled: ec2metadata.New(p, disableFallbackConfig), - metadataFallbackEnabled: ec2metadata.New(p, enableFallbackConfig), - } -} - -func (c *metadataClient) InstanceID(ctx context.Context) (string, error) { - instanceId, err := c.metadataFallbackDisabled.GetMetadataWithContext(ctx, "instance-id") - if err != nil { - log.Printf("D! could not get instance id without imds v1 fallback enable thus enable fallback") - instanceInner, errorInner := c.metadataFallbackEnabled.GetMetadataWithContext(ctx, "instance-id") - if errorInner == nil { - agent.UsageFlags().Set(agent.FlagIMDSFallbackSuccess) - } - return instanceInner, errorInner - } - return instanceId, err -} - -func (c *metadataClient) Hostname(ctx context.Context) (string, error) { - hostname, err := c.metadataFallbackDisabled.GetMetadataWithContext(ctx, "hostname") - if err != nil { - log.Printf("D! could not get hostname without imds v1 fallback enable thus enable fallback") - hostnameInner, errorInner := c.metadataFallbackEnabled.GetMetadataWithContext(ctx, "hostname") - if errorInner == nil { - agent.UsageFlags().Set(agent.FlagIMDSFallbackSuccess) - } - return hostnameInner, errorInner - } - return hostname, err -} - -func (c *metadataClient) Get(ctx context.Context) (ec2metadata.EC2InstanceIdentityDocument, error) { - instanceDocument, err := c.metadataFallbackDisabled.GetInstanceIdentityDocumentWithContext(ctx) - if err != nil { - log.Printf("D! could not get instance document without imds v1 fallback enable thus enable fallback") - instanceDocumentInner, errorInner := c.metadataFallbackEnabled.GetInstanceIdentityDocumentWithContext(ctx) - if errorInner == nil { - agent.UsageFlags().Set(agent.FlagIMDSFallbackSuccess) - } - return instanceDocumentInner, errorInner - } - return instanceDocument, err -} diff --git a/plugins/processors/ec2tagger/ec2metadataprovider_test.go b/plugins/processors/ec2tagger/ec2metadataprovider_test.go deleted file mode 100644 index 619b7d18e4..0000000000 --- a/plugins/processors/ec2tagger/ec2metadataprovider_test.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: MIT - -package ec2tagger - -import ( - "context" - "os" - "reflect" - "testing" - - "github.com/aws/aws-sdk-go/aws/ec2metadata" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/awstesting/mock" - "github.com/stretchr/testify/assert" -) - -func TestMetadataProvider_Get(t *testing.T) { - tests := []struct { - name string - ctx context.Context - sess *session.Session - expectDoc ec2metadata.EC2InstanceIdentityDocument - }{ - { - name: "mock session", - ctx: context.Background(), - sess: mock.Session, - expectDoc: ec2metadata.EC2InstanceIdentityDocument{}, - }, - } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - c := NewMetadataProvider(tc.sess, 0) - gotDoc, err := c.Get(tc.ctx) - assert.NotNil(t, err) - assert.Truef(t, reflect.DeepEqual(gotDoc, tc.expectDoc), "get() gotDoc: %v, expected: %v", gotDoc, tc.expectDoc) - }) - } -} - -func TestMetadataProvider_available(t *testing.T) { - tests := []struct { - name string - ctx context.Context - sess *session.Session - want error - }{ - { - name: "mock session", - ctx: context.Background(), - sess: mock.Session, - want: nil, - }, - } - - // For build environments where IMDS is disabled via environment variable, explicitly re-enable it. Otherwise the - // call to c.InstanceId() fails before even contacting the mock session. - // See https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html#envvars-list-AWS_EC2_METADATA_DISABLED - const awsEc2MetadataDisabledEnvVar = "AWS_EC2_METADATA_DISABLED" - val := os.Getenv(awsEc2MetadataDisabledEnvVar) - defer func() { assert.NoError(t, os.Setenv(awsEc2MetadataDisabledEnvVar, val)) }() - assert.NoError(t, os.Setenv(awsEc2MetadataDisabledEnvVar, "false")) - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - c := NewMetadataProvider(tc.sess, 0) - _, err := c.InstanceID(tc.ctx) - assert.ErrorIs(t, err, tc.want) - }) - } -} diff --git a/plugins/processors/ec2tagger/ec2tagger.go b/plugins/processors/ec2tagger/ec2tagger.go index 5d49ce5cf7..b3f1b3373c 100644 --- a/plugins/processors/ec2tagger/ec2tagger.go +++ b/plugins/processors/ec2tagger/ec2tagger.go @@ -19,7 +19,8 @@ import ( "go.uber.org/zap" configaws "github.com/aws/amazon-cloudwatch-agent/cfg/aws" - translatorCtx "github.com/aws/amazon-cloudwatch-agent/translator/context" + ec2metadata "github.com/aws/amazon-cloudwatch-agent/internal/metadata/ec2" + translatorcontext "github.com/aws/amazon-cloudwatch-agent/translator/context" ) type ec2MetadataLookupType struct { @@ -42,7 +43,7 @@ type Tagger struct { logger *zap.Logger cancelFunc context.CancelFunc - metadataProvider MetadataProvider + metadataProvider ec2metadata.MetadataProvider ec2Provider ec2ProviderType shutdownC chan bool @@ -59,15 +60,26 @@ type Tagger struct { // newTagger returns a new EC2 Tagger processor. func newTagger(config *Config, logger *zap.Logger) *Tagger { - _, cancel := context.WithCancel(context.Background()) - mdCredentialConfig := &configaws.CredentialConfig{} + mdCredentialConfig := &configaws.CredentialConfig{ + AccessKey: config.AccessKey, + SecretKey: config.SecretKey, + RoleARN: config.RoleARN, + Profile: config.Profile, + Filename: config.Filename, + Token: config.Token, + } p := &Tagger{ - Config: config, - logger: logger, - cancelFunc: cancel, - metadataProvider: NewMetadataProvider(mdCredentialConfig.Credentials(), config.IMDSRetries), + Config: config, + logger: logger, + cancelFunc: cancel, + metadataProvider: ec2metadata.NewMetadataProvider( + mdCredentialConfig.Credentials(), + ec2metadata.MetadataProviderConfig{ + IMDSv2Retries: config.IMDSRetries, + }, + ), ec2Provider: func(ec2CredentialConfig *configaws.CredentialConfig) ec2iface.EC2API { return ec2.New( ec2CredentialConfig.Credentials(), @@ -413,7 +425,7 @@ func (t *Tagger) setStarted() { /* Retrieve metadata from IMDS and use these metadata to: -* Extract InstanceID, ImageID, InstanceType to create custom dimension for collected metrics +* Extract InstanceID, imageID, InstanceType to create custom dimension for collected metrics * Extract InstanceID to retrieve Instance's Volume and Tags * Extract Region to create aws session with custom configuration For more information on IMDS, please follow this document https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html @@ -436,7 +448,7 @@ func (t *Tagger) deriveEC2MetadataFromIMDS(ctx context.Context) error { doc, err := t.metadataProvider.Get(ctx) if err != nil { t.logger.Error("ec2tagger: Unable to retrieve EC2 Metadata. This plugin must only be used on an EC2 instance.") - if translatorCtx.CurrentContext().RunInContainer() { + if translatorcontext.CurrentContext().RunInContainer() { t.logger.Warn("ec2tagger: Timeout may have occurred because hop limit is too small. Please increase hop limit to 2 by following this document https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-options.html#configuring-IMDS-existing-instances.") } return err diff --git a/plugins/processors/ec2tagger/ec2tagger_test.go b/plugins/processors/ec2tagger/ec2tagger_test.go index c9af7c2f1a..752483aee0 100644 --- a/plugins/processors/ec2tagger/ec2tagger_test.go +++ b/plugins/processors/ec2tagger/ec2tagger_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/stretchr/testify/assert" @@ -20,6 +19,7 @@ import ( "go.opentelemetry.io/collector/processor/processortest" configaws "github.com/aws/amazon-cloudwatch-agent/cfg/aws" + ec2metadata "github.com/aws/amazon-cloudwatch-agent/internal/metadata/ec2" ) type mockEC2Client struct { @@ -96,7 +96,7 @@ func (m *mockEC2Client) DescribeTags(*ec2.DescribeTagsInput) (*ec2.DescribeTagsO //when tags are not ready or customer doesn't have permission to call the api if m.tagsCallCount <= m.tagsFailLimit { m.tagsCallCount++ - return nil, errors.New("No tags available now") + return nil, errors.New("no tags available now") } //return partial tags to simulate the case @@ -177,7 +177,7 @@ func (m *mockEC2Client) DescribeVolumes(*ec2.DescribeVolumesInput) (*ec2.Describ //when the volumes are not ready or customer doesn't have permission to call the api if m.volumesCallCount <= m.volumesFailLimit { m.volumesCallCount++ - return nil, errors.New("No volumes available now") + return nil, errors.New("no volumes available now") } //return partial volumes to simulate the case @@ -200,26 +200,28 @@ func (m *mockEC2Client) DescribeVolumes(*ec2.DescribeVolumesInput) (*ec2.Describ return nil, nil } +var _ ec2metadata.MetadataProvider = (*mockMetadataProvider)(nil) + type mockMetadataProvider struct { - InstanceIdentityDocument *ec2metadata.EC2InstanceIdentityDocument + Metadata *ec2metadata.Metadata } -func (m *mockMetadataProvider) Get(ctx context.Context) (ec2metadata.EC2InstanceIdentityDocument, error) { - if m.InstanceIdentityDocument != nil { - return *m.InstanceIdentityDocument, nil - } - return ec2metadata.EC2InstanceIdentityDocument{}, errors.New("No instance identity document") +func (m *mockMetadataProvider) ID() string { + return "mock" } -func (m *mockMetadataProvider) Hostname(ctx context.Context) (string, error) { - return "MockHostName", nil +func (m *mockMetadataProvider) Get(context.Context) (*ec2metadata.Metadata, error) { + if m.Metadata != nil { + return m.Metadata, nil + } + return nil, errors.New("no metadata") } -func (m *mockMetadataProvider) InstanceID(ctx context.Context) (string, error) { - return "MockInstanceID", nil +func (m *mockMetadataProvider) Hostname(context.Context) (string, error) { + return "MockHostName", nil } -var mockedInstanceIdentityDoc = &ec2metadata.EC2InstanceIdentityDocument{ +var mockedMetadata = &ec2metadata.Metadata{ InstanceID: "i-01d2417c27a396e44", Region: "us-east-1", InstanceType: "m5ad.large", @@ -232,7 +234,7 @@ var mockedInstanceIdentityDoc = &ec2metadata.EC2InstanceIdentityDocument{ // pm.ResourceMetrics().At(0).ScopeMetrics().Len() == 1 // pm.ResourceMetrics().At(0).ScopeMetrics().At(0).Metrics().Len() == len(metrics) // -// and for each metric from metrics it create one single datapoint that appy all tags/attributes from metric +// and for each metric from metrics it creates one single datapoint that applies all tags/attributes from metric func createTestMetrics(metrics []map[string]string) pmetric.Metrics { pm := pmetric.NewMetrics() rm := pm.ResourceMetrics().AppendEmpty() @@ -297,12 +299,12 @@ func TestStartFailWithNoMetadata(t *testing.T) { Config: cfg, logger: processortest.NewNopCreateSettings().Logger, cancelFunc: cancel, - metadataProvider: &mockMetadataProvider{InstanceIdentityDocument: nil}, + metadataProvider: &mockMetadataProvider{Metadata: nil}, } err := tagger.Start(context.Background(), componenttest.NewNopHost()) assert.NotNil(t, err) - assert.Contains(t, err.Error(), "No instance identity document") + assert.Contains(t, err.Error(), "no metadata") } // run Start() and check all tags/volumes are retrieved and saved @@ -333,7 +335,7 @@ func TestStartSuccessWithNoTagsVolumesUpdate(t *testing.T) { Config: cfg, logger: processortest.NewNopCreateSettings().Logger, cancelFunc: cancel, - metadataProvider: &mockMetadataProvider{InstanceIdentityDocument: mockedInstanceIdentityDoc}, + metadataProvider: &mockMetadataProvider{Metadata: mockedMetadata}, ec2Provider: ec2Provider, } err := tagger.Start(context.Background(), componenttest.NewNopHost()) @@ -378,7 +380,7 @@ func TestStartSuccessWithTagsVolumesUpdate(t *testing.T) { Config: cfg, logger: processortest.NewNopCreateSettings().Logger, cancelFunc: cancel, - metadataProvider: &mockMetadataProvider{InstanceIdentityDocument: mockedInstanceIdentityDoc}, + metadataProvider: &mockMetadataProvider{Metadata: mockedMetadata}, ec2Provider: ec2Provider, } @@ -433,7 +435,7 @@ func TestStartSuccessWithWildcardTagVolumeKey(t *testing.T) { Config: cfg, logger: processortest.NewNopCreateSettings().Logger, cancelFunc: cancel, - metadataProvider: &mockMetadataProvider{InstanceIdentityDocument: mockedInstanceIdentityDoc}, + metadataProvider: &mockMetadataProvider{Metadata: mockedMetadata}, ec2Provider: ec2Provider, } @@ -480,7 +482,7 @@ func TestApplyWithTagsVolumesUpdate(t *testing.T) { Config: cfg, logger: processortest.NewNopCreateSettings().Logger, cancelFunc: cancel, - metadataProvider: &mockMetadataProvider{InstanceIdentityDocument: mockedInstanceIdentityDoc}, + metadataProvider: &mockMetadataProvider{Metadata: mockedMetadata}, ec2Provider: ec2Provider, } err := tagger.Start(context.Background(), componenttest.NewNopHost()) @@ -490,24 +492,20 @@ func TestApplyWithTagsVolumesUpdate(t *testing.T) { //so that all tags/volumes are retrieved time.Sleep(time.Second) md := createTestMetrics([]map[string]string{ - map[string]string{ - "host": "example.org", - }, - map[string]string{ - "device": device2, - }, + {"host": "example.org"}, + {"device": device2}, }) output, err := tagger.processMetrics(context.Background(), md) assert.Nil(t, err) expectedOutput := createTestMetrics([]map[string]string{ - map[string]string{ + { "AutoScalingGroupName": tagVal3, "InstanceId": "i-01d2417c27a396e44", "InstanceType": "m5ad.large", tagKey1: tagVal1, tagKey2: tagVal2, }, - map[string]string{ + { "AutoScalingGroupName": tagVal3, "EBSVolumeId": volumeAttachmentId2, "InstanceId": "i-01d2417c27a396e44", @@ -528,14 +526,14 @@ func TestApplyWithTagsVolumesUpdate(t *testing.T) { updatedOutput, err := tagger.processMetrics(context.Background(), md) assert.Nil(t, err) expectedUpdatedOutput := createTestMetrics([]map[string]string{ - map[string]string{ + { "AutoScalingGroupName": tagVal3, "InstanceId": "i-01d2417c27a396e44", "InstanceType": "m5ad.large", tagKey1: tagVal1, tagKey2: updatedTagVal2, }, - map[string]string{ + { "AutoScalingGroupName": tagVal3, "EBSVolumeId": volumeAttachmentUpdatedId2, "InstanceId": "i-01d2417c27a396e44", @@ -575,20 +573,14 @@ func TestMetricsDroppedBeforeStarted(t *testing.T) { Config: cfg, logger: processortest.NewNopCreateSettings().Logger, cancelFunc: cancel, - metadataProvider: &mockMetadataProvider{InstanceIdentityDocument: mockedInstanceIdentityDoc}, + metadataProvider: &mockMetadataProvider{Metadata: mockedMetadata}, ec2Provider: ec2Provider, } md := createTestMetrics([]map[string]string{ - map[string]string{ - "host": "example.org", - }, - map[string]string{ - "device": device1, - }, - map[string]string{ - "device": device2, - }, + {"host": "example.org"}, + {"device": device1}, + {"device": device2}, }) err := tagger.Start(context.Background(), componenttest.NewNopHost()) assert.Nil(t, err) @@ -643,7 +635,7 @@ func TestTaggerStartDoesNotBlock(t *testing.T) { Config: cfg, logger: processortest.NewNopCreateSettings().Logger, cancelFunc: cancel, - metadataProvider: &mockMetadataProvider{InstanceIdentityDocument: mockedInstanceIdentityDoc}, + metadataProvider: &mockMetadataProvider{Metadata: mockedMetadata}, ec2Provider: ec2Provider, } @@ -673,7 +665,7 @@ func TestTaggerStartsWithoutTagOrVolume(t *testing.T) { Config: cfg, logger: processortest.NewNopCreateSettings().Logger, cancelFunc: cancel, - metadataProvider: &mockMetadataProvider{InstanceIdentityDocument: mockedInstanceIdentityDoc}, + metadataProvider: &mockMetadataProvider{Metadata: mockedMetadata}, } deadline := time.NewTimer(1 * time.Second) diff --git a/tool/util/util.go b/tool/util/util.go index 0441fd8efc..699df06086 100644 --- a/tool/util/util.go +++ b/tool/util/util.go @@ -4,10 +4,10 @@ package util import ( + "context" "encoding/json" "fmt" "io" - "log" "os" "path" "path/filepath" @@ -16,12 +16,10 @@ import ( "strconv" "strings" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/session" - configaws "github.com/aws/amazon-cloudwatch-agent/cfg/aws" + ec2metadata "github.com/aws/amazon-cloudwatch-agent/internal/metadata/ec2" "github.com/aws/amazon-cloudwatch-agent/internal/retryer" "github.com/aws/amazon-cloudwatch-agent/tool/data/interfaze" "github.com/aws/amazon-cloudwatch-agent/tool/runtime" @@ -209,32 +207,22 @@ func SDKCredentials() (accessKey, secretKey string, creds *credentials.Credentia } func DefaultEC2Region() (region string) { - fmt.Println("Trying to fetch the default region based on ec2 metadata...") + fmt.Println("Trying to fetch the default region from EC2...") // imds should by the time user can run the wizard - sesFallBackDisabled, err := session.NewSession(&aws.Config{ - LogLevel: configaws.SDKLogLevel(), - Logger: configaws.SDKLogger{}, - EC2MetadataEnableFallback: aws.Bool(false), - Retryer: retryer.NewIMDSRetryer(retryer.GetDefaultRetryNumber()), - }) - sesFallBackEnabled, err := session.NewSession(&aws.Config{ - LogLevel: configaws.SDKLogLevel(), - Logger: configaws.SDKLogger{}, - }) + ses, err := session.NewSession() if err != nil { return } - md := ec2metadata.New(sesFallBackDisabled) - if info, errOuter := md.Region(); errOuter == nil { - region = info + metadataProvider := ec2metadata.NewMetadataProvider( + ses, + ec2metadata.MetadataProviderConfig{ + IMDSv2Retries: retryer.GetDefaultRetryNumber(), + }, + ) + if metadata, err := metadataProvider.Get(context.Background()); err != nil { + fmt.Printf("W! could not get region from EC2... %v", err) } else { - log.Printf("D! could not get region from imds v2 thus enable fallback") - mdInner := ec2metadata.New(sesFallBackEnabled) - if infoInner, errInner := mdInner.Region(); errInner == nil { - region = infoInner - } else { - fmt.Printf("W! could not get region from ec2 metadata... %v", errInner) - } + region = metadata.Region } return } diff --git a/translator/util/ec2util/ec2util.go b/translator/util/ec2util/ec2util.go index 480a08c8b8..aa991b2397 100644 --- a/translator/util/ec2util/ec2util.go +++ b/translator/util/ec2util/ec2util.go @@ -4,49 +4,38 @@ package ec2util import ( + "context" "fmt" "net" "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/session" - configaws "github.com/aws/amazon-cloudwatch-agent/cfg/aws" - "github.com/aws/amazon-cloudwatch-agent/extension/agenthealth/handler/stats/agent" + ec2metadata "github.com/aws/amazon-cloudwatch-agent/internal/metadata/ec2" "github.com/aws/amazon-cloudwatch-agent/internal/retryer" "github.com/aws/amazon-cloudwatch-agent/translator/config" - "github.com/aws/amazon-cloudwatch-agent/translator/context" + translatorcontext "github.com/aws/amazon-cloudwatch-agent/translator/context" ) -// this is a singleton struct -type ec2Util struct { - Region string - PrivateIP string - InstanceID string - Hostname string - AccountID string -} - var ( - ec2UtilInstance *ec2Util + ec2UtilInstance *ec2metadata.Metadata once sync.Once ) const allowedRetries = 5 -func GetEC2UtilSingleton() *ec2Util { +func GetEC2UtilSingleton() *ec2metadata.Metadata { once.Do(func() { ec2UtilInstance = initEC2UtilSingleton() }) return ec2UtilInstance } -func initEC2UtilSingleton() (newInstance *ec2Util) { - newInstance = &ec2Util{Region: "", PrivateIP: ""} +func initEC2UtilSingleton() (newInstance *ec2metadata.Metadata) { + newInstance = &ec2metadata.Metadata{Region: "", PrivateIP: ""} - if (context.CurrentContext().Mode() == config.ModeOnPrem) || (context.CurrentContext().Mode() == config.ModeOnPremise) { + if (translatorcontext.CurrentContext().Mode() == config.ModeOnPrem) || (translatorcontext.CurrentContext().Mode() == config.ModeOnPremise) { return } @@ -80,67 +69,40 @@ func initEC2UtilSingleton() (newInstance *ec2Util) { fmt.Println("E! [EC2] No available network interface") } - err := newInstance.deriveEC2MetadataFromIMDS() - - if err != nil { - fmt.Println("E! [EC2] Cannot get EC2 Metadata from IMDS:", err) + if err := populateEC2Metadata(newInstance); err != nil { + fmt.Println("E! [EC2] Cannot get EC2 Metadata", err) } return } -func (e *ec2Util) deriveEC2MetadataFromIMDS() error { +func populateEC2Metadata(metadata *ec2metadata.Metadata) error { ses, err := session.NewSession() - if err != nil { return err } - mdDisableFallback := ec2metadata.New(ses, &aws.Config{ - LogLevel: configaws.SDKLogLevel(), - Logger: configaws.SDKLogger{}, - Retryer: retryer.NewIMDSRetryer(retryer.GetDefaultRetryNumber()), - EC2MetadataEnableFallback: aws.Bool(false), - }) - mdEnableFallback := ec2metadata.New(ses, &aws.Config{ - LogLevel: configaws.SDKLogLevel(), - Logger: configaws.SDKLogger{}, - }) + ctx := context.Background() + metadataProvider := ec2metadata.NewMetadataProvider( + ses, + ec2metadata.MetadataProviderConfig{ + IMDSv2Retries: retryer.GetDefaultRetryNumber(), + }, + ) - // ec2 and ecs treats retries for getting host name differently - // More information on API: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html#instance-metadata-ex-2 - if hostname, err := mdDisableFallback.GetMetadata("hostname"); err == nil { - e.Hostname = hostname + if hostname, err := metadataProvider.Hostname(ctx); err != nil { + fmt.Println("E! [EC2] Fetch hostname from EC2 metadata fail:", err) } else { - fmt.Println("D! could not get hostname without imds v1 fallback enable thus enable fallback") - hostnameInner, errInner := mdEnableFallback.GetMetadata("hostname") - if errInner == nil { - e.Hostname = hostnameInner - agent.UsageFlags().Set(agent.FlagIMDSFallbackSuccess) - } else { - fmt.Println("E! [EC2] Fetch hostname from EC2 metadata fail:", errInner) - } + metadata.Hostname = hostname } - // More information on API: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html - if instanceIdentityDocument, err := mdDisableFallback.GetInstanceIdentityDocument(); err == nil { - e.Region = instanceIdentityDocument.Region - e.AccountID = instanceIdentityDocument.AccountID - e.PrivateIP = instanceIdentityDocument.PrivateIP - e.InstanceID = instanceIdentityDocument.InstanceID + if md, err := metadataProvider.Get(ctx); err != nil { + fmt.Println("E! [EC2] Fetch identity document from EC2 metadata fail:", err) } else { - fmt.Println("D! could not get instance document without imds v1 fallback enable thus enable fallback") - instanceIdentityDocumentInner, errInner := mdEnableFallback.GetInstanceIdentityDocument() - if errInner == nil { - e.Region = instanceIdentityDocumentInner.Region - e.AccountID = instanceIdentityDocumentInner.AccountID - e.PrivateIP = instanceIdentityDocumentInner.PrivateIP - e.InstanceID = instanceIdentityDocumentInner.InstanceID - agent.UsageFlags().Set(agent.FlagIMDSFallbackSuccess) - } else { - fmt.Println("E! [EC2] Fetch identity document from EC2 metadata fail:", errInner) - } + metadata.AccountID = md.AccountID + metadata.InstanceID = md.InstanceID + metadata.PrivateIP = md.PrivateIP + metadata.Region = md.Region } - return nil }