Skip to content

Commit

Permalink
refactor: Enhance SSH key handling and add integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yarlson committed Sep 15, 2024
1 parent c22af5d commit f9dc2ab
Show file tree
Hide file tree
Showing 7 changed files with 406 additions and 21 deletions.
47 changes: 47 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,62 @@ require (
github.com/fatih/color v1.17.0
github.com/pkg/sftp v1.13.6
github.com/spf13/cobra v1.8.1
github.com/stretchr/testify v1.9.0
github.com/testcontainers/testcontainers-go v0.33.0
golang.org/x/crypto v0.27.0
golang.org/x/term v0.24.0
)

require (
dario.cat/mergo v1.0.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/cenkalti/backoff/v4 v4.2.1 // indirect
github.com/containerd/containerd v1.7.18 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v27.1.1+incompatible // indirect
github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/klauspost/compress v1.17.4 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect
github.com/moby/sys/sequential v0.5.0 // indirect
github.com/moby/sys/user v0.1.0 // indirect
github.com/moby/term v0.5.0 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/shirou/gopsutil/v3 v3.23.12 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/yusufpapurcu/wmi v1.2.3 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.24.0 // indirect
go.opentelemetry.io/otel/metric v1.24.0 // indirect
go.opentelemetry.io/otel/trace v1.24.0 // indirect
golang.org/x/sys v0.25.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
164 changes: 163 additions & 1 deletion go.sum

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion internal/setup/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func RunSetup(cmd *cobra.Command, args []string) {
defer client.Close()
success("SSH connection to the server established.")

userKey, e := utils.FindSSHKey(sshKeyPath, false)
userKey, e := utils.FindSSHKey(sshKeyPath)
if e != nil {
warning("Failed to find user SSH key, will use root key for new user on the server")
userKey = rootKey
Expand Down
2 changes: 1 addition & 1 deletion internal/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func ConnectWithUser(host, user string, key []byte) (*Client, error) {
Timeout: 10 * time.Second,
}

client, err := ssh.Dial("tcp", host+":22", config)
client, err := ssh.Dial("tcp", host, config)
if err != nil {
return nil, fmt.Errorf("failed to connect: %v", err)
}
Expand Down
114 changes: 114 additions & 0 deletions internal/ssh/ssh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package ssh

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
"golang.org/x/crypto/ssh"
)

func TestSSHClient(t *testing.T) {
ctx := context.Background()

// Generate SSH key pair
privateKey, publicKey, err := generateSSHKeyPair()
assert.NoError(t, err)

// Start an SSH server container
req := testcontainers.ContainerRequest{
Image: "linuxserver/openssh-server",
ExposedPorts: []string{"2222/tcp"},
Env: map[string]string{
"PUID": "1000",
"PGID": "1000",
"PUBLIC_KEY": publicKey,
"SUDO_ACCESS": "true",
"PASSWORD_ACCESS": "false",
"USER_NAME": "abc",
},
WaitingFor: wait.ForListeningPort("2222/tcp").WithStartupTimeout(2 * time.Minute),
}

container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
assert.NoError(t, err)
defer func() {
err := container.Terminate(ctx)
assert.NoError(t, err)
}()

// Get the host and port
host, err := container.Host(ctx)
assert.NoError(t, err)
mappedPort, err := container.MappedPort(ctx, "2222")
assert.NoError(t, err)
port := mappedPort.Port()

// Read the private key
key := []byte(privateKey)

// Wait for the SSH server to start
time.Sleep(5 * time.Second)

// Connect to the SSH server
client, err := ConnectWithUser(host+":"+port, "abc", key)
assert.NoError(t, err)
defer client.Close()

// Run a command
err = client.RunCommand("echo 'Hello, World!'")
assert.NoError(t, err)

// Upload a file
localFilePath := "test_upload.txt"
remoteFilePath := "/tmp/test_upload.txt"

err = os.WriteFile(localFilePath, []byte("test data"), 0644)
assert.NoError(t, err)
defer os.Remove(localFilePath)

err = client.UploadFile(localFilePath, remoteFilePath)
assert.NoError(t, err)

// Verify the file was uploaded
session, err := client.NewSession()
assert.NoError(t, err)
defer session.Close()

output, err := session.CombinedOutput("cat " + remoteFilePath)
assert.NoError(t, err)
assert.Equal(t, "test data", string(output))
}

func generateSSHKeyPair() (privateKey string, publicKey string, err error) {
// Generate the key pair
privateKeyObj, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return "", "", err
}

privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKeyObj)
privateKeyPem := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: privateKeyBytes,
})

