From 305a4d9c2767c32db8b7a737456db85a891d8618 Mon Sep 17 00:00:00 2001 From: Yar Kravtsov Date: Wed, 4 Dec 2024 06:57:26 +0200 Subject: [PATCH] refactor: Integrate SSH tunnel functionality into SSH package --- pkg/ssh/ssh.go | 66 ++++++++- pkg/tunnel/tunnel.go | 305 -------------------------------------- pkg/tunnel/tunnel_test.go | 246 ------------------------------ 3 files changed, 64 insertions(+), 553 deletions(-) delete mode 100644 pkg/tunnel/tunnel.go delete mode 100644 pkg/tunnel/tunnel_test.go diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index ddd5976..f2f4380 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -2,6 +2,8 @@ package ssh import ( "fmt" + "io" + "net" "os" "path/filepath" "strings" @@ -26,11 +28,22 @@ func NewSSHClientWithKey(host string, port int, user string, key []byte) (*ssh.C addr := fmt.Sprintf("%s:%d", host, port) - client, err := ssh.Dial("tcp", addr, config) + conn, err := net.DialTimeout("tcp", addr, config.Timeout) if err != nil { - return nil, fmt.Errorf("failed to connect: %v", err) + return nil, fmt.Errorf("failed to dial TCP connection: %v", err) + } + + if tcpConn, ok := conn.(*net.TCPConn); ok { + _ = tcpConn.SetKeepAlive(true) + _ = tcpConn.SetKeepAlivePeriod(30 * time.Second) + } + + sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + if err != nil { + return nil, fmt.Errorf("failed to establish SSH connection: %v", err) } + client := ssh.NewClient(sshConn, chans, reqs) return client, nil } @@ -115,3 +128,52 @@ func getSSHDir() (string, error) { return filepath.Join(home, ".ssh"), nil } + +// CreateSSHTunnel establishes an SSH tunnel from a local port to a remote address through an SSH server. +// It listens on localPort and forwards connections to remoteAddr via the SSH server at host:port. +// Authentication is done using the provided user and key (private key bytes). +func CreateSSHTunnel(host string, port int, user string, key []byte, localPort string, remoteAddr string) error { + client, err := NewSSHClientWithKey(host, port, user, key) + if err != nil { + return fmt.Errorf("failed to establish SSH connection: %v", err) + } + + localListener, err := net.Listen("tcp", "localhost:"+localPort) + if err != nil { + return fmt.Errorf("failed to listen on local port %s: %v", localPort, err) + } + defer localListener.Close() + fmt.Printf("Listening on localhost:%s, forwarding to %s via %s:%d\n", localPort, remoteAddr, host, port) + + for { + localConn, err := localListener.Accept() + if err != nil { + fmt.Printf("Failed to accept local connection: %v\n", err) + continue + } + + remoteConn, err := client.Dial("tcp", remoteAddr) + if err != nil { + fmt.Printf("Failed to dial remote address %s: %v\n", remoteAddr, err) + localConn.Close() + continue + } + + go func() { + defer localConn.Close() + defer remoteConn.Close() + + go func() { + _, err := io.Copy(remoteConn, localConn) + if err != nil { + fmt.Printf("Error copying from local to remote: %v\n", err) + } + }() + + _, err := io.Copy(localConn, remoteConn) + if err != nil { + fmt.Printf("Error copying from remote to local: %v\n", err) + } + }() + } +} diff --git a/pkg/tunnel/tunnel.go b/pkg/tunnel/tunnel.go deleted file mode 100644 index b41a60c..0000000 --- a/pkg/tunnel/tunnel.go +++ /dev/null @@ -1,305 +0,0 @@ -package tunnel - -import ( - "bytes" - "context" - "fmt" - "io" - "net" - "os" - "path/filepath" - "strings" - - "github.com/bramvdbogaerde/go-scp" - "golang.org/x/crypto/ssh" - - ftlssh "github.com/yarlson/ftl/pkg/ssh" -) - -type Tunnel struct { - sshClient *ssh.Client -} - -func NewTunnel(sshClient *ssh.Client) *Tunnel { - return &Tunnel{ - sshClient: sshClient, - } -} - -func (c *Tunnel) Close() error { - if c.sshClient == nil { - return nil - } - - err := c.sshClient.Close() - c.sshClient = nil - return err -} - -func (c *Tunnel) RunCommand(ctx context.Context, command string, args ...string) (io.Reader, error) { - session, err := c.sshClient.NewSession() - if err != nil { - return nil, fmt.Errorf("unable to create session: %v", err) - } - defer session.Close() - - fullCommand := command - if len(args) > 0 { - var quotedArgs []string - for _, arg := range args { - quotedArgs = append(quotedArgs, fmt.Sprintf("%q", arg)) - } - fullCommand += " " + strings.Join(quotedArgs, " ") - fullCommand = strings.Replace(fullCommand, "\\n", "\n", -1) - } - - pr, pw := io.Pipe() - - session.Stdout = pw - session.Stderr = pw - - if err := session.Start(fullCommand); err != nil { - _ = pw.Close() - return nil, fmt.Errorf("failed to start command: %w", err) - } - - done := make(chan error, 1) - go func() { - done <- session.Wait() - _ = pw.Close() - }() - - var output bytes.Buffer - outputChan := make(chan struct{}) - - go func() { - _, _ = io.Copy(&output, pr) - close(outputChan) - }() - - var commandErr error - select { - case <-ctx.Done(): - _ = session.Signal(ssh.SIGTERM) - commandErr = ctx.Err() - case err := <-done: - if err != nil { - commandErr = fmt.Errorf("command failed: %w", err) - } - } - - <-outputChan - - if commandErr != nil { - return bytes.NewReader(output.Bytes()), commandErr - } - - return bytes.NewReader(output.Bytes()), nil -} - -func (c *Tunnel) CopyFile(ctx context.Context, src, dst string) error { - client, err := scp.NewClientBySSH(c.sshClient) - if err != nil { - return fmt.Errorf("failed to create SCP client: %w", err) - } - - file, err := os.Open(src) - if err != nil { - return fmt.Errorf("failed to open file: %w", err) - } - - return client.CopyFile(ctx, file, dst, "0644") -} - -func (c *Tunnel) RunCommands(ctx context.Context, commands []string) error { - for _, command := range commands { - if err := c.runSingleCommand(ctx, command); err != nil { - return err - } - } - return nil -} - -func (c *Tunnel) runSingleCommand(ctx context.Context, command string) error { - session, err := c.sshClient.NewSession() - if err != nil { - return fmt.Errorf("unable to create session: %w", err) - } - defer session.Close() - - pr, pw := io.Pipe() - defer pr.Close() - - session.Stdout = pw - session.Stderr = pw - - if err := session.Start(command); err != nil { - return fmt.Errorf("failed to start command: %w", err) - } - - done := make(chan error, 1) - go func() { - done <- session.Wait() - _ = pw.Close() - }() - - var output strings.Builder - - go func() { - _, _ = io.Copy(&output, pr) - }() - - select { - case <-ctx.Done(): - _ = session.Signal(ssh.SIGTERM) - return ctx.Err() - case err := <-done: - if err != nil { - return fmt.Errorf("%w\nOutput: %s", err, output.String()) - } - return nil - } -} - -func (c *Tunnel) RunCommandWithOutput(command string) (string, error) { - session, err := c.sshClient.NewSession() - if err != nil { - return "", err - } - defer session.Close() - - output, err := session.CombinedOutput(command) - if err != nil { - return "", fmt.Errorf("command failed: %v\nOutput: %s", err, string(output)) - } - - return string(output), nil -} - -// sshKeyPath is used only for testing purposes -var sshKeyPath string - -// FindSSHKey looks for an SSH key in the given path or in default locations -func FindSSHKey(keyPath string) ([]byte, error) { - if keyPath != "" { - if strings.HasPrefix(keyPath, "~") { - home, err := os.UserHomeDir() - if err != nil { - return nil, fmt.Errorf("failed to get home directory: %w", err) - } - keyPath = filepath.Join(home, keyPath[1:]) - } - - return os.ReadFile(keyPath) - } - - sshDir, err := getSSHDir() - if err != nil { - return nil, err - } - - keyNames := []string{"id_rsa", "id_ecdsa", "id_ed25519"} - for _, name := range keyNames { - path := filepath.Join(sshDir, name) - key, err := os.ReadFile(path) - if err == nil { - return key, nil - } - } - - return nil, fmt.Errorf("no suitable SSH key found in %s", sshDir) -} - -// FindKeyAndConnectWithUser finds an SSH key and establishes a connection -func FindKeyAndConnectWithUser(host string, port int, user, keyPath string) (*ssh.Client, []byte, error) { - key, err := FindSSHKey(keyPath) - if err != nil { - return nil, nil, fmt.Errorf("failed to find SSH key: %w", err) - } - - client, err := ftlssh.NewSSHClientWithKey(host, port, user, key) - if err != nil { - return nil, nil, fmt.Errorf("failed to establish SSH connection: %w", err) - } - - return client, key, nil -} - -// getSSHDir returns the SSH directory path -func getSSHDir() (string, error) { - if sshKeyPath != "" { - return sshKeyPath, nil - } - - home, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get home directory: %w", err) - } - - return filepath.Join(home, ".ssh"), nil -} - -func (c *Tunnel) CreateTunnel(ctx context.Context, localPort, remotePort int) error { - listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", localPort)) - if err != nil { - return fmt.Errorf("failed to start local listener: %w", err) - } - defer listener.Close() - - errChan := make(chan error, 1) - - go func() { - for { - local, err := listener.Accept() - if err != nil { - if !strings.Contains(err.Error(), "use of closed network connection") { - errChan <- fmt.Errorf("failed to accept connection: %w", err) - } - return - } - - go func(localConn net.Conn) { - defer localConn.Close() - - remoteConn, err := c.sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort)) - if err != nil { - fmt.Printf("Failed to connect to remote port: %v\n", err) - return - } - defer remoteConn.Close() - - copyErrChan := make(chan error, 2) - doneChan := make(chan bool, 2) - - go func() { - _, err := io.Copy(localConn, remoteConn) - copyErrChan <- err - doneChan <- true - }() - - go func() { - _, err := io.Copy(remoteConn, localConn) - copyErrChan <- err - doneChan <- true - }() - - select { - case err := <-copyErrChan: - if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { - fmt.Printf("Copy operation failed: %v\n", err) - } - <-doneChan - case <-ctx.Done(): - return - } - }(local) - } - }() - - select { - case err := <-errChan: - return err - case <-ctx.Done(): - return ctx.Err() - } -} diff --git a/pkg/tunnel/tunnel_test.go b/pkg/tunnel/tunnel_test.go deleted file mode 100644 index af1f171..0000000 --- a/pkg/tunnel/tunnel_test.go +++ /dev/null @@ -1,246 +0,0 @@ -package tunnel - -import ( - "context" - "crypto/rand" - "encoding/pem" - "fmt" - "io" - "net" - "net/http" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "github.com/docker/docker/api/types" - "github.com/docker/docker/pkg/archive" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/wait" - "golang.org/x/crypto/ed25519" - "golang.org/x/crypto/ssh" - - ftlssh "github.com/yarlson/ftl/pkg/ssh" -) - -func TestFindSSHKey(t *testing.T) { - // Create temporary SSH keys in a temp directory - tempDir := t.TempDir() - sshKeyPath = tempDir - - keyContent := []byte("test-key") - keyNames := []string{"id_rsa", "id_ecdsa", "id_ed25519"} - - // Write test keys - for _, name := range keyNames { - keyPath := filepath.Join(tempDir, name) - err := os.WriteFile(keyPath, keyContent, 0600) - assert.NoError(t, err) - } - - // Override the home directory to point to tempDir - originalHome := os.Getenv("HOME") - defer os.Setenv("HOME", originalHome) - - _ = os.Setenv("HOME", tempDir) - - // Test with no keyPath, should find id_rsa first - key, err := FindSSHKey("") - assert.NoError(t, err) - assert.Equal(t, keyContent, key) - - // Test with specified keyPath - specifiedKeyPath := filepath.Join(tempDir, "custom_key") - err = os.WriteFile(specifiedKeyPath, keyContent, 0600) - assert.NoError(t, err) - - key, err = FindSSHKey(specifiedKeyPath) - assert.NoError(t, err) - assert.Equal(t, keyContent, key) - - // Test when no keys are found - _ = os.RemoveAll(tempDir) - key, err = FindSSHKey("") - assert.Error(t, err) - assert.Nil(t, key) -} - -func TestRunner_CreateTunnel(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) - defer cancel() - - // Setup - tempDir := t.TempDir() - privateKeyPath, publicKeyPath := filepath.Join(tempDir, "id_rsa"), filepath.Join(tempDir, "id_rsa.pub") - require.NoError(t, generateSSHKeyPair(privateKeyPath, publicKeyPath)) - - publicKeyBytes, err := os.ReadFile(publicKeyPath) - require.NoError(t, err) - - imageTag := "ssh-nginx-test:latest" - dockerfilePath := createDockerfile(t, strings.TrimSpace(string(publicKeyBytes))) - buildImage(t, ctx, dockerfilePath, imageTag) - - // Start container - container := startContainer(t, ctx, imageTag) - defer func() { require.NoError(t, container.Terminate(ctx)) }() - - sshPort, err := container.MappedPort(ctx, "22") - require.NoError(t, err) - - // Create SSH client - key, err := os.ReadFile(privateKeyPath) - require.NoError(t, err) - - client, err := ftlssh.NewSSHClientWithKey("localhost", sshPort.Int(), "root", key) - require.NoError(t, err) - defer client.Close() - - tunnel := NewTunnel(client) - - // Test - localPort, remotePort := 23451, 80 - tunnelCtx, tunnelCancel := context.WithCancel(ctx) - defer tunnelCancel() - - errCh := make(chan error, 1) - go func() { - errCh <- tunnel.CreateTunnel(tunnelCtx, localPort, remotePort) - }() - - require.NoError(t, waitForPort(t, localPort, 5*time.Second)) - - resp, err := makeHTTPRequest(localPort) - require.NoError(t, err) - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Contains(t, string(body), "Welcome to nginx!") - - // Cleanup - tunnelCancel() - select { - case err := <-errCh: - assert.ErrorIs(t, err, context.Canceled) - case <-time.After(5 * time.Second): - t.Error("Tunnel didn't close within timeout") - } -} - -func waitForPort(t *testing.T, port int, timeout time.Duration) error { - t.Helper() - deadline := time.Now().Add(timeout) - for time.Now().Before(deadline) { - conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second) - if err == nil { - conn.Close() - return nil - } - time.Sleep(100 * time.Millisecond) - } - return fmt.Errorf("port %d not available after %s", port, timeout) -} - -func startContainer(t *testing.T, ctx context.Context, imageTag string) testcontainers.Container { - t.Helper() - req := testcontainers.ContainerRequest{ - Image: imageTag, - ExposedPorts: []string{"22/tcp", "80/tcp"}, - WaitingFor: wait.ForListeningPort("22/tcp").WithStartupTimeout(time.Minute), - } - - container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ - ContainerRequest: req, - Started: true, - }) - require.NoError(t, err) - return container -} - -func makeHTTPRequest(port int) (*http.Response, error) { - httpClient := &http.Client{Timeout: 5 * time.Second} - return httpClient.Get(fmt.Sprintf("http://localhost:%d", port)) -} - -func generateSSHKeyPair(privateKeyPath, publicKeyPath string) error { - publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - return err - } - - privateKeyBytes, err := ssh.MarshalPrivateKey(privateKey, "") - if err != nil { - return err - } - - privateKeyPEM := &pem.Block{ - Type: "OPENSSH PRIVATE KEY", - Bytes: privateKeyBytes.Bytes, - } - - privateKeyPEMBytes := pem.EncodeToMemory(privateKeyPEM) - - err = os.WriteFile(privateKeyPath, privateKeyPEMBytes, 0600) - if err != nil { - return err - } - - publicKeySSH, err := ssh.NewPublicKey(publicKey) - if err != nil { - return err - } - - return os.WriteFile(publicKeyPath, ssh.MarshalAuthorizedKey(publicKeySSH), 0644) -} - -func createDockerfile(t *testing.T, publicKey string) string { - content := `FROM ubuntu:20.04 - -RUN apt-get update && apt-get install -y openssh-server nginx - -RUN mkdir /root/.ssh && \ - echo "%s" > /root/.ssh/authorized_keys && \ - chmod 700 /root/.ssh && \ - chmod 600 /root/.ssh/authorized_keys - -RUN mkdir /run/sshd - -EXPOSE 22 80 - -CMD service ssh start && nginx -g 'daemon off;' -` - - dockerfilePath := filepath.Join(t.TempDir(), "Dockerfile") - err := os.WriteFile(dockerfilePath, []byte(fmt.Sprintf(content, publicKey)), 0644) - require.NoError(t, err) - return dockerfilePath -} - -func buildImage(t *testing.T, ctx context.Context, dockerfilePath, imageTag string) { - cli, err := testcontainers.NewDockerClientWithOpts(ctx) - require.NoError(t, err) - - tar, err := archive.TarWithOptions(filepath.Dir(dockerfilePath), &archive.TarOptions{}) - require.NoError(t, err) - - opts := types.ImageBuildOptions{ - Dockerfile: filepath.Base(dockerfilePath), - Tags: []string{imageTag}, - Remove: true, - } - - resp, err := cli.ImageBuild(ctx, tar, opts) - require.NoError(t, err) - defer resp.Body.Close() - - _, err = io.Copy(io.Discard, resp.Body) - require.NoError(t, err) -}