Skip to content

Commit

Permalink
Add password provider to SSH connection (#63)
Browse files Browse the repository at this point in the history
* Add password provider to SSH connection
This allows to implement a custom password provider (e.g. user input) for
provate keys which require a password

* Fix ssh-agent usage

* rearrange code

* smaller fixes
  • Loading branch information
weakpixel authored Aug 8, 2022
1 parent 8ac0c0f commit fd33172
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 32 deletions.
40 changes: 40 additions & 0 deletions examples/password/password.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package main

import (
"flag"
"fmt"
"syscall"

"github.com/k0sproject/rig"
"golang.org/x/crypto/ssh/terminal"
)

/*
This example shows how to use a key password provider
*/

func main() {
user := flag.String("user", "root", "SSH User")
host := flag.String("host", "localhost", "Host")
flag.Parse()
conn := rig.Connection{
SSH: &rig.SSH{
User: *user,
Address: *host,
PasswordCallback: func() (string, error) {
fmt.Println("Enter password:")
pass, err := terminal.ReadPassword(int(syscall.Stdin))
return string(pass), err
},
},
}
if err := conn.Connect(); err != nil {
panic(err)
}
defer conn.Disconnect()
output, err := conn.ExecOutput("ls -al")
if err != nil {
panic(err)
}
println(output)
}
93 changes: 61 additions & 32 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ import (

// SSH describes an SSH connection
type SSH struct {
Address string `yaml:"address" validate:"required,hostname|ip"`
User string `yaml:"user" validate:"required" default:"root"`
Port int `yaml:"port" default:"22" validate:"gt=0,lte=65535"`
KeyPath string `yaml:"keyPath" validate:"omitempty"`
HostKey string `yaml:"hostKey,omitempty"`
Bastion *SSH `yaml:"bastion,omitempty"`

name string
Address string `yaml:"address" validate:"required,hostname|ip"`
User string `yaml:"user" validate:"required" default:"root"`
Port int `yaml:"port" default:"22" validate:"gt=0,lte=65535"`
KeyPath string `yaml:"keyPath" validate:"omitempty"`
HostKey string `yaml:"hostKey,omitempty"`
Bastion *SSH `yaml:"bastion,omitempty"`
PasswordCallback PasswordCallback
name string

isWindows bool
knowOs bool
Expand All @@ -48,6 +48,8 @@ type SSH struct {
client *ssh.Client
}

type PasswordCallback func() (secret string, err error)

const DefaultKeypath = "~/.ssh/id_rsa"

// SetDefaults sets various default values
Expand Down Expand Up @@ -131,41 +133,25 @@ func (c *SSH) Connect() error {
config.HostKeyCallback = trustedHostKeyCallback(c.HostKey)
}

var pubkeySigners []ssh.Signer

_, err := os.Stat(c.KeyPath)
if err != nil && !c.keypathDefault {
return err
}
if err == nil {
var key []byte
key, err = os.ReadFile(c.KeyPath)
if err != nil {
return err
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
log.Infof("can't parse keyfile %s: %s", c.KeyPath, err.Error())
} else {
pubkeySigners = append(pubkeySigners, signer)
}
}

sshAgentSock := os.Getenv("SSH_AUTH_SOCK")
if sshAgentSock != "" {
sshAgent, err := net.Dial("unix", sshAgentSock)
if err != nil {
log.Errorf("can't connect to SSH agent auth socket %s: %s", sshAgentSock, err)
} else {
signers, err := agent.NewClient(sshAgent).Signers()
if err == nil {
pubkeySigners = append(pubkeySigners, signers...)
if err == nil && len(signers) > 0 {
config.Auth = append(config.Auth, ssh.PublicKeys(signers...))
}
}
}

if len(pubkeySigners) > 0 {
config.Auth = append(config.Auth, ssh.PublicKeys(pubkeySigners...))
privateKeyAuth, err := c.getPrivateKeys()
if err != nil {
return err
}
if len(privateKeyAuth) > 0 {
config.Auth = append(config.Auth, privateKeyAuth...)
}

dst := net.JoinHostPort(c.Address, strconv.Itoa(c.Port))
Expand Down Expand Up @@ -196,6 +182,49 @@ func (c *SSH) Connect() error {
return nil
}

func (c *SSH) getPrivateKeys() ([]ssh.AuthMethod, error) {
result := []ssh.AuthMethod{}
_, err := os.Stat(c.KeyPath)
if err != nil && !c.keypathDefault {
return result, err
}
if err == nil {
var key []byte
key, err = os.ReadFile(c.KeyPath)
if err != nil {
return result, err
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
if c.PasswordCallback != nil {
switch err.(type) {
case *ssh.PassphraseMissingError:
auth := ssh.PublicKeysCallback(func() ([]ssh.Signer, error) {
pass, err := c.PasswordCallback()
if err != nil {
return nil, fmt.Errorf("password provider failed: %s", err)
}
signer, err := ssh.ParsePrivateKeyWithPassphrase(key, []byte(pass))
if err != nil {

return nil, err
}
return []ssh.Signer{signer}, nil
})

result = append(result, auth)
default:
log.Infof("can't parse keyfile %s: %s", c.KeyPath, err.Error())
}
} else {
log.Infof("can't parse keyfile %s: %s", c.KeyPath, err.Error())
}
}
result = append(result, ssh.PublicKeys(signer))
}
return result, nil
}

// Exec executes a command on the host
func (c *SSH) Exec(cmd string, opts ...exec.Option) error {
o := exec.Build(opts...)
Expand Down

0 comments on commit fd33172

Please sign in to comment.