From 9698e41b1b4e5b1a85aaac833a009c19f1afe447 Mon Sep 17 00:00:00 2001 From: Mark Wolfe Date: Thu, 27 Jul 2017 20:55:13 +1000 Subject: [PATCH] feat(storage) Keychain storage for password. * Refactor all login related flags. * Added role command line option. * Added suggestions by @hoegertn to cover a couple of edge cases. fixes #28 #37 #29 --- aws_account.go | 32 +++ aws_account_test.go | 42 ++++ cmd/saml2aws/commands/exec.go | 8 +- cmd/saml2aws/commands/login.go | 154 ++++++++---- cmd/saml2aws/commands/login_darwin.go | 13 + cmd/saml2aws/commands/login_test.go | 18 ++ cmd/saml2aws/main.go | 32 ++- config.go | 46 ++-- helper/credentials/credentials.go | 58 +++++ helper/credentials/saml.go | 23 ++ helper/osxkeychain/osxkeychain_darwin.c | 228 ++++++++++++++++++ helper/osxkeychain/osxkeychain_darwin.go | 170 +++++++++++++ helper/osxkeychain/osxkeychain_darwin.h | 14 ++ helper/osxkeychain/osxkeychain_darwin_test.go | 64 +++++ input.go | 14 +- 15 files changed, 832 insertions(+), 84 deletions(-) create mode 100644 cmd/saml2aws/commands/login_darwin.go create mode 100644 cmd/saml2aws/commands/login_test.go create mode 100644 helper/credentials/credentials.go create mode 100644 helper/credentials/saml.go create mode 100644 helper/osxkeychain/osxkeychain_darwin.c create mode 100644 helper/osxkeychain/osxkeychain_darwin.go create mode 100644 helper/osxkeychain/osxkeychain_darwin.h create mode 100644 helper/osxkeychain/osxkeychain_darwin_test.go diff --git a/aws_account.go b/aws_account.go index ccf3306f0..f485a93ab 100644 --- a/aws_account.go +++ b/aws_account.go @@ -6,15 +6,19 @@ import ( "net/http" "net/url" + "fmt" + "github.com/PuerkitoBio/goquery" "github.com/pkg/errors" ) +// AWSAccount holds the AWS account name and roles type AWSAccount struct { Name string Roles []*AWSRole } +// ParseAWSAccounts extract the aws accounts from the saml assertion func ParseAWSAccounts(samlAssertion string) ([]*AWSAccount, error) { awsURL := "https://signin.aws.amazon.com/saml" @@ -31,6 +35,7 @@ func ParseAWSAccounts(samlAssertion string) ([]*AWSAccount, error) { return ExtractAWSAccounts(data) } +// ExtractAWSAccounts extract the accounts from the AWS html page func ExtractAWSAccounts(data []byte) ([]*AWSAccount, error) { accounts := []*AWSAccount{} @@ -53,3 +58,30 @@ func ExtractAWSAccounts(data []byte) ([]*AWSAccount, error) { return accounts, nil } + +// AssignPrincipals assign principal from roles +func AssignPrincipals(awsRoles []*AWSRole, awsAccounts []*AWSAccount) { + + awsPrincipalARNs := make(map[string]string) + for _, awsRole := range awsRoles { + awsPrincipalARNs[awsRole.RoleARN] = awsRole.PrincipalARN + } + + for _, awsAccount := range awsAccounts { + for _, awsRole := range awsAccount.Roles { + awsRole.PrincipalARN = awsPrincipalARNs[awsRole.RoleARN] + } + } + +} + +// LocateRole locate role by name +func LocateRole(awsRoles []*AWSRole, roleName string) (*AWSRole, error) { + for _, awsRole := range awsRoles { + if awsRole.RoleARN == roleName { + return awsRole, nil + } + } + + return nil, fmt.Errorf("Supplied RoleArn not found in saml assertion: %s", roleName) +} diff --git a/aws_account_test.go b/aws_account_test.go index d4159f036..a45229e79 100644 --- a/aws_account_test.go +++ b/aws_account_test.go @@ -34,3 +34,45 @@ func TestExtractAWSAccounts(t *testing.T) { assert.Equal(t, role.RoleARN, "arn:aws:iam::000000000002:role/Production") assert.Equal(t, role.Name, "Production") } + +func TestAssignPrincipals(t *testing.T) { + awsRoles := []*AWSRole{ + &AWSRole{ + PrincipalARN: "arn:aws:iam::000000000001:saml-provider/test-idp", + RoleARN: "arn:aws:iam::000000000001:role/Development", + }, + } + + awsAccounts := []*AWSAccount{ + &AWSAccount{ + Roles: []*AWSRole{ + &AWSRole{ + RoleARN: "arn:aws:iam::000000000001:role/Development", + }, + }, + }, + } + + AssignPrincipals(awsRoles, awsAccounts) + + assert.Equal(t, "arn:aws:iam::000000000001:saml-provider/test-idp", awsAccounts[0].Roles[0].PrincipalARN) +} + +func TestLocateRole(t *testing.T) { + awsRoles := []*AWSRole{ + &AWSRole{ + PrincipalARN: "arn:aws:iam::000000000001:saml-provider/test-idp", + RoleARN: "arn:aws:iam::000000000001:role/Development", + }, + &AWSRole{ + PrincipalARN: "arn:aws:iam::000000000002:saml-provider/test-idp", + RoleARN: "arn:aws:iam::000000000002:role/Development", + }, + } + + role, err := LocateRole(awsRoles, "arn:aws:iam::000000000001:role/Development") + + assert.Empty(t, err) + + assert.Equal(t, "arn:aws:iam::000000000001:role/Development", role.RoleARN) +} diff --git a/cmd/saml2aws/commands/exec.go b/cmd/saml2aws/commands/exec.go index 42c6f3c23..bb24e114b 100644 --- a/cmd/saml2aws/commands/exec.go +++ b/cmd/saml2aws/commands/exec.go @@ -15,25 +15,25 @@ import ( ) // Exec execute the supplied command after seeding the environment -func Exec(profile string, providerName string, skipVerify bool, cmdline []string) error { +func Exec(loginFlags *LoginFlags, cmdline []string) error { if len(cmdline) < 1 { return fmt.Errorf("Command to execute required.") } - ok, err := checkToken(profile) + ok, err := checkToken(loginFlags.Profile) if err != nil { return errors.Wrap(err, "error validating token") } if !ok { - err = Login(profile, providerName, skipVerify) + err = Login(loginFlags) } if err != nil { return errors.Wrap(err, "error logging in") } - sharedCreds := saml2aws.NewSharedCredentials(profile) + sharedCreds := saml2aws.NewSharedCredentials(loginFlags.Profile) id, secret, token, err := sharedCreds.Load() if err != nil { diff --git a/cmd/saml2aws/commands/login.go b/cmd/saml2aws/commands/login.go index bc1ba849e..efe3a052f 100644 --- a/cmd/saml2aws/commands/login.go +++ b/cmd/saml2aws/commands/login.go @@ -10,33 +10,46 @@ import ( "github.com/aws/aws-sdk-go/service/sts" "github.com/pkg/errors" "github.com/versent/saml2aws" + "github.com/versent/saml2aws/helper/credentials" ) -// Login login to ADFS -func Login(profile, providerName string, skipVerify bool) error { +// LoginFlags login specific command flags +type LoginFlags struct { + Provider string + Profile string + Hostname string + Username string + Password string + RoleArn string + SkipVerify bool + SkipPrompt bool +} - config := saml2aws.NewConfigLoader(providerName) +// RoleSupplied role arn has been passed as a flag +func (lf *LoginFlags) RoleSupplied() bool { + return lf.RoleArn != "" +} - username, err := config.LoadUsername() - if err != nil { - return errors.Wrap(err, "error loading config file") - } +// Login login to ADFS +func Login(loginFlags *LoginFlags) error { + + config := saml2aws.NewConfigLoader(loginFlags.Provider) hostname, err := config.LoadHostname() if err != nil { return errors.Wrap(err, "error loading config file") } - loginDetails, err := saml2aws.PromptForLoginDetails(username, hostname) + fmt.Println("LookupCredentials", hostname) + + loginDetails, err := resolveLoginDetails(hostname, loginFlags) if err != nil { return errors.Wrap(err, "error accepting password") } - fmt.Printf("%s https://%s\n", providerName, loginDetails.Hostname) + fmt.Printf("Authenticating to %s with URL https://%s\n", loginFlags.Provider, loginDetails.Hostname) - fmt.Printf("Authenticating to %s...\n", providerName) - - opts := &saml2aws.SAMLOptions{Provider: providerName, SkipVerify: skipVerify} + opts := &saml2aws.SAMLOptions{Provider: loginFlags.Provider, SkipVerify: loginFlags.SkipVerify} provider, err := saml2aws.NewSAMLClient(opts) if err != nil { @@ -55,6 +68,11 @@ func Login(profile, providerName string, skipVerify bool) error { os.Exit(1) } + err = credentials.SaveCredentials(loginDetails.Hostname, loginDetails.Username, loginDetails.Password) + if err != nil { + return errors.Wrap(err, "error storing password in keychain") + } + data, err := base64.StdEncoding.DecodeString(samlAssertion) if err != nil { return errors.Wrap(err, "error decoding saml assertion") @@ -76,34 +94,7 @@ func Login(profile, providerName string, skipVerify bool) error { return errors.Wrap(err, "error parsing aws roles") } - var role = new(saml2aws.AWSRole) - - if len(awsRoles) == 1 { - role = awsRoles[0] - } else if len(awsRoles) == 0 { - return errors.Wrap(err, "no roles available") - } else { - awsPrincipalARNs := make(map[string]string) - for _, awsRole := range awsRoles { - awsPrincipalARNs[awsRole.RoleARN] = awsRole.PrincipalARN - } - - awsAccounts, err := saml2aws.ParseAWSAccounts(samlAssertion) - if err != nil { - return errors.Wrap(err, "error parsing aws role accounts") - } - - for _, awsAccount := range awsAccounts { - for _, awsRole := range awsAccount.Roles { - awsRole.PrincipalARN = awsPrincipalARNs[awsRole.RoleARN] - } - } - - role, err = saml2aws.PromptForAWSRoleSelection(awsAccounts) - if err != nil { - return errors.Wrap(err, "error selecting role") - } - } + role, err := resolveRole(awsRoles, samlAssertion, loginFlags) fmt.Println("Selected role:", role.RoleARN) @@ -125,12 +116,12 @@ func Login(profile, providerName string, skipVerify bool) error { resp, err := svc.AssumeRoleWithSAML(params) if err != nil { - return errors.Wrap(err, "error retieving sts credentials using SAML") + return errors.Wrap(err, "error retrieving STS credentials using SAML") } fmt.Println("Saving credentials") - sharedCreds := saml2aws.NewSharedCredentials(profile) + sharedCreds := saml2aws.NewSharedCredentials(loginFlags.Profile) err = sharedCreds.Save(aws.StringValue(resp.Credentials.AccessKeyId), aws.StringValue(resp.Credentials.SecretAccessKey), aws.StringValue(resp.Credentials.SessionToken)) if err != nil { @@ -141,7 +132,7 @@ func Login(profile, providerName string, skipVerify bool) error { fmt.Println("") fmt.Println("Your new access key pair has been stored in the AWS configuration") fmt.Printf("Note that it will expire at %v\n", resp.Credentials.Expiration.Local()) - fmt.Println("To use this credential, call the AWS CLI with the --profile option (e.g. aws --profile", profile, "ec2 describe-instances).") + fmt.Println("To use this credential, call the AWS CLI with the --profile option (e.g. aws --profile", loginFlags.Profile, "ec2 describe-instances).") fmt.Println("Saving config:", config.Filename) config.SaveUsername(loginDetails.Username) @@ -149,3 +140,80 @@ func Login(profile, providerName string, skipVerify bool) error { return nil } + +func resolveLoginDetails(hostname string, loginFlags *LoginFlags) (*saml2aws.LoginDetails, error) { + + loginDetails := new(saml2aws.LoginDetails) + + fmt.Println("hostname", hostname) + + savedUsername, savedPassword, err := credentials.LookupCredentials(hostname) + if err != nil { + if !credentials.IsErrCredentialsNotFound(err) { + return nil, errors.Wrap(err, "error loading saved password") + } + } + + // if you supply a username in a flag it takes precedence + if loginFlags.Username != "" { + loginDetails.Username = loginFlags.Username + } else { + fmt.Println("Using saved username") + loginDetails.Username = savedUsername + } + + // if you supply a password in a flag it takes precedence + if loginFlags.Password != "" { + loginDetails.Password = loginFlags.Password + } else { + fmt.Println("Using saved password") + loginDetails.Password = savedPassword + } + + fmt.Println("savedUsername", savedUsername) + + // if skip prompt was passed just pass back the flag values + if loginFlags.SkipPrompt { + return &saml2aws.LoginDetails{ + Username: loginDetails.Username, + Password: loginDetails.Password, + Hostname: loginFlags.Hostname, + }, nil + } + + return saml2aws.PromptForLoginDetails(savedUsername, hostname, savedPassword) +} + +func resolveRole(awsRoles []*saml2aws.AWSRole, samlAssertion string, loginFlags *LoginFlags) (*saml2aws.AWSRole, error) { + var role = new(saml2aws.AWSRole) + + if len(awsRoles) == 1 { + if loginFlags.RoleSupplied() { + return saml2aws.LocateRole(awsRoles, loginFlags.RoleArn) + } + role = awsRoles[0] + } else if len(awsRoles) == 0 { + return nil, errors.New("no roles available") + } + + awsAccounts, err := saml2aws.ParseAWSAccounts(samlAssertion) + if err != nil { + return nil, errors.Wrap(err, "error parsing aws role accounts") + } + + saml2aws.AssignPrincipals(awsRoles, awsAccounts) + + if loginFlags.RoleSupplied() { + return saml2aws.LocateRole(awsRoles, loginFlags.RoleArn) + } + + for { + role, err = saml2aws.PromptForAWSRoleSelection(awsAccounts) + if err == nil { + break + } + fmt.Println("error selecting role") + } + + return role, nil +} diff --git a/cmd/saml2aws/commands/login_darwin.go b/cmd/saml2aws/commands/login_darwin.go new file mode 100644 index 000000000..895c4623b --- /dev/null +++ b/cmd/saml2aws/commands/login_darwin.go @@ -0,0 +1,13 @@ +package commands + +import ( + "fmt" + + "github.com/versent/saml2aws/helper/credentials" + "github.com/versent/saml2aws/helper/osxkeychain" +) + +func init() { + fmt.Println("adding osx helper") + credentials.CurrentHelper = &osxkeychain.Osxkeychain{} +} diff --git a/cmd/saml2aws/commands/login_test.go b/cmd/saml2aws/commands/login_test.go new file mode 100644 index 000000000..de386151f --- /dev/null +++ b/cmd/saml2aws/commands/login_test.go @@ -0,0 +1,18 @@ +package commands + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/versent/saml2aws" +) + +func TestResolveLoginDetails(t *testing.T) { + + loginFlags := &LoginFlags{Hostname: "id.example.com", Username: "wolfeidau", Password: "testtestlol", SkipPrompt: true} + + loginDetails, err := resolveLoginDetails("id.example.com", loginFlags) + + assert.Empty(t, err) + assert.Equal(t, loginDetails, &saml2aws.LoginDetails{Username: "wolfeidau", Password: "testtestlol", Hostname: "id.example.com"}) +} diff --git a/cmd/saml2aws/main.go b/cmd/saml2aws/main.go index 2909954f6..e78cf7d1a 100644 --- a/cmd/saml2aws/main.go +++ b/cmd/saml2aws/main.go @@ -12,15 +12,9 @@ import ( var ( app = kingpin.New("saml2aws", "A command line tool to help with SAML access to the AWS token service.") - // /verbose = kingpin.Flag("verbose", "Verbose mode.").Short('v').Bool() - profileName = app.Flag("profile", "The AWS profile to save the temporary credentials").Short('p').Default("saml").String() - skipVerify = app.Flag("skip-verify", "Skip verification of server certificate.").Short('s').Bool() - providerName = app.Flag("provider", "The type of SAML IDP provider.").Short('i').Default("ADFS").Enum("ADFS", "ADFS2", "Ping", "JumpCloud", "Okta", "KeyCloak") - cmdLogin = app.Command("login", "Login to a SAML 2.0 IDP and convert the SAML assertion to an STS token.") - - cmdExec = app.Command("exec", "Exec the supplied command with env vars from STS token.") - cmdLine = buildCmdList(cmdExec.Arg("command", "The command to execute.")) + cmdExec = app.Command("exec", "Exec the supplied command with env vars from STS token.") + cmdLine = buildCmdList(cmdExec.Arg("command", "The command to execute.")) // Version app version Version = "1.0.0" @@ -48,19 +42,37 @@ func buildCmdList(s kingpin.Settings) (target *[]string) { return } +func configureLoginFlags(app *kingpin.Application) *commands.LoginFlags { + c := &commands.LoginFlags{} + + app.Flag("profile", "The AWS profile to save the temporary credentials").Short('p').Default("saml").StringVar(&c.Profile) + app.Flag("skip-verify", "Skip verification of server certificate.").Short('s').BoolVar(&c.SkipVerify) + app.Flag("provider", "The type of SAML IDP provider.").Short('i').Default("ADFS").EnumVar(&c.Provider, "ADFS", "ADFS2", "Ping", "JumpCloud", "Okta", "KeyCloak") + app.Flag("hostname", "The hostname of the SAML IDP server used to login.").StringVar(&c.Hostname) + app.Flag("username", "The username used to login.").StringVar(&c.Username) + app.Flag("password", "The password used to login.").Envar("SAML2AWS_PASSWORD").StringVar(&c.Password) + app.Flag("role", "The ARN of the role to assume.").StringVar(&c.RoleArn) + app.Flag("skip-prompt", "Skip prompting for parameters during login.").BoolVar(&c.SkipPrompt) + + return c +} + func main() { log.SetFlags(log.Lshortfile) app.Version(Version) + + lc := configureLoginFlags(app) + command := kingpin.MustParse(app.Parse(os.Args[1:])) var err error switch command { case cmdLogin.FullCommand(): - err = commands.Login(*profileName, *providerName, *skipVerify) + err = commands.Login(lc) case cmdExec.FullCommand(): - err = commands.Exec(*profileName, *providerName, *skipVerify, *cmdLine) + err = commands.Exec(lc, *cmdLine) } if err != nil { diff --git a/config.go b/config.go index b47fa6727..bc565602a 100644 --- a/config.go +++ b/config.go @@ -30,29 +30,6 @@ func NewConfigLoader(profile string) *ConfigLoader { } } -// ensureConfigExists verify that the config file exists -func (p *ConfigLoader) ensureConfigExists() error { - filename, err := p.filename() - if err != nil { - return err - } - - if _, err := os.Stat(filename); err != nil { - if os.IsNotExist(err) { - - // create an base config file - err = ioutil.WriteFile(filename, []byte("["+p.Profile+"]"), 0600) - if err != nil { - return err - } - - } - return err - } - - return nil -} - // SaveUsername persist the username func (p *ConfigLoader) SaveUsername(username string) error { filename, err := p.filename() @@ -133,6 +110,29 @@ func (p *ConfigLoader) LoadProvider(defaultValue string) (string, error) { return loadConfig(filename, p.Profile, "provider") } +// ensureConfigExists verify that the config file exists +func (p *ConfigLoader) ensureConfigExists() error { + filename, err := p.filename() + if err != nil { + return err + } + + if _, err := os.Stat(filename); err != nil { + if os.IsNotExist(err) { + + // create an base config file + err = ioutil.WriteFile(filename, []byte("["+p.Profile+"]"), 0600) + if err != nil { + return err + } + + } + return err + } + + return nil +} + func (p *ConfigLoader) filename() (string, error) { if p.Filename == "" { if p.Filename = os.Getenv("AWS2SAML_CONFIG_FILE"); p.Filename != "" { diff --git a/helper/credentials/credentials.go b/helper/credentials/credentials.go new file mode 100644 index 000000000..cb75b9f95 --- /dev/null +++ b/helper/credentials/credentials.go @@ -0,0 +1,58 @@ +package credentials + +import "errors" + +// CurrentHelper the currently configured credentials helper +var CurrentHelper Helper = &defaultHelper{} + +// ErrCredentialsNotFound returned when the credential can't be located in the native store. +var ErrCredentialsNotFound = errors.New("credentials not found in native keychain") + +// Credentials holds the information shared between saml2aws and the credentials store. +type Credentials struct { + ServerURL string + Username string + Secret string +} + +// CredsLabel saml2aws credentials should be labeled as such in credentials stores that allow labelling. +// That label allows to filter out non-Docker credentials too at lookup/search in macOS keychain, +// Windows credentials manager and Linux libsecret. Default value is "saml2aws Credentials" +var CredsLabel = "saml2aws Credentials" + +// Helper is the interface a credentials store helper must implement. +type Helper interface { + // Add appends credentials to the store. + Add(*Credentials) error + // Delete removes credentials from the store. + Delete(serverURL string) error + // Get retrieves credentials from the store. + // It returns username and secret as strings. + Get(serverURL string) (string, string, error) + // List returns the stored serverURLs and their associated usernames. + List() (map[string]string, error) +} + +// IsErrCredentialsNotFound returns true if the error +// was caused by not having a set of credentials in a store. +func IsErrCredentialsNotFound(err error) bool { + return err == ErrCredentialsNotFound +} + +type defaultHelper struct{} + +func (defaultHelper) Add(*Credentials) error { + return nil +} + +func (defaultHelper) Delete(serverURL string) error { + return nil +} + +func (defaultHelper) Get(serverURL string) (string, string, error) { + return "", "", nil +} + +func (defaultHelper) List() (map[string]string, error) { + return map[string]string{}, nil +} diff --git a/helper/credentials/saml.go b/helper/credentials/saml.go new file mode 100644 index 000000000..f5e7bd0e2 --- /dev/null +++ b/helper/credentials/saml.go @@ -0,0 +1,23 @@ +package credentials + +import "fmt" + +// LookupCredentials lookup an existing set of credentials and validate it. +func LookupCredentials(hostname string) (string, string, error) { + + username, password, err := CurrentHelper.Get(fmt.Sprintf("https://%s", hostname)) + + return username, password, err +} + +// SaveCredentials save the user credentials. +func SaveCredentials(hostname, username, password string) error { + + creds := &Credentials{ + ServerURL: fmt.Sprintf("https://%s", hostname), + Username: username, + Secret: password, + } + + return CurrentHelper.Add(creds) +} diff --git a/helper/osxkeychain/osxkeychain_darwin.c b/helper/osxkeychain/osxkeychain_darwin.c new file mode 100644 index 000000000..f84d61ee5 --- /dev/null +++ b/helper/osxkeychain/osxkeychain_darwin.c @@ -0,0 +1,228 @@ +#include "osxkeychain_darwin.h" +#include +#include +#include +#include + +char *get_error(OSStatus status) { + char *buf = malloc(128); + CFStringRef str = SecCopyErrorMessageString(status, NULL); + int success = CFStringGetCString(str, buf, 128, kCFStringEncodingUTF8); + if (!success) { + strncpy(buf, "Unknown error", 128); + } + return buf; +} + +char *keychain_add(struct Server *server, char *label, char *username, char *secret) { + SecKeychainItemRef item; + + OSStatus status = SecKeychainAddInternetPassword( + NULL, + strlen(server->host), server->host, + 0, NULL, + strlen(username), username, + strlen(server->path), server->path, + server->port, + server->proto, + kSecAuthenticationTypeDefault, + strlen(secret), secret, + &item + ); + + if (status) { + return get_error(status); + } + + SecKeychainAttribute attribute; + SecKeychainAttributeList attrs; + attribute.tag = kSecLabelItemAttr; + attribute.data = label; + attribute.length = strlen(label); + attrs.count = 1; + attrs.attr = &attribute; + + status = SecKeychainItemModifyContent(item, &attrs, 0, NULL); + + if (status) { + return get_error(status); + } + + return NULL; +} + +char *keychain_get(struct Server *server, unsigned int *username_l, char **username, unsigned int *secret_l, char **secret) { + char *tmp; + SecKeychainItemRef item; + + OSStatus status = SecKeychainFindInternetPassword( + NULL, + strlen(server->host), server->host, + 0, NULL, + 0, NULL, + strlen(server->path), server->path, + server->port, + server->proto, + kSecAuthenticationTypeDefault, + secret_l, (void **)&tmp, + &item); + + if (status) { + return get_error(status); + } + + *secret = strdup(tmp); + SecKeychainItemFreeContent(NULL, tmp); + + SecKeychainAttributeList list; + SecKeychainAttribute attr; + + list.count = 1; + list.attr = &attr; + attr.tag = kSecAccountItemAttr; + + status = SecKeychainItemCopyContent(item, NULL, &list, NULL, NULL); + if (status) { + return get_error(status); + } + + *username = strdup(attr.data); + *username_l = attr.length; + SecKeychainItemFreeContent(&list, NULL); + + return NULL; +} + +char *keychain_delete(struct Server *server) { + SecKeychainItemRef item; + + OSStatus status = SecKeychainFindInternetPassword( + NULL, + strlen(server->host), server->host, + 0, NULL, + 0, NULL, + strlen(server->path), server->path, + server->port, + server->proto, + kSecAuthenticationTypeDefault, + 0, NULL, + &item); + + if (status) { + return get_error(status); + } + + status = SecKeychainItemDelete(item); + if (status) { + return get_error(status); + } + return NULL; +} + +char * CFStringToCharArr(CFStringRef aString) { + if (aString == NULL) { + return NULL; + } + CFIndex length = CFStringGetLength(aString); + CFIndex maxSize = + CFStringGetMaximumSizeForEncoding(length, kCFStringEncodingUTF8) + 1; + char *buffer = (char *)malloc(maxSize); + if (CFStringGetCString(aString, buffer, maxSize, + kCFStringEncodingUTF8)) { + return buffer; + } + return NULL; +} + +char *keychain_list(char *credsLabel, char *** paths, char *** accts, unsigned int *list_l) { + CFStringRef credsLabelCF = CFStringCreateWithCString(NULL, credsLabel, kCFStringEncodingUTF8); + CFMutableDictionaryRef query = CFDictionaryCreateMutable (NULL, 1, NULL, NULL); + CFDictionaryAddValue(query, kSecClass, kSecClassInternetPassword); + CFDictionaryAddValue(query, kSecReturnAttributes, kCFBooleanTrue); + CFDictionaryAddValue(query, kSecMatchLimit, kSecMatchLimitAll); + CFDictionaryAddValue(query, kSecAttrLabel, credsLabelCF); + //Use this query dictionary + CFTypeRef result= NULL; + OSStatus status = SecItemCopyMatching( + query, + &result); + + CFRelease(credsLabelCF); + + //Ran a search and store the results in result + if (status) { + return get_error(status); + } + CFIndex numKeys = CFArrayGetCount(result); + *paths = (char **) malloc((int)sizeof(char *)*numKeys); + *accts = (char **) malloc((int)sizeof(char *)*numKeys); + //result is of type CFArray + for(CFIndex i=0; i +*/ +import "C" +import ( + "errors" + "net/url" + "strconv" + "strings" + "unsafe" + + "github.com/versent/saml2aws/helper/credentials" +) + +// errCredentialsNotFound is the specific error message returned by OS X +// when the credentials are not in the keychain. +const errCredentialsNotFound = "The specified item could not be found in the keychain." + +// Osxkeychain handles secrets using the OS X Keychain as store. +type Osxkeychain struct{} + +// Add adds new credentials to the keychain. +func (h Osxkeychain) Add(creds *credentials.Credentials) error { + h.Delete(creds.ServerURL) + + s, err := splitServer(creds.ServerURL) + if err != nil { + return err + } + defer freeServer(s) + + label := C.CString(credentials.CredsLabel) + defer C.free(unsafe.Pointer(label)) + username := C.CString(creds.Username) + defer C.free(unsafe.Pointer(username)) + secret := C.CString(creds.Secret) + defer C.free(unsafe.Pointer(secret)) + + errMsg := C.keychain_add(s, label, username, secret) + if errMsg != nil { + defer C.free(unsafe.Pointer(errMsg)) + return errors.New(C.GoString(errMsg)) + } + + return nil +} + +// Delete removes credentials from the keychain. +func (h Osxkeychain) Delete(serverURL string) error { + s, err := splitServer(serverURL) + if err != nil { + return err + } + defer freeServer(s) + + errMsg := C.keychain_delete(s) + if errMsg != nil { + defer C.free(unsafe.Pointer(errMsg)) + return errors.New(C.GoString(errMsg)) + } + + return nil +} + +// Get returns the username and secret to use for a given registry server URL. +func (h Osxkeychain) Get(serverURL string) (string, string, error) { + s, err := splitServer(serverURL) + if err != nil { + return "", "", err + } + defer freeServer(s) + + var usernameLen C.uint + var username *C.char + var secretLen C.uint + var secret *C.char + defer C.free(unsafe.Pointer(username)) + defer C.free(unsafe.Pointer(secret)) + + errMsg := C.keychain_get(s, &usernameLen, &username, &secretLen, &secret) + if errMsg != nil { + defer C.free(unsafe.Pointer(errMsg)) + goMsg := C.GoString(errMsg) + if goMsg == errCredentialsNotFound { + return "", "", credentials.ErrCredentialsNotFound + } + + return "", "", errors.New(goMsg) + } + + user := C.GoStringN(username, C.int(usernameLen)) + pass := C.GoStringN(secret, C.int(secretLen)) + return user, pass, nil +} + +// List returns the stored URLs and corresponding usernames. +func (h Osxkeychain) List() (map[string]string, error) { + credsLabelC := C.CString(credentials.CredsLabel) + defer C.free(unsafe.Pointer(credsLabelC)) + + var pathsC **C.char + defer C.free(unsafe.Pointer(pathsC)) + var acctsC **C.char + defer C.free(unsafe.Pointer(acctsC)) + var listLenC C.uint + errMsg := C.keychain_list(credsLabelC, &pathsC, &acctsC, &listLenC) + if errMsg != nil { + defer C.free(unsafe.Pointer(errMsg)) + goMsg := C.GoString(errMsg) + return nil, errors.New(goMsg) + } + + defer C.freeListData(&pathsC, listLenC) + defer C.freeListData(&acctsC, listLenC) + + var listLen int + listLen = int(listLenC) + pathTmp := (*[1 << 30]*C.char)(unsafe.Pointer(pathsC))[:listLen:listLen] + acctTmp := (*[1 << 30]*C.char)(unsafe.Pointer(acctsC))[:listLen:listLen] + //taking the array of c strings into go while ignoring all the stuff irrelevant to credentials-helper + resp := make(map[string]string) + for i := 0; i < listLen; i++ { + if C.GoString(pathTmp[i]) == "0" { + continue + } + resp[C.GoString(pathTmp[i])] = C.GoString(acctTmp[i]) + } + return resp, nil +} + +func splitServer(serverURL string) (*C.struct_Server, error) { + u, err := url.Parse(serverURL) + if err != nil { + return nil, err + } + + hostAndPort := strings.Split(u.Host, ":") + host := hostAndPort[0] + var port int + if len(hostAndPort) == 2 { + p, err := strconv.Atoi(hostAndPort[1]) + if err != nil { + return nil, err + } + port = p + } + + proto := C.kSecProtocolTypeHTTPS + if u.Scheme != "https" { + proto = C.kSecProtocolTypeHTTP + } + + return &C.struct_Server{ + proto: C.SecProtocolType(proto), + host: C.CString(host), + port: C.uint(port), + path: C.CString(u.Path), + }, nil +} + +func freeServer(s *C.struct_Server) { + C.free(unsafe.Pointer(s.host)) + C.free(unsafe.Pointer(s.path)) +} diff --git a/helper/osxkeychain/osxkeychain_darwin.h b/helper/osxkeychain/osxkeychain_darwin.h new file mode 100644 index 000000000..c54e7d728 --- /dev/null +++ b/helper/osxkeychain/osxkeychain_darwin.h @@ -0,0 +1,14 @@ +#include + +struct Server { + SecProtocolType proto; + char *host; + char *path; + unsigned int port; +}; + +char *keychain_add(struct Server *server, char *label, char *username, char *secret); +char *keychain_get(struct Server *server, unsigned int *username_l, char **username, unsigned int *secret_l, char **secret); +char *keychain_delete(struct Server *server); +char *keychain_list(char *credsLabel, char *** data, char *** accts, unsigned int *list_l); +void freeListData(char *** data, unsigned int length); \ No newline at end of file diff --git a/helper/osxkeychain/osxkeychain_darwin_test.go b/helper/osxkeychain/osxkeychain_darwin_test.go new file mode 100644 index 000000000..ba694abb7 --- /dev/null +++ b/helper/osxkeychain/osxkeychain_darwin_test.go @@ -0,0 +1,64 @@ +package osxkeychain + +import ( + "testing" + + "github.com/versent/saml2aws/helper/credentials" +) + +func TestOSXKeychainHelper(t *testing.T) { + creds := &credentials.Credentials{ + ServerURL: "https://foobar.docker.io:2376/v1", + Username: "foobar", + Secret: "foobarbaz", + } + creds1 := &credentials.Credentials{ + ServerURL: "https://foobar.docker.io:2376/v2", + Username: "foobarbaz", + Secret: "foobar", + } + helper := Osxkeychain{} + if err := helper.Add(creds); err != nil { + t.Fatal(err) + } + + username, secret, err := helper.Get(creds.ServerURL) + if err != nil { + t.Fatal(err) + } + + if username != "foobar" { + t.Fatalf("expected %s, got %s\n", "foobar", username) + } + + if secret != "foobarbaz" { + t.Fatalf("expected %s, got %s\n", "foobarbaz", secret) + } + + auths, err := helper.List() + if err != nil || len(auths) == 0 { + t.Fatal(err) + } + + helper.Add(creds1) + defer helper.Delete(creds1.ServerURL) + newauths, err := helper.List() + if len(newauths)-len(auths) != 1 { + if err == nil { + t.Fatalf("Error: len(newauths): %d, len(auths): %d", len(newauths), len(auths)) + } + t.Fatalf("Error: len(newauths): %d, len(auths): %d\n Error= %v", len(newauths), len(auths), err) + } + + if err := helper.Delete(creds.ServerURL); err != nil { + t.Fatal(err) + } +} + +func TestMissingCredentials(t *testing.T) { + helper := Osxkeychain{} + _, _, err := helper.Get("https://adsfasdf.wrewerwer.com/asdfsdddd") + if !credentials.IsErrCredentialsNotFound(err) { + t.Fatalf("expected ErrCredentialsNotFound, got %v", err) + } +} diff --git a/input.go b/input.go index 8b4869e36..f8d4e18a2 100644 --- a/input.go +++ b/input.go @@ -18,11 +18,17 @@ type LoginDetails struct { } // PromptForLoginDetails prompt the user to present their username, password and hostname -func PromptForLoginDetails(username, hostname string) (*LoginDetails, error) { +func PromptForLoginDetails(username, hostname, password string) (*LoginDetails, error) { hostname = promptFor("Hostname [%s]", hostname) + + fmt.Println("To use saved username and password just hit enter.") + username = promptFor("Username [%s]", username) - password := prompt.PasswordMasked("Password") + + if enteredPassword := prompt.PasswordMasked("Password"); enteredPassword != "" { + password = enteredPassword + } fmt.Println("") @@ -52,9 +58,9 @@ func PromptForAWSRoleSelection(accounts []*AWSAccount) (*AWSRole, error) { } fmt.Print("Selection: ") - selectedroleindex, _ := reader.ReadString('\n') + selectedRoleIndex, _ := reader.ReadString('\n') - v, err := strconv.Atoi(strings.TrimSpace(selectedroleindex)) + v, err := strconv.Atoi(strings.TrimSpace(selectedRoleIndex)) if err != nil { return nil, fmt.Errorf("Unrecognised role index")