Skip to content

Commit

Permalink
Add IMDS fallback with DescribeInstances.
Browse files Browse the repository at this point in the history
  • Loading branch information
jefchien committed Apr 16, 2024
1 parent be89919 commit 508d3c9
Show file tree
Hide file tree
Showing 15 changed files with 869 additions and 303 deletions.
51 changes: 51 additions & 0 deletions internal/metadata/ec2/chain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT

package ec2

import (
"context"
"errors"
"fmt"
"strings"
)

type chainMetadataProvider struct {
providers []MetadataProvider
}

func newChainMetadataProvider(providers []MetadataProvider) *chainMetadataProvider {
return &chainMetadataProvider{providers: providers}
}

func (p *chainMetadataProvider) ID() string {
var providerIDs []string
for _, provider := range p.providers {
providerIDs = append(providerIDs, provider.ID())
}
return fmt.Sprintf("Chain [%s]", strings.Join(providerIDs, ","))
}

func (p *chainMetadataProvider) Get(ctx context.Context) (*Metadata, error) {
var errs error
for _, provider := range p.providers {
if metadata, err := provider.Get(ctx); err != nil {
errs = errors.Join(errs, fmt.Errorf("unable to get metadata from %s: %w", provider.ID(), err))
} else {
return metadata, nil
}
}
return nil, errs
}

func (p *chainMetadataProvider) Hostname(ctx context.Context) (string, error) {
var errs error
for _, provider := range p.providers {
if hostname, err := provider.Hostname(ctx); err != nil {
errs = errors.Join(errs, fmt.Errorf("unable to get hostname from %s: %w", provider.ID(), err))
} else {
return hostname, nil
}
}
return "", errs
}
122 changes: 122 additions & 0 deletions internal/metadata/ec2/chain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT

package ec2

import (
"context"
"errors"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

type mockMetadataProvider struct {
Index int
Metadata *Metadata
Err error
}

func (m *mockMetadataProvider) ID() string {
return fmt.Sprintf("mock/%v", m.Index)
}

func (m *mockMetadataProvider) Get(context.Context) (*Metadata, error) {
if m.Metadata != nil {
return m.Metadata, nil
}
return nil, m.Err
}

func (m *mockMetadataProvider) Hostname(context.Context) (string, error) {
if m.Metadata != nil && m.Metadata.Hostname != "" {
return m.Metadata.Hostname, nil
}
return "", m.Err
}

func TestChainProvider(t *testing.T) {
errFirstTest := errors.New("skip first")
errSecondTest := errors.New("skip second")
testCases := map[string]struct {
providers []MetadataProvider
wantID string
wantMetadata *Metadata
wantHostname string
wantErr error
}{
"WithErrors": {
providers: []MetadataProvider{
&mockMetadataProvider{
Index: 1,
Err: errFirstTest,
},
&mockMetadataProvider{
Index: 2,
Err: errSecondTest,
},
},
wantID: "Chain [mock/1,mock/2]",
wantErr: errSecondTest,
},
"WithEarlyChainSuccess": {
providers: []MetadataProvider{
&mockMetadataProvider{
Index: 1,
Metadata: &Metadata{
Hostname: "hostname-1",
InstanceID: "instance-id-1",
},
},
&mockMetadataProvider{
Index: 2,
Metadata: &Metadata{
Hostname: "hostname-2",
InstanceID: "instance-id-2",
},
},
},
wantID: "Chain [mock/1,mock/2]",
wantHostname: "hostname-1",
wantMetadata: &Metadata{
Hostname: "hostname-1",
InstanceID: "instance-id-1",
},
},
"WithFallback": {
providers: []MetadataProvider{
&mockMetadataProvider{
Index: 1,
Err: errFirstTest,
},
&mockMetadataProvider{
Index: 2,
Metadata: &Metadata{
Hostname: "hostname-2",
InstanceID: "instance-id-2",
},
},
},
wantID: "Chain [mock/1,mock/2]",
wantHostname: "hostname-2",
wantMetadata: &Metadata{
Hostname: "hostname-2",
InstanceID: "instance-id-2",
},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
ctx := context.Background()
p := newChainMetadataProvider(testCase.providers)
assert.Equal(t, testCase.wantID, p.ID())
hostname, err := p.Hostname(ctx)
assert.ErrorIs(t, err, testCase.wantErr)
assert.Equal(t, testCase.wantHostname, hostname)
metadata, err := p.Get(ctx)
assert.ErrorIs(t, err, testCase.wantErr)
assert.Equal(t, testCase.wantMetadata, metadata)
})
}
}
43 changes: 43 additions & 0 deletions internal/metadata/ec2/ec2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT

package ec2

import (
"context"

"github.com/aws/aws-sdk-go/aws/client"
)

