Skip to content

Commit

Permalink
mockssh: expose default command handler for reuse, remove RemoteDir a…
Browse files Browse the repository at this point in the history
…nd RemoteEnv
  • Loading branch information
pjcdawkins committed Jan 15, 2025
1 parent f58a78d commit 60dcb79
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 23 deletions.
51 changes: 28 additions & 23 deletions pkg/mockssh/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"io"
"net"
"net/http"
"os"
"os/exec"
"sync"
"testing"
Expand All @@ -31,9 +30,9 @@ type Server struct {
CertAuthorityKeys []ssh.PublicKey
CertChecker ssh.CertChecker

// RemoteEnv, RemoteDir and CommandHandler are optional configuration.
RemoteEnv []string
RemoteDir string
// An optional CommandHandler, which responds to commands sent over SSH.
// NewServer will give this a default using ExecHandler, which can also
// be reused from custom handlers.
CommandHandler CommandHandler

// listener and port are set after Start.
Expand All @@ -47,7 +46,7 @@ type CommandIO struct {
StdErr io.Writer
}

type CommandHandler func(conn ssh.ConnMetadata, command string, io CommandIO) int
type CommandHandler func(conn ssh.ConnMetadata, command string, commandIO CommandIO) int

// NewServer creates and starts a local SSH server for a test.
// It must be stopped with the Server.Stop method.
Expand All @@ -65,9 +64,8 @@ func NewServer(t *testing.T, authorityEndpoint string) (*Server, error) {
}

s := &Server{t: t, hostKey: hk}
s.CommandHandler = s.defaultCommandHandler
s.CommandHandler = ExecHandler("", nil)
s.CertChecker = s.defaultCertChecker()
s.RemoteDir = t.TempDir()
s.CertAuthorityKeys = keys

if err := s.start(); err != nil {
Expand All @@ -89,6 +87,10 @@ func (s *Server) HostKeyConfig() string {
)
}

func (s *Server) HostKey() ssh.PublicKey {
return s.hostKey.PublicKey()
}

func (s *Server) start() error {
t := s.t

Expand Down Expand Up @@ -148,22 +150,25 @@ func (s *Server) Stop() error {
return nil
}

func (s *Server) defaultCommandHandler(_ ssh.ConnMetadata, command string, commandIO CommandIO) int {
c := exec.Command("bash", "-c", command)
c.Stdout = commandIO.StdOut
c.Stderr = commandIO.StdErr
c.Stdin = commandIO.StdIn
c.Dir = s.RemoteDir
c.Env = append(os.Environ(), s.RemoteEnv...)
if err := c.Run(); err != nil {
exitErr := &exec.ExitError{}
if errors.As(err, &exitErr) {
return exitErr.ExitCode()
// ExecHandler returns a CommandHandler to execute a command in the given environment.
func ExecHandler(workingDir string, env []string) CommandHandler {
return func(_ ssh.ConnMetadata, command string, commandIO CommandIO) int {
c := exec.Command("bash", "-c", command)
c.Stdout = commandIO.StdOut
c.Stderr = commandIO.StdErr
c.Stdin = commandIO.StdIn
c.Dir = workingDir
c.Env = env
if err := c.Run(); err != nil {
exitErr := &exec.ExitError{}
if errors.As(err, &exitErr) {
return exitErr.ExitCode()
}
_, _ = fmt.Fprintf(commandIO.StdErr, "Failed to execute command: %v", err)
return 1
}
_, _ = fmt.Fprintf(commandIO.StdErr, "Failed to execute command: %v", err)
return 1
return 0
}
return 0
}

func (s *Server) defaultCertChecker() ssh.CertChecker {
Expand Down Expand Up @@ -253,9 +258,9 @@ func (s *Server) handleChannels(conn ssh.ConnMetadata, chans <-chan ssh.NewChann
for {
select {
case s := <-exitWithStatus:
_, err = channel.SendRequest("exit-status", false, ssh.Marshal(struct{ Status int }{s}))
_, err = channel.SendRequest("exit-status", false, ssh.Marshal(struct{ Status uint32 }{uint32(s)})) //nolint: gosec
if err != nil {
t.Errorf("Failed to send exit status: %v", err)
t.Fatalf("Failed to send exit status: %v", err)
}
goto closeChannel
case <-timer.C:
Expand Down
105 changes: 105 additions & 0 deletions pkg/mockssh/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package mockssh_test

import (
"bytes"
"crypto/ed25519"
"crypto/rand"
"encoding/json"
"fmt"
"net"
"net/http"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"

"github.com/platformsh/cli/pkg/mockapi"
"github.com/platformsh/cli/pkg/mockssh"
)

func TestServer(t *testing.T) {
authServer := mockapi.NewAuthServer(t)
defer authServer.Close()

sshServer, err := mockssh.NewServer(t, authServer.URL+"/ssh/authority")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
_ = sshServer.Stop()
})

tempDir := t.TempDir()
sshServer.CommandHandler = mockssh.ExecHandler(tempDir, []string{})

cert := getTestSSHAuth(t, authServer.URL)

// Create the SSH client configuration
address := fmt.Sprintf("127.0.0.1:%d", sshServer.Port())
config := &ssh.ClientConfig{
User: "test",
Auth: []ssh.AuthMethod{ssh.PublicKeys(cert)},
HostKeyCallback: func(_ string, remote net.Addr, key ssh.PublicKey) error {
if remote.String() != address {
return fmt.Errorf("unexpected address: %s", remote.String())
}
if bytes.Equal(sshServer.HostKey().Marshal(), key.Marshal()) {
return nil
}
return fmt.Errorf("host key mismatch")
},
}

client, err := ssh.Dial("tcp", address, config)
require.NoError(t, err)
defer client.Close()

session, err := client.NewSession()
require.NoError(t, err)
defer session.Close()

stdOutBuffer := &bytes.Buffer{}
session.Stdout = stdOutBuffer

require.NoError(t, session.Run("pwd"))
assert.Equal(t, tempDir, strings.TrimRight(stdOutBuffer.String(), "\n"))

session2, err := client.NewSession()
require.NoError(t, err)
defer session2.Close()
err = session2.Run("false")
assert.Error(t, err)
var exitErr *ssh.ExitError
assert.ErrorAs(t, err, &exitErr)
assert.Equal(t, 1, exitErr.ExitStatus())
}

func getTestSSHAuth(t *testing.T, authServerURL string) ssh.Signer {
t.Helper()

// Generate a keypair
_, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
s, err := ssh.NewSignerFromKey(priv)
require.NoError(t, err)

b, err := json.Marshal(struct{ Key string }{string(ssh.MarshalAuthorizedKey(s.PublicKey()))})
require.NoError(t, err)
resp, err := http.DefaultClient.Post(authServerURL+"/ssh", "application/json", bytes.NewReader(b))
require.NoError(t, err)
defer resp.Body.Close()

var rs struct{ Certificate string }
require.NoError(t, json.NewDecoder(resp.Body).Decode(&rs))

parsed, _, _, _, err := ssh.ParseAuthorizedKey([]byte(rs.Certificate)) //nolint: dogsled
require.NoError(t, err)

cert, _ := parsed.(*ssh.Certificate)
certSigner, err := ssh.NewCertSigner(cert, s)
require.NoError(t, err)

return certSigner
}

0 comments on commit 60dcb79

Please sign in to comment.