From fd331722d24562957a89502c5f9dcb10060de06a Mon Sep 17 00:00:00 2001 From: WeakPixel <98713283+weakpixel@users.noreply.github.com> Date: Mon, 8 Aug 2022 08:58:04 +0200 Subject: [PATCH] Add password provider to SSH connection (#63) * Add password provider to SSH connection This allows to implement a custom password provider (e.g. user input) for provate keys which require a password * Fix ssh-agent usage * rearrange code * smaller fixes --- examples/password/password.go | 40 +++++++++++++++ ssh.go | 93 +++++++++++++++++++++++------------ 2 files changed, 101 insertions(+), 32 deletions(-) create mode 100644 examples/password/password.go diff --git a/examples/password/password.go b/examples/password/password.go new file mode 100644 index 00000000..e605b371 --- /dev/null +++ b/examples/password/password.go @@ -0,0 +1,40 @@ +package main + +import ( + "flag" + "fmt" + "syscall" + + "github.com/k0sproject/rig" + "golang.org/x/crypto/ssh/terminal" +) + +/* + This example shows how to use a key password provider +*/ + +func main() { + user := flag.String("user", "root", "SSH User") + host := flag.String("host", "localhost", "Host") + flag.Parse() + conn := rig.Connection{ + SSH: &rig.SSH{ + User: *user, + Address: *host, + PasswordCallback: func() (string, error) { + fmt.Println("Enter password:") + pass, err := terminal.ReadPassword(int(syscall.Stdin)) + return string(pass), err + }, + }, + } + if err := conn.Connect(); err != nil { + panic(err) + } + defer conn.Disconnect() + output, err := conn.ExecOutput("ls -al") + if err != nil { + panic(err) + } + println(output) +} diff --git a/ssh.go b/ssh.go index c3653478..c20bb32c 100644 --- a/ssh.go +++ b/ssh.go @@ -32,14 +32,14 @@ import ( // SSH describes an SSH connection type SSH struct { - Address string `yaml:"address" validate:"required,hostname|ip"` - User string `yaml:"user" validate:"required" default:"root"` - Port int `yaml:"port" default:"22" validate:"gt=0,lte=65535"` - KeyPath string `yaml:"keyPath" validate:"omitempty"` - HostKey string `yaml:"hostKey,omitempty"` - Bastion *SSH `yaml:"bastion,omitempty"` - - name string + Address string `yaml:"address" validate:"required,hostname|ip"` + User string `yaml:"user" validate:"required" default:"root"` + Port int `yaml:"port" default:"22" validate:"gt=0,lte=65535"` + KeyPath string `yaml:"keyPath" validate:"omitempty"` + HostKey string `yaml:"hostKey,omitempty"` + Bastion *SSH `yaml:"bastion,omitempty"` + PasswordCallback PasswordCallback + name string isWindows bool knowOs bool @@ -48,6 +48,8 @@ type SSH struct { client *ssh.Client } +type PasswordCallback func() (secret string, err error) + const DefaultKeypath = "~/.ssh/id_rsa" // SetDefaults sets various default values @@ -131,26 +133,6 @@ func (c *SSH) Connect() error { config.HostKeyCallback = trustedHostKeyCallback(c.HostKey) } - var pubkeySigners []ssh.Signer - - _, err := os.Stat(c.KeyPath) - if err != nil && !c.keypathDefault { - return err - } - if err == nil { - var key []byte - key, err = os.ReadFile(c.KeyPath) - if err != nil { - return err - } - signer, err := ssh.ParsePrivateKey(key) - if err != nil { - log.Infof("can't parse keyfile %s: %s", c.KeyPath, err.Error()) - } else { - pubkeySigners = append(pubkeySigners, signer) - } - } - sshAgentSock := os.Getenv("SSH_AUTH_SOCK") if sshAgentSock != "" { sshAgent, err := net.Dial("unix", sshAgentSock) @@ -158,14 +140,18 @@ func (c *SSH) Connect() error { log.Errorf("can't connect to SSH agent auth socket %s: %s", sshAgentSock, err) } else { signers, err := agent.NewClient(sshAgent).Signers() - if err == nil { - pubkeySigners = append(pubkeySigners, signers...) + if err == nil && len(signers) > 0 { + config.Auth = append(config.Auth, ssh.PublicKeys(signers...)) } } } - if len(pubkeySigners) > 0 { - config.Auth = append(config.Auth, ssh.PublicKeys(pubkeySigners...)) + privateKeyAuth, err := c.getPrivateKeys() + if err != nil { + return err + } + if len(privateKeyAuth) > 0 { + config.Auth = append(config.Auth, privateKeyAuth...) } dst := net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) @@ -196,6 +182,49 @@ func (c *SSH) Connect() error { return nil } +func (c *SSH) getPrivateKeys() ([]ssh.AuthMethod, error) { + result := []ssh.AuthMethod{} + _, err := os.Stat(c.KeyPath) + if err != nil && !c.keypathDefault { + return result, err + } + if err == nil { + var key []byte + key, err = os.ReadFile(c.KeyPath) + if err != nil { + return result, err + } + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + if c.PasswordCallback != nil { + switch err.(type) { + case *ssh.PassphraseMissingError: + auth := ssh.PublicKeysCallback(func() ([]ssh.Signer, error) { + pass, err := c.PasswordCallback() + if err != nil { + return nil, fmt.Errorf("password provider failed: %s", err) + } + signer, err := ssh.ParsePrivateKeyWithPassphrase(key, []byte(pass)) + if err != nil { + + return nil, err + } + return []ssh.Signer{signer}, nil + }) + + result = append(result, auth) + default: + log.Infof("can't parse keyfile %s: %s", c.KeyPath, err.Error()) + } + } else { + log.Infof("can't parse keyfile %s: %s", c.KeyPath, err.Error()) + } + } + result = append(result, ssh.PublicKeys(signer)) + } + return result, nil +} + // Exec executes a command on the host func (c *SSH) Exec(cmd string, opts ...exec.Option) error { o := exec.Build(opts...)