Skip to content

Commit

Permalink
Added --last-used for connect and sync commands, also accepting one a…
Browse files Browse the repository at this point in the history
…rg for instance search term
  • Loading branch information
null93 committed Sep 6, 2024
1 parent 4a6b0b2 commit 019cf86
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 45 deletions.
74 changes: 55 additions & 19 deletions internal/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,71 @@ import (
)

var connectCmd = &cobra.Command{
Use: "connect",
Use: "connect <instance-search-term>",
Short: "Connect to an EC2 instance using session-manager-plugin",
Args: cobra.ExactArgs(0),
Args: cobra.MaximumNArgs(1),
Run: func(cmd *cobra.Command, args []string) {
searchTerm := ""
if len(args) > 0 {
searchTerm = args[0]
}
var err error
var role *credentials.Role
var action string
var binaryPath string
for {
if !selectCachedFirst {
action, role = SelectRoleCredentialsStartingFromSession()
} else {
action, role = SelectRoleCredentialsStartingFromCache()
}
if action == "toggle-view" {
toggleView()
continue
if lastUsed {
var err error
var sessions credentials.Sessions
var session *credentials.Session
var roleTemp credentials.Role
if roleTemp, err = credentials.GetLastUsedRole(); err != nil {
ExitWithError(1, "failed to get last used role", err)
}
if action == "back" {
goBack()
continue
role = &roleTemp
if role.Credentials == nil || role.Credentials.IsExpired() {
if sessions, err = credentials.GetSessions(); err != nil {
ExitWithError(2, "failed to parse sso sessions", err)
}
if session = sessions.FindByName(role.SessionName); session == nil {
ExitWithError(3, "failed to find sso session "+role.SessionName, err)
}
if session.ClientToken == nil || session.ClientToken.IsExpired() {
if err = tui.ClientLogin(session); err != nil {
ExitWithError(4, "failed to authorize device login", err)
}
}
if err = session.RefreshRoleCredentials(role); err != nil {
ExitWithError(5, "failed to get credentials", err)
}
if err = role.Credentials.Save(session.Name, role.CacheKey()); err != nil {
ExitWithError(6, "failed to save credentials", err)
}
}
if action == "delete" {
if role != nil && role.Credentials != nil {
role.Credentials.DeleteCache(role.SessionName, role.CacheKey())
}
for {
if role == nil {
if !selectCachedFirst {
action, role = SelectRoleCredentialsStartingFromSession()
} else {
action, role = SelectRoleCredentialsStartingFromCache()
}
if action == "toggle-view" {
toggleView()
continue
}
if action == "back" {
goBack()
continue
}
if action == "delete" {
if role != nil && role.Credentials != nil {
role.Credentials.DeleteCache(role.SessionName, role.CacheKey())
}
continue
}
continue
}
if instanceId == "" {
if instanceId, action, err = tui.SelectInstance(role); err != nil {
if instanceId, action, err = tui.SelectInstance(role, searchTerm); err != nil {
ExitWithError(19, "failed to pick an instance", err)
} else if action == "back" {
goBack()
Expand Down Expand Up @@ -84,5 +119,6 @@ func init() {
connectCmd.Flags().StringVarP(&roleName, "role-name", "r", roleName, "AWS role name")
connectCmd.Flags().StringVarP(&instanceId, "instance-id", "i", instanceId, "EC2 instance ID")
connectCmd.Flags().BoolVarP(&selectCachedFirst, "cached", "c", selectCachedFirst, "select from cached credentials")
connectCmd.Flags().BoolVarP(&lastUsed, "last-used", "l", lastUsed, "select last used credentials")
connectCmd.Flags().Uint32VarP(&connectUid, "uid", "u", connectUid, "UID on instance to 'su' to")
}
1 change: 1 addition & 0 deletions internal/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ var (
debug bool = false
selectCachedFirst bool = false
connectUid uint32 = 0
lastUsed bool = false
sessionName string
accountId string
roleName string
Expand Down
74 changes: 55 additions & 19 deletions internal/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,35 +162,70 @@ func rsyncPortForward(role *credentials.Role, instanceId string) {
}

var syncCmd = &cobra.Command{
Use: "sync",
Use: "sync <instance-search-term>",
Short: "start rsyncd and port forward to it",
Args: cobra.ExactArgs(0),
Args: cobra.MaximumNArgs(1),
Run: func(cmd *cobra.Command, args []string) {
searchTerm := ""
if len(args) > 0 {
searchTerm = args[0]
}
var err error
var role *credentials.Role
var action string
for {
if !selectCachedFirst {
action, role = SelectRoleCredentialsStartingFromSession()
} else {
action, role = SelectRoleCredentialsStartingFromCache()
}
if action == "toggle-view" {
toggleView()
continue
if lastUsed {
var err error
var sessions credentials.Sessions
var session *credentials.Session
var roleTemp credentials.Role
if roleTemp, err = credentials.GetLastUsedRole(); err != nil {
ExitWithError(1, "failed to get last used role", err)
}
if action == "back" {
goBack()
continue
role = &roleTemp
if role.Credentials == nil || role.Credentials.IsExpired() {
if sessions, err = credentials.GetSessions(); err != nil {
ExitWithError(2, "failed to parse sso sessions", err)
}
if session = sessions.FindByName(role.SessionName); session == nil {
ExitWithError(3, "failed to find sso session "+role.SessionName, err)
}
if session.ClientToken == nil || session.ClientToken.IsExpired() {
if err = tui.ClientLogin(session); err != nil {
ExitWithError(4, "failed to authorize device login", err)
}
}
if err = session.RefreshRoleCredentials(role); err != nil {
ExitWithError(5, "failed to get credentials", err)
}
if err = role.Credentials.Save(session.Name, role.CacheKey()); err != nil {
ExitWithError(6, "failed to save credentials", err)
}
}
if action == "delete" {
if role != nil && role.Credentials != nil {
role.Credentials.DeleteCache(role.SessionName, role.CacheKey())
}
for {
if role == nil {
if !selectCachedFirst {
action, role = SelectRoleCredentialsStartingFromSession()
} else {
action, role = SelectRoleCredentialsStartingFromCache()
}
if action == "toggle-view" {
toggleView()
continue
}
if action == "back" {
goBack()
continue
}
if action == "delete" {
if role != nil && role.Credentials != nil {
role.Credentials.DeleteCache(role.SessionName, role.CacheKey())
}
continue
}
continue
}
if instanceId == "" {
if instanceId, action, err = tui.SelectInstance(role); err != nil {
if instanceId, action, err = tui.SelectInstance(role, searchTerm); err != nil {
ExitWithError(19, "failed to pick an instance", err)
} else if action == "back" {
goBack()
Expand Down Expand Up @@ -224,5 +259,6 @@ func init() {
syncCmd.Flags().StringVarP(&instanceId, "instance-id", "i", instanceId, "EC2 instance ID")
syncCmd.Flags().Uint16VarP(&rsyncPort, "rsync-port", "P", rsyncPort, "rsync port")
syncCmd.Flags().Uint16VarP(&localPort, "local-port", "p", localPort, "local port")
syncCmd.Flags().BoolVarP(&lastUsed, "last-used", "l", lastUsed, "select last used credentials")
syncCmd.Flags().BoolVarP(&selectCachedFirst, "cached", "c", selectCachedFirst, "select from cached credentials")
}
4 changes: 3 additions & 1 deletion sdk/picker/picker.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ func (p *picker) render() {
ansi.MoveCursorUp(6 + lines)
}

func (p *picker) Pick() (*option, *keys.KeyCode) {
func (p *picker) Pick(initialFilter string) (*option, *keys.KeyCode) {
p.term = initialFilter
p.filter()
ansi.HideCursor()
defer ansi.ClearDown()
defer ansi.ShowCursor()
Expand Down
12 changes: 6 additions & 6 deletions sdk/tui/tui.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func SelectSession(sessions credentials.Sessions) (string, string, error) {
}
p.AddOption(session.Name, session.Name, session.Region, session.StartUrl, expires)
}
selection, firedKeyCode := p.Pick()
selection, firedKeyCode := p.Pick("")
if firedKeyCode != nil && *firedKeyCode == keys.Tab {
return "", "toggle-view", nil
}
Expand Down Expand Up @@ -114,7 +114,7 @@ func SelectAccount(session *credentials.Session, accountAliases map[string]strin
}
p.AddOption(account.Id, account.Id, name, account.Email)
}
selection, firedKeyCode := p.Pick()
selection, firedKeyCode := p.Pick("")
if firedKeyCode != nil && *firedKeyCode == keys.Esc {
return "", "back", nil
}
Expand All @@ -139,7 +139,7 @@ func SelectRole(roles credentials.Roles) (string, string, error) {
}
p.AddOption(role.Name, role.Name, expires)
}
selection, firedKeyCode := p.Pick()
selection, firedKeyCode := p.Pick("")
if firedKeyCode != nil && *firedKeyCode == keys.Esc {
return "", "back", nil
}
Expand All @@ -149,7 +149,7 @@ func SelectRole(roles credentials.Roles) (string, string, error) {
return selection.Value.(string), "", nil
}

func SelectInstance(role *credentials.Role) (string, string, error) {
func SelectInstance(role *credentials.Role, initialFilter string) (string, string, error) {
instances, err := role.GetManagedInstances()
if err != nil {
return "", "", err
Expand All @@ -163,7 +163,7 @@ func SelectInstance(role *credentials.Role) (string, string, error) {
for _, instance := range instances {
p.AddOption(instance.Id, instance.Id, instance.InstanceType, instance.PrivateIpAddress, instance.PublicIpAddress, instance.Name)
}
selection, firedKeyCode := p.Pick()
selection, firedKeyCode := p.Pick(initialFilter)
if firedKeyCode != nil && *firedKeyCode == keys.Esc {
return "", "back", nil
}
Expand Down Expand Up @@ -199,7 +199,7 @@ func SelectRolesCredentials(accountAliases map[string]string) (*credentials.Role
}
p.AddOption(role, role.SessionName, role.Region, role.AccountId, alias, role.Name, expires)
}
selection, firedKeyCode := p.Pick()
selection, firedKeyCode := p.Pick("")
if firedKeyCode != nil && *firedKeyCode == keys.Tab {
return nil, "toggle-view", nil
}
Expand Down

0 comments on commit 019cf86

Please sign in to comment.