-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add IMDS fallback with DescribeInstances.
- Loading branch information
Showing
15 changed files
with
869 additions
and
303 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
// SPDX-License-Identifier: MIT | ||
|
||
package ec2 | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"strings" | ||
) | ||
|
||
type chainMetadataProvider struct { | ||
providers []MetadataProvider | ||
} | ||
|
||
func newChainMetadataProvider(providers []MetadataProvider) *chainMetadataProvider { | ||
return &chainMetadataProvider{providers: providers} | ||
} | ||
|
||
func (p *chainMetadataProvider) ID() string { | ||
var providerIDs []string | ||
for _, provider := range p.providers { | ||
providerIDs = append(providerIDs, provider.ID()) | ||
} | ||
return fmt.Sprintf("Chain [%s]", strings.Join(providerIDs, ",")) | ||
} | ||
|
||
func (p *chainMetadataProvider) Get(ctx context.Context) (*Metadata, error) { | ||
var errs error | ||
for _, provider := range p.providers { | ||
if metadata, err := provider.Get(ctx); err != nil { | ||
errs = errors.Join(errs, fmt.Errorf("unable to get metadata from %s: %w", provider.ID(), err)) | ||
} else { | ||
return metadata, nil | ||
} | ||
} | ||
return nil, errs | ||
} | ||
|
||
func (p *chainMetadataProvider) Hostname(ctx context.Context) (string, error) { | ||
var errs error | ||
for _, provider := range p.providers { | ||
if hostname, err := provider.Hostname(ctx); err != nil { | ||
errs = errors.Join(errs, fmt.Errorf("unable to get hostname from %s: %w", provider.ID(), err)) | ||
} else { | ||
return hostname, nil | ||
} | ||
} | ||
return "", errs | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
// SPDX-License-Identifier: MIT | ||
|
||
package ec2 | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
type mockMetadataProvider struct { | ||
Index int | ||
Metadata *Metadata | ||
Err error | ||
} | ||
|
||
func (m *mockMetadataProvider) ID() string { | ||
return fmt.Sprintf("mock/%v", m.Index) | ||
} | ||
|
||
func (m *mockMetadataProvider) Get(context.Context) (*Metadata, error) { | ||
if m.Metadata != nil { | ||
return m.Metadata, nil | ||
} | ||
return nil, m.Err | ||
} | ||
|
||
func (m *mockMetadataProvider) Hostname(context.Context) (string, error) { | ||
if m.Metadata != nil && m.Metadata.Hostname != "" { | ||
return m.Metadata.Hostname, nil | ||
} | ||
return "", m.Err | ||
} | ||
|
||
func TestChainProvider(t *testing.T) { | ||
errFirstTest := errors.New("skip first") | ||
errSecondTest := errors.New("skip second") | ||
testCases := map[string]struct { | ||
providers []MetadataProvider | ||
wantID string | ||
wantMetadata *Metadata | ||
wantHostname string | ||
wantErr error | ||
}{ | ||
"WithErrors": { | ||
providers: []MetadataProvider{ | ||
&mockMetadataProvider{ | ||
Index: 1, | ||
Err: errFirstTest, | ||
}, | ||
&mockMetadataProvider{ | ||
Index: 2, | ||
Err: errSecondTest, | ||
}, | ||
}, | ||
wantID: "Chain [mock/1,mock/2]", | ||
wantErr: errSecondTest, | ||
}, | ||
"WithEarlyChainSuccess": { | ||
providers: []MetadataProvider{ | ||
&mockMetadataProvider{ | ||
Index: 1, | ||
Metadata: &Metadata{ | ||
Hostname: "hostname-1", | ||
InstanceID: "instance-id-1", | ||
}, | ||
}, | ||
&mockMetadataProvider{ | ||
Index: 2, | ||
Metadata: &Metadata{ | ||
Hostname: "hostname-2", | ||
InstanceID: "instance-id-2", | ||
}, | ||
}, | ||
}, | ||
wantID: "Chain [mock/1,mock/2]", | ||
wantHostname: "hostname-1", | ||
wantMetadata: &Metadata{ | ||
Hostname: "hostname-1", | ||
InstanceID: "instance-id-1", | ||
}, | ||
}, | ||
"WithFallback": { | ||
providers: []MetadataProvider{ | ||
&mockMetadataProvider{ | ||
Index: 1, | ||
Err: errFirstTest, | ||
}, | ||
&mockMetadataProvider{ | ||
Index: 2, | ||
Metadata: &Metadata{ | ||
Hostname: "hostname-2", | ||
InstanceID: "instance-id-2", | ||
}, | ||
}, | ||
}, | ||
wantID: "Chain [mock/1,mock/2]", | ||
wantHostname: "hostname-2", | ||
wantMetadata: &Metadata{ | ||
Hostname: "hostname-2", | ||
InstanceID: "instance-id-2", | ||
}, | ||
}, | ||
} | ||
for name, testCase := range testCases { | ||
t.Run(name, func(t *testing.T) { | ||
ctx := context.Background() | ||
p := newChainMetadataProvider(testCase.providers) | ||
assert.Equal(t, testCase.wantID, p.ID()) | ||
hostname, err := p.Hostname(ctx) | ||
assert.ErrorIs(t, err, testCase.wantErr) | ||
assert.Equal(t, testCase.wantHostname, hostname) | ||
metadata, err := p.Get(ctx) | ||
assert.ErrorIs(t, err, testCase.wantErr) | ||
assert.Equal(t, testCase.wantMetadata, metadata) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
// SPDX-License-Identifier: MIT | ||
|
||
package ec2 | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/aws/aws-sdk-go/aws/client" | ||
) | ||
|
||
// Metadata is a set of information about the EC2 instance. | ||
type Metadata struct { | ||
AccountID string | ||
Hostname string | ||
ImageID string | ||
InstanceID string | ||
InstanceType string | ||
PrivateIP string | ||
Region string | ||
} | ||
|
||
type MetadataProviderConfig struct { | ||
// IMDSv2Retries is the number of retries the IMDSv2 MetadataProvider will make before it errors out. | ||
IMDSv2Retries int | ||
} | ||
|
||
// MetadataProvider provides functions to get EC2 Metadata and the hostname. | ||
type MetadataProvider interface { | ||
Get(ctx context.Context) (*Metadata, error) | ||
Hostname(ctx context.Context) (string, error) | ||
ID() string | ||
} | ||
|
||
func NewMetadataProvider(configProvider client.ConfigProvider, config MetadataProviderConfig) MetadataProvider { | ||
return newChainMetadataProvider( | ||
[]MetadataProvider{ | ||
newIMDSv2MetadataProvider(configProvider, config.IMDSv2Retries), | ||
newIMDSv1MetadataProvider(configProvider), | ||
newDescribeInstancesMetadataProvider(configProvider), | ||
}, | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
// SPDX-License-Identifier: MIT | ||
|
||
package ec2 | ||
|
||
import ( | ||
"testing" | ||
|
||
awsmock "github.com/aws/aws-sdk-go/awstesting/mock" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestNewMetadataProvider(t *testing.T) { | ||
mp := NewMetadataProvider( | ||
awsmock.Session, | ||
MetadataProviderConfig{IMDSv2Retries: 0}, | ||
) | ||
cmp, ok := mp.(*chainMetadataProvider) | ||
assert.True(t, ok) | ||
assert.Len(t, cmp.providers, 3) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
// SPDX-License-Identifier: MIT | ||
|
||
package ec2 | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/aws/client" | ||
"github.com/aws/aws-sdk-go/aws/ec2metadata" | ||
|
||
configaws "github.com/aws/amazon-cloudwatch-agent/cfg/aws" | ||
"github.com/aws/amazon-cloudwatch-agent/extension/agenthealth/handler/stats/agent" | ||
"github.com/aws/amazon-cloudwatch-agent/internal/retryer" | ||
) | ||
|
||
type imdsVersion string | ||
|
||
const ( | ||
IMDSv1 imdsVersion = "IMDSv1" | ||
IMDSv2 imdsVersion = "IMDSv2" | ||
|
||
metadataKeyHostname = "hostname" | ||
) | ||
|
||
type ec2MetadataClient interface { | ||
GetMetadataWithContext(ctx aws.Context, p string) (string, error) | ||
GetInstanceIdentityDocumentWithContext(ctx aws.Context) (ec2metadata.EC2InstanceIdentityDocument, error) | ||
} | ||
|
||
type imdsMetadataProvider struct { | ||
version imdsVersion | ||
svc ec2MetadataClient | ||
} | ||
|
||
var _ MetadataProvider = (*imdsMetadataProvider)(nil) | ||
|
||
func newIMDSv2MetadataProvider(configProvider client.ConfigProvider, retries int) *imdsMetadataProvider { | ||
return newIMDSProvider(IMDSv2, configProvider, &aws.Config{ | ||
LogLevel: configaws.SDKLogLevel(), | ||
Logger: configaws.SDKLogger{}, | ||
Retryer: retryer.NewIMDSRetryer(retries), | ||
EC2MetadataEnableFallback: aws.Bool(false), | ||
}) | ||
} | ||
|
||
func newIMDSv1MetadataProvider(configProvider client.ConfigProvider) *imdsMetadataProvider { | ||
return newIMDSProvider(IMDSv1, configProvider, &aws.Config{ | ||
LogLevel: configaws.SDKLogLevel(), | ||
Logger: configaws.SDKLogger{}, | ||
}) | ||
} | ||
|
||
func newIMDSProvider(version imdsVersion, configProvider client.ConfigProvider, config *aws.Config) *imdsMetadataProvider { | ||
return &imdsMetadataProvider{ | ||
svc: ec2metadata.New(configProvider, config), | ||
version: version, | ||
} | ||
} | ||
|
||
func (p *imdsMetadataProvider) ID() string { | ||
return string(p.version) | ||
} | ||
|
||
// Hostname more information on API: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html#instance-metadata-ex-2 | ||
func (p *imdsMetadataProvider) Hostname(ctx context.Context) (string, error) { | ||
hostname, err := p.svc.GetMetadataWithContext(ctx, metadataKeyHostname) | ||
if err != nil { | ||
return "", err | ||
} | ||
if p.version == IMDSv1 { | ||
agent.UsageFlags().Set(agent.FlagIMDSFallbackSuccess) | ||
} | ||
return hostname, nil | ||
} | ||
|
||
// Get more information on API: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html | ||
func (p *imdsMetadataProvider) Get(ctx context.Context) (*Metadata, error) { | ||
instanceDocument, err := p.svc.GetInstanceIdentityDocumentWithContext(ctx) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if p.version == IMDSv1 { | ||
agent.UsageFlags().Set(agent.FlagIMDSFallbackSuccess) | ||
} | ||
return fromInstanceIdentityDocument(instanceDocument), nil | ||
} | ||
|
||
func fromInstanceIdentityDocument(document ec2metadata.EC2InstanceIdentityDocument) *Metadata { | ||
return &Metadata{ | ||
AccountID: document.AccountID, | ||
ImageID: document.ImageID, | ||
InstanceID: document.InstanceID, | ||
InstanceType: document.InstanceType, | ||
PrivateIP: document.PrivateIP, | ||
Region: document.Region, | ||
} | ||
} |
Oops, something went wrong.