// Metadata is a set of information about the EC2 instance.
type Metadata struct {
AccountID string
Hostname string
ImageID string
InstanceID string
InstanceType string
PrivateIP string
Region string
}

type MetadataProviderConfig struct {
// IMDSv2Retries is the number of retries the IMDSv2 MetadataProvider will make before it errors out.
IMDSv2Retries int
}

// MetadataProvider provides functions to get EC2 Metadata and the hostname.
type MetadataProvider interface {
Get(ctx context.Context) (*Metadata, error)
Hostname(ctx context.Context) (string, error)
ID() string
}

func NewMetadataProvider(configProvider client.ConfigProvider, config MetadataProviderConfig) MetadataProvider {
return newChainMetadataProvider(
[]MetadataProvider{
newIMDSv2MetadataProvider(configProvider, config.IMDSv2Retries),
newIMDSv1MetadataProvider(configProvider),
newDescribeInstancesMetadataProvider(configProvider),
},
)
}
21 changes: 21 additions & 0 deletions internal/metadata/ec2/ec2_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT

package ec2

import (
"testing"

awsmock "github.com/aws/aws-sdk-go/awstesting/mock"
"github.com/stretchr/testify/assert"
)

func TestNewMetadataProvider(t *testing.T) {
mp := NewMetadataProvider(
awsmock.Session,
MetadataProviderConfig{IMDSv2Retries: 0},
)
cmp, ok := mp.(*chainMetadataProvider)
assert.True(t, ok)
assert.Len(t, cmp.providers, 3)
}
99 changes: 99 additions & 0 deletions internal/metadata/ec2/imds.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT

package ec2

import (
"context"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/ec2metadata"

configaws "github.com/aws/amazon-cloudwatch-agent/cfg/aws"
"github.com/aws/amazon-cloudwatch-agent/extension/agenthealth/handler/stats/agent"
"github.com/aws/amazon-cloudwatch-agent/internal/retryer"
)

type imdsVersion string

const (
IMDSv1 imdsVersion = "IMDSv1"
IMDSv2 imdsVersion = "IMDSv2"

metadataKeyHostname = "hostname"
)

type ec2MetadataClient interface {
GetMetadataWithContext(ctx aws.Context, p string) (string, error)
GetInstanceIdentityDocumentWithContext(ctx aws.Context) (ec2metadata.EC2InstanceIdentityDocument, error)
}

type imdsMetadataProvider struct {
version imdsVersion
svc ec2MetadataClient
}

var _ MetadataProvider = (*imdsMetadataProvider)(nil)

func newIMDSv2MetadataProvider(configProvider client.ConfigProvider, retries int) *imdsMetadataProvider {
return newIMDSProvider(IMDSv2, configProvider, &aws.Config{
LogLevel: configaws.SDKLogLevel(),
Logger: configaws.SDKLogger{},
Retryer: retryer.NewIMDSRetryer(retries),
EC2MetadataEnableFallback: aws.Bool(false),
})
}

func newIMDSv1MetadataProvider(configProvider client.ConfigProvider) *imdsMetadataProvider {
return newIMDSProvider(IMDSv1, configProvider, &aws.Config{
LogLevel: configaws.SDKLogLevel(),
Logger: configaws.SDKLogger{},
})
}

func newIMDSProvider(version imdsVersion, configProvider client.ConfigProvider, config *aws.Config) *imdsMetadataProvider {
return &imdsMetadataProvider{
svc: ec2metadata.New(configProvider, config),
version: version,
}
}

func (p *imdsMetadataProvider) ID() string {
return string(p.version)
}

// Hostname more information on API: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html#instance-metadata-ex-2
func (p *imdsMetadataProvider) Hostname(ctx context.Context) (string, error) {
hostname, err := p.svc.GetMetadataWithContext(ctx, metadataKeyHostname)
if err != nil {
return "", err
}
if p.version == IMDSv1 {
agent.UsageFlags().Set(agent.FlagIMDSFallbackSuccess)
}
return hostname, nil
}

// Get more information on API: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html
func (p *imdsMetadataProvider) Get(ctx context.Context) (*Metadata, error) {
instanceDocument, err := p.svc.GetInstanceIdentityDocumentWithContext(ctx)
if err != nil {
return nil, err
}
if p.version == IMDSv1 {
agent.UsageFlags().Set(agent.FlagIMDSFallbackSuccess)
}
return fromInstanceIdentityDocument(instanceDocument), nil
}

func fromInstanceIdentityDocument(document ec2metadata.EC2InstanceIdentityDocument) *Metadata {
return &Metadata{
AccountID: document.AccountID,
ImageID: document.ImageID,
InstanceID: document.InstanceID,
InstanceType: document.InstanceType,
PrivateIP: document.PrivateIP,
Region: document.Region,
}
}
Loading

0 comments on commit 508d3c9

Please sign in to comment.