Skip to content

Commit

Permalink
refactor: Integrate SSH tunnel functionality into SSH package
Browse files Browse the repository at this point in the history
  • Loading branch information
yarlson committed Dec 4, 2024
1 parent b7b9073 commit 305a4d9
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 553 deletions.
66 changes: 64 additions & 2 deletions pkg/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package ssh

import (
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
Expand All @@ -26,11 +28,22 @@ func NewSSHClientWithKey(host string, port int, user string, key []byte) (*ssh.C

addr := fmt.Sprintf("%s:%d", host, port)

client, err := ssh.Dial("tcp", addr, config)
conn, err := net.DialTimeout("tcp", addr, config.Timeout)
if err != nil {
return nil, fmt.Errorf("failed to connect: %v", err)
return nil, fmt.Errorf("failed to dial TCP connection: %v", err)
}

if tcpConn, ok := conn.(*net.TCPConn); ok {
_ = tcpConn.SetKeepAlive(true)
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
}

sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
return nil, fmt.Errorf("failed to establish SSH connection: %v", err)
}

client := ssh.NewClient(sshConn, chans, reqs)
return client, nil
}

Expand Down Expand Up @@ -115,3 +128,52 @@ func getSSHDir() (string, error) {

return filepath.Join(home, ".ssh"), nil
}

// CreateSSHTunnel establishes an SSH tunnel from a local port to a remote address through an SSH server.
// It listens on localPort and forwards connections to remoteAddr via the SSH server at host:port.
// Authentication is done using the provided user and key (private key bytes).
func CreateSSHTunnel(host string, port int, user string, key []byte, localPort string, remoteAddr string) error {
client, err := NewSSHClientWithKey(host, port, user, key)
if err != nil {
return fmt.Errorf("failed to establish SSH connection: %v", err)
}

localListener, err := net.Listen("tcp", "localhost:"+localPort)
if err != nil {
return fmt.Errorf("failed to listen on local port %s: %v", localPort, err)
}
defer localListener.Close()
fmt.Printf("Listening on localhost:%s, forwarding to %s via %s:%d\n", localPort, remoteAddr, host, port)

for {
localConn, err := localListener.Accept()
if err != nil {
fmt.Printf("Failed to accept local connection: %v\n", err)
continue
}

remoteConn, err := client.Dial("tcp", remoteAddr)
if err != nil {
fmt.Printf("Failed to dial remote address %s: %v\n", remoteAddr, err)
localConn.Close()
continue
}

go func() {
defer localConn.Close()
defer remoteConn.Close()

go func() {
_, err := io.Copy(remoteConn, localConn)
if err != nil {
fmt.Printf("Error copying from local to remote: %v\n", err)
}
}()

_, err := io.Copy(localConn, remoteConn)
if err != nil {
fmt.Printf("Error copying from remote to local: %v\n", err)
}
}()
}
}
305 changes: 0 additions & 305 deletions pkg/tunnel/tunnel.go

This file was deleted.

Loading

0 comments on commit 305a4d9

Please sign in to comment.