Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use aws-sdk-v2 for ssm backend #886

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 46 additions & 32 deletions backends/ssm/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,53 @@ package ssm

import (
"os"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ssm"
"context"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/ssm"
"github.com/kelseyhightower/confd/log"
)

type Client struct {
client *ssm.SSM
client *ssm.Client
}

func New() (*Client, error) {
// Create a session to share configuration, and load external configuration.
sess := session.Must(session.NewSession())
var region string

// get region from metadata service unless provided by env
if os.Getenv("AWS_REGION") != "" {
region = os.Getenv("AWS_REGION")
} else {
cfg, err := config.LoadDefaultConfig(context.TODO())

if err != nil {
return nil, err
}

imds_client := imds.NewFromConfig(cfg)
response, err := imds_client.GetRegion(context.TODO(), &imds.GetRegionInput{})
if err != nil {
return nil, err
}
region = response.Region
}

// Fail early, if no credentials can be found
_, err := sess.Config.Credentials.Get()
// Create the service's client with the config.
ssm_cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region))
if err != nil {
return nil, err
}

var c *aws.Config
if os.Getenv("SSM_LOCAL") != "" {
log.Debug("SSM_LOCAL is set")
endpoint := "http://localhost:8001"
c = &aws.Config{
Endpoint: &endpoint,
// if SSM_LOCAL is set override the endpoint configuration
ssm_client := ssm.NewFromConfig(ssm_cfg, func (o *ssm.Options) {
if os.Getenv("SSM_LOCAL") != "" {
log.Debug("SSM_LOCAL is set")
o.BaseEndpoint = aws.String("http://localhost:8001/")
}
} else {
c = nil
}

// Create the service's client with the session.
svc := ssm.New(sess, c)
return &Client{svc}, nil
})
return &Client{ssm_client}, nil
}

// GetValues retrieves the values for the given keys from AWS SSM Parameter Store
Expand All @@ -53,7 +64,7 @@ func (c *Client) GetValues(keys []string) (map[string]string, error) {
}
if len(resp) == 0 {
resp, err = c.getParameter(key)
if err != nil && err.(awserr.Error).Code() != ssm.ErrCodeParameterNotFound {
if err != nil {
return vars, err
}
}
Expand All @@ -72,13 +83,16 @@ func (c *Client) getParametersWithPrefix(prefix string) (map[string]string, erro
Recursive: aws.Bool(true),
WithDecryption: aws.Bool(true),
}
c.client.GetParametersByPathPages(params,
func(page *ssm.GetParametersByPathOutput, lastPage bool) bool {
for _, p := range page.Parameters {
parameters[*p.Name] = *p.Value
}
return !lastPage
})
paginator := ssm.NewGetParametersByPathPaginator(c.client, params)
for paginator.HasMorePages() {
page, err := paginator.NextPage(context.TODO())
if err != nil {
return parameters, err
}
for _, p := range page.Parameters {
parameters[*p.Name] = *p.Value
}
}
return parameters, err
}

Expand All @@ -88,7 +102,7 @@ func (c *Client) getParameter(name string) (map[string]string, error) {
Name: aws.String(name),
WithDecryption: aws.Bool(true),
}
resp, err := c.client.GetParameter(params)
resp, err := c.client.GetParameter(context.TODO(), params)
if err != nil {
return parameters, err
}
Expand Down