publicKeyObj, err := ssh.NewPublicKey(&privateKeyObj.PublicKey)
if err != nil {
return "", "", err
}
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKeyObj)

return string(privateKeyPem), string(publicKeyBytes), nil
}
47 changes: 29 additions & 18 deletions pkg/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,57 @@ import (
"github.com/enclave-ci/aerie/internal/ssh"
)

func FindSSHKey(keyPath string, isRoot bool) ([]byte, error) {
// sshKeyPath is used only for testing purposes
var sshKeyPath string

// FindSSHKey looks for an SSH key in the given path or in default locations
func FindSSHKey(keyPath string) ([]byte, error) {
if keyPath != "" {
return os.ReadFile(keyPath)
}

// Try to find key in .ssh directory
home, err := os.UserHomeDir()
sshDir, err := getSSHDir()
if err != nil {
return nil, fmt.Errorf("failed to get home directory: %v", err)
return nil, err
}

sshDir := filepath.Join(home, ".ssh")
keyNames := []string{"id_rsa", "id_ecdsa", "id_ed25519"}

for _, name := range keyNames {
path := filepath.Join(sshDir, name)
if _, err := os.Stat(path); err == nil {
key, err := os.ReadFile(path)
if err == nil {
return key, nil
}
key, err := os.ReadFile(path)
if err == nil {
return key, nil
}
}

if isRoot {
return nil, fmt.Errorf("no suitable SSH key found in .ssh directory")
}
return nil, nil
return nil, fmt.Errorf("no suitable SSH key found in %s", sshDir)
}

// FindKeyAndConnectWithUser finds an SSH key and establishes a connection
func FindKeyAndConnectWithUser(host, user, keyPath string) (*ssh.Client, []byte, error) {
key, err := FindSSHKey(keyPath, true)
key, err := FindSSHKey(keyPath)
if err != nil {
return nil, nil, fmt.Errorf("failed to find SSH key: %v", err)
return nil, nil, fmt.Errorf("failed to find SSH key: %w", err)
}

client, err := ssh.ConnectWithUser(host, user, key)
if err != nil {
return nil, nil, fmt.Errorf("failed to establish SSH connection: %v", err)
return nil, nil, fmt.Errorf("failed to establish SSH connection: %w", err)
}

return client, key, nil
}

// getSSHDir returns the SSH directory path
func getSSHDir() (string, error) {
if sshKeyPath != "" {
return sshKeyPath, nil
}

home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("failed to get home directory: %w", err)
}

return filepath.Join(home, ".ssh"), nil
}
51 changes: 51 additions & 0 deletions pkg/utils/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package utils

import (
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
)

func TestFindSSHKey(t *testing.T) {
// Create temporary SSH keys in a temp directory
tempDir := t.TempDir()
sshKeyPath = tempDir

keyContent := []byte("test-key")
keyNames := []string{"id_rsa", "id_ecdsa", "id_ed25519"}

// Write test keys
for _, name := range keyNames {
keyPath := filepath.Join(tempDir, name)
err := os.WriteFile(keyPath, keyContent, 0600)
assert.NoError(t, err)
}

// Override the home directory to point to tempDir
originalHome := os.Getenv("HOME")
defer os.Setenv("HOME", originalHome)

_ = os.Setenv("HOME", tempDir)

// Test with no keyPath, should find id_rsa first
key, err := FindSSHKey("")
assert.NoError(t, err)
assert.Equal(t, keyContent, key)

// Test with specified keyPath
specifiedKeyPath := filepath.Join(tempDir, "custom_key")
err = os.WriteFile(specifiedKeyPath, keyContent, 0600)
assert.NoError(t, err)

key, err = FindSSHKey(specifiedKeyPath)
assert.NoError(t, err)
assert.Equal(t, keyContent, key)

// Test when no keys are found
_ = os.RemoveAll(tempDir)
key, err = FindSSHKey("")
assert.Error(t, err)
assert.Nil(t, key)
}

0 comments on commit f9dc2ab

Please sign in to comment.