Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IMDS fallback with DescribeInstances. #1139

Merged
merged 2 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions internal/metadata/ec2/chain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT

package ec2

import (
"context"
"errors"
"fmt"
"strings"
)

type chainMetadataProvider struct {
providers []MetadataProvider
}

func newChainMetadataProvider(providers []MetadataProvider) *chainMetadataProvider {
return &chainMetadataProvider{providers: providers}
}

func (p *chainMetadataProvider) ID() string {
var providerIDs []string
for _, provider := range p.providers {
providerIDs = append(providerIDs, provider.ID())
}
return fmt.Sprintf("Chain [%s]", strings.Join(providerIDs, ","))
}

func (p *chainMetadataProvider) Get(ctx context.Context) (*Metadata, error) {
var errs error
for _, provider := range p.providers {
if metadata, err := provider.Get(ctx); err != nil {
errs = errors.Join(errs, fmt.Errorf("unable to get metadata from %s: %w", provider.ID(), err))
} else {
return metadata, nil
}
}
return nil, errs
}

func (p *chainMetadataProvider) Hostname(ctx context.Context) (string, error) {
var errs error
for _, provider := range p.providers {
if hostname, err := provider.Hostname(ctx); err != nil {
errs = errors.Join(errs, fmt.Errorf("unable to get hostname from %s: %w", provider.ID(), err))
} else {
return hostname, nil
}
}
return "", errs
}
122 changes: 122 additions & 0 deletions internal/metadata/ec2/chain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT

package ec2

import (
"context"
"errors"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

type mockMetadataProvider struct {
Index int
Metadata *Metadata
Err error
}

func (m *mockMetadataProvider) ID() string {
return fmt.Sprintf("mock/%v", m.Index)
}

func (m *mockMetadataProvider) Get(context.Context) (*Metadata, error) {
if m.Metadata != nil {
return m.Metadata, nil
}
return nil, m.Err
}

func (m *mockMetadataProvider) Hostname(context.Context) (string, error) {
if m.Metadata != nil && m.Metadata.Hostname != "" {
return m.Metadata.Hostname, nil
}
return "", m.Err
}

func TestChainProvider(t *testing.T) {
errFirstTest := errors.New("skip first")
errSecondTest := errors.New("skip second")
testCases := map[string]struct {
providers []MetadataProvider
wantID string
wantMetadata *Metadata
wantHostname string
wantErr error
}{
"WithErrors": {
providers: []MetadataProvider{
&mockMetadataProvider{
Index: 1,
Err: errFirstTest,
},
&mockMetadataProvider{
Index: 2,
Err: errSecondTest,
},
},
wantID: "Chain [mock/1,mock/2]",
wantErr: errSecondTest,
},
"WithEarlyChainSuccess": {
providers: []MetadataProvider{
&mockMetadataProvider{
Index: 1,
Metadata: &Metadata{
Hostname: "hostname-1",
InstanceID: "instance-id-1",
},
},
&mockMetadataProvider{
Index: 2,
Metadata: &Metadata{
Hostname: "hostname-2",
InstanceID: "instance-id-2",
},
},
},
wantID: "Chain [mock/1,mock/2]",
wantHostname: "hostname-1",
wantMetadata: &Metadata{
Hostname: "hostname-1",
InstanceID: "instance-id-1",
},
},
"WithFallback": {
providers: []MetadataProvider{
&mockMetadataProvider{
Index: 1,
Err: errFirstTest,
},
&mockMetadataProvider{
Index: 2,
Metadata: &Metadata{
Hostname: "hostname-2",
InstanceID: "instance-id-2",
},
},
},
wantID: "Chain [mock/1,mock/2]",
wantHostname: "hostname-2",
wantMetadata: &Metadata{
Hostname: "hostname-2",
InstanceID: "instance-id-2",
},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
ctx := context.Background()
p := newChainMetadataProvider(testCase.providers)
assert.Equal(t, testCase.wantID, p.ID())
hostname, err := p.Hostname(ctx)
assert.ErrorIs(t, err, testCase.wantErr)
assert.Equal(t, testCase.wantHostname, hostname)
metadata, err := p.Get(ctx)
assert.ErrorIs(t, err, testCase.wantErr)
assert.Equal(t, testCase.wantMetadata, metadata)
})
}
}
43 changes: 43 additions & 0 deletions internal/metadata/ec2/ec2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT

package ec2

import (
"context"

"github.com/aws/aws-sdk-go/aws/client"
)

// Metadata is a set of information about the EC2 instance.
type Metadata struct {
AccountID string
Hostname string
ImageID string
InstanceID string
InstanceType string
PrivateIP string
Region string
}

