Skip to content

Commit

Permalink
Attaching session name to cached role credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
null93 committed Jun 20, 2024
1 parent ca03436 commit 4d164a5
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 45 deletions.
2 changes: 1 addition & 1 deletion internal/clean.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var cleanCmd = &cobra.Command{
}
for _, role := range roles {
if role.Credentials.IsExpired() || cleanAll {
err := role.Credentials.DeleteCache(role.CacheKey())
err := role.Credentials.DeleteCache(role.SessionName, role.CacheKey())
if err != nil {
ExitWithError(2, "failed to delete role credentials", err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/creds-select.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ var credsSelectCmd = &cobra.Command{
p.WithMaxHeight(10)
p.WithEmptyMessage("No Role Credentials Found")
p.WithTitle("Pick Role Credentials")
p.WithHeaders("Region", "Account ID", "Role Name", "Expires In")
p.WithHeaders("SSO Session", "Region", "Account ID", "Role Name", "Expires In")
for _, role := range roles {
expires := "-"
if role.Credentials != nil && !role.Credentials.IsExpired() {
expires = fmt.Sprintf("%.f mins", role.Credentials.Expiration.Sub(now).Minutes())
}
p.AddOption(role, role.Region, role.AccountId, role.Name, expires)
p.AddOption(role, role.SessionName, role.Region, role.AccountId, role.Name, expires)
}
selection := p.Pick()
if selection == nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ var selectCredentialsCmd = &cobra.Command{
if err != nil {
ExitWithError(16, "failed to get credentials", err)
}
err = role.Credentials.Save(role.CacheKey())
err = role.Credentials.Save(session.Name, role.CacheKey())
if err != nil {
ExitWithError(17, "failed to save credentials", err)
}
Expand Down
83 changes: 45 additions & 38 deletions sdk/credentials/role-credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func findRoleCredentials(r Role) (*RoleCredentials, error) {
if err != nil {
return nil, err
}
cachePath := filepath.Join(homedir, RoleCredentialsCachePath, cacheKey+".json")
cachePath := filepath.Join(homedir, RoleCredentialsCachePath, r.SessionName, cacheKey+".json")
if _, err := os.Stat(cachePath); err == nil {
contents, err := ioutil.ReadFile(cachePath)
if err != nil {
Expand All @@ -56,28 +56,28 @@ func (r *RoleCredentials) IsExpired() bool {
return r.Expiration.Before(time.Now())
}

func (r *RoleCredentials) Save(key string) error {
func (r *RoleCredentials) Save(sessionName, key string) error {
homedir, err := os.UserHomeDir()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Join(homedir, RoleCredentialsCachePath), 0700); err != nil {
if err := os.MkdirAll(filepath.Join(homedir, RoleCredentialsCachePath, sessionName), 0700); err != nil {
return err
}
cachePath := filepath.Join(homedir, RoleCredentialsCachePath, key+".json")
cachePath := filepath.Join(homedir, RoleCredentialsCachePath, sessionName, key+".json")
contents, err := json.Marshal(r)
if err != nil {
return err
}
return ioutil.WriteFile(cachePath, contents, 0600)
}

func (r *RoleCredentials) DeleteCache(key string) error {
func (r *RoleCredentials) DeleteCache(sessionName, key string) error {
homedir, err := os.UserHomeDir()
if err != nil {
return err
}
cachePath := filepath.Join(homedir, RoleCredentialsCachePath, key+".json")
cachePath := filepath.Join(homedir, RoleCredentialsCachePath, sessionName, key+".json")
return os.Remove(cachePath)
}

Expand All @@ -90,7 +90,7 @@ func (r *Role) MarkLastUsed() error {
return err
}
lastUsedPath := filepath.Join(homedir, KnoxPath, "last-used")
return ioutil.WriteFile(lastUsedPath, []byte(r.CacheKey()), 0600)
return ioutil.WriteFile(lastUsedPath, []byte(r.SessionName+"\n"+r.CacheKey()), 0600)
}

func GetLastUsedRole() (Role, error) {
Expand All @@ -103,17 +103,23 @@ func GetLastUsedRole() (Role, error) {
if err != nil {
return Role{}, err
}
parts := strings.Split(string(contents), "_")
lines := strings.Split(string(contents), "\n")
if len(lines) < 2 {
return Role{}, fmt.Errorf("invalid last used role")
}
sessionName := lines[0]
parts := strings.Split(lines[1], "_")
if len(parts) < 3 {
return Role{}, fmt.Errorf("invalid last used role")
}
region := parts[0]
accountId := parts[1]
roleName := strings.Join(parts[2:], "_")
role := Role{
Region: region,
AccountId: accountId,
Name: roleName,
Region: region,
AccountId: accountId,
Name: roleName,
SessionName: sessionName,
}
creds, err := findRoleCredentials(role)
if err != nil {
Expand All @@ -129,37 +135,38 @@ func GetSavedRolesWithCredentials() (Roles, error) {
if err != nil {
return roles, err
}
cacheDir := filepath.Join(homedir, RoleCredentialsCachePath)
files, err := os.ReadDir(cacheDir)
pattern := filepath.Join(homedir, RoleCredentialsCachePath, "*", "*.json")
files, err := filepath.Glob(pattern)
if err != nil {
return roles, err
}
for _, file := range files {
filename := file.Name()
if !file.IsDir() && filepath.Ext(filename) == ".json" {
contents, err := os.ReadFile(filepath.Join(cacheDir, filename))
parts := strings.Split(filename, "_")
if len(parts) < 3 {
continue
}
region := parts[0]
accountId := parts[1]
roleName := strings.TrimSuffix(strings.Join(parts[2:], "_"), ".json")
if err != nil {
return nil, err
}
cred := RoleCredentials{}
if err := json.Unmarshal(contents, &cred); err != nil {
return nil, err
}
role := Role{
Region: region,
AccountId: accountId,
Name: roleName,
Credentials: &cred,
}
roles = append(roles, role)
for _, foundPath := range files {
fmt.Println(foundPath)
fileName := filepath.Base(foundPath)
sessionName := filepath.Base(filepath.Dir(foundPath))
contents, err := os.ReadFile(foundPath)
parts := strings.Split(fileName, "_")
if len(parts) < 3 {
continue
}
region := parts[0]
accountId := parts[1]
roleName := strings.TrimSuffix(strings.Join(parts[2:], "_"), ".json")
if err != nil {
return nil, err
}
cred := RoleCredentials{}
if err := json.Unmarshal(contents, &cred); err != nil {
return nil, err
}
role := Role{
Region: region,
AccountId: accountId,
Name: roleName,
SessionName: sessionName,
Credentials: &cred,
}
roles = append(roles, role)
}
return roles, nil
}
8 changes: 5 additions & 3 deletions sdk/credentials/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type Role struct {
Name string
AccountId string
Region string
SessionName string
Credentials *RoleCredentials
}

Expand Down Expand Up @@ -269,9 +270,10 @@ func (s *Session) GetRoles(accountId string) (Roles, error) {
for _, details := range page.RoleList {
roleName := aws.ToString(details.RoleName)
role := Role{
Name: roleName,
AccountId: accountId,
Region: s.Region,
Name: roleName,
AccountId: accountId,
Region: s.Region,
SessionName: s.Name,
}
creds, err := findRoleCredentials(role)
if err != nil {
Expand Down

0 comments on commit 4d164a5

Please sign in to comment.