type MetadataProviderConfig struct {
// IMDSv2Retries is the number of retries the IMDSv2 MetadataProvider will make before it errors out.
IMDSv2Retries int
}

// MetadataProvider provides functions to get EC2 Metadata and the hostname.
type MetadataProvider interface {
Get(ctx context.Context) (*Metadata, error)
Hostname(ctx context.Context) (string, error)
ID() string
}

func NewMetadataProvider(configProvider client.ConfigProvider, config MetadataProviderConfig) MetadataProvider {
return newChainMetadataProvider(
[]MetadataProvider{
newIMDSv2MetadataProvider(configProvider, config.IMDSv2Retries),
newIMDSv1MetadataProvider(configProvider),
newDescribeInstancesMetadataProvider(configProvider),
},
)
}
21 changes: 21 additions & 0 deletions internal/metadata/ec2/ec2_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT

package ec2

import (
"testing"

awsmock "github.com/aws/aws-sdk-go/awstesting/mock"
"github.com/stretchr/testify/assert"
)

func TestNewMetadataProvider(t *testing.T) {
mp := NewMetadataProvider(
awsmock.Session,
MetadataProviderConfig{IMDSv2Retries: 0},
)
cmp, ok := mp.(*chainMetadataProvider)
assert.True(t, ok)
assert.Len(t, cmp.providers, 3)
}
99 changes: 99 additions & 0 deletions internal/metadata/ec2/imds.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT

package ec2

import (
"context"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/ec2metadata"

configaws "github.com/aws/amazon-cloudwatch-agent/cfg/aws"
"github.com/aws/amazon-cloudwatch-agent/extension/agenthealth/handler/stats/agent"
"github.com/aws/amazon-cloudwatch-agent/internal/retryer"
)

type imdsVersion string

const (
IMDSv1 imdsVersion = "IMDSv1"
IMDSv2 imdsVersion = "IMDSv2"

metadataKeyHostname = "hostname"
)

type ec2MetadataClient interface {
GetMetadataWithContext(ctx aws.Context, p string) (string, error)
GetInstanceIdentityDocumentWithContext(ctx aws.Context) (ec2metadata.EC2InstanceIdentityDocument, error)
}

type imdsMetadataProvider struct {
version imdsVersion
svc ec2MetadataClient
}

var _ MetadataProvider = (*imdsMetadataProvider)(nil)

func newIMDSv2MetadataProvider(configProvider client.ConfigProvider, retries int) *imdsMetadataProvider {
return newIMDSProvider(IMDSv2, configProvider, &aws.Config{
LogLevel: configaws.SDKLogLevel(),
Logger: configaws.SDKLogger{},
Retryer: retryer.NewIMDSRetryer(retries),
EC2MetadataEnableFallback: aws.Bool(false),
})
}

func newIMDSv1MetadataProvider(configProvider client.ConfigProvider) *imdsMetadataProvider {
return newIMDSProvider(IMDSv1, configProvider, &aws.Config{
LogLevel: configaws.SDKLogLevel(),
Logger: configaws.SDKLogger{},
})
}

func newIMDSProvider(version imdsVersion, configProvider client.ConfigProvider, config *aws.Config) *imdsMetadataProvider {
return &imdsMetadataProvider{
svc: ec2metadata.New(configProvider, config),
version: version,
}
}

func (p *imdsMetadataProvider) ID() string {
return string(p.version)
}

// Hostname more information on API: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html#instance-metadata-ex-2
func (p *imdsMetadataProvider) Hostname(ctx context.Context) (string, error) {
hostname, err := p.svc.GetMetadataWithContext(ctx, metadataKeyHostname)
if err != nil {
return "", err
}
if p.version == IMDSv1 {
agent.UsageFlags().Set(agent.FlagIMDSFallbackSuccess)
}
return hostname, nil
}

// Get more information on API: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html
func (p *imdsMetadataProvider) Get(ctx context.Context) (*Metadata, error) {
instanceDocument, err := p.svc.GetInstanceIdentityDocumentWithContext(ctx)
if err != nil {
return nil, err
}
if p.version == IMDSv1 {
agent.UsageFlags().Set(agent.FlagIMDSFallbackSuccess)
}
return fromInstanceIdentityDocument(instanceDocument), nil
}

func fromInstanceIdentityDocument(document ec2metadata.EC2InstanceIdentityDocument) *Metadata {
return &Metadata{
AccountID: document.AccountID,
ImageID: document.ImageID,
InstanceID: document.InstanceID,
InstanceType: document.InstanceType,
PrivateIP: document.PrivateIP,
Region: document.Region,
}
}
Loading