From 19956d782f4833d1dc3485ebc7aab455713511f1 Mon Sep 17 00:00:00 2001 From: Yar Kravtsov Date: Wed, 4 Dec 2024 08:00:03 +0200 Subject: [PATCH] feat: Add SSH tunneling functionality for remote dependencies --- README.md | 52 +++++++++++++ cmd/tunnels.go | 139 +++++++++++++++++++++++++++++++++++ pkg/config/config.go | 2 + pkg/deployment/deployment.go | 13 +++- pkg/ssh/ssh.go | 115 ++++++++++++++++++++++------- 5 files changed, 289 insertions(+), 32 deletions(-) create mode 100644 cmd/tunnels.go diff --git a/README.md b/README.md index 32cceb7..a3af553 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ FTL is a deployment tool that reduces complexity for projects that don't require - Integrated Nginx reverse proxy - Multi-provider support (Hetzner, DigitalOcean, Linode, custom servers) - Fetch and stream logs from deployed services +- Establish SSH tunnels to remote dependencies ## Installation @@ -156,6 +157,48 @@ ftl logs [service] [flags] ftl logs -n 50 ``` +### 5. Create SSH Tunnels + +Establish SSH tunnels for your dependencies, allowing local access to services running on your server: + +```bash +ftl tunnels [flags] +``` + +This command will: + +- Connect to your server via SSH +- Forward local ports to remote ports for all dependencies defined in your configuration +- Allow you to interact with your dependencies locally as if they were running on your machine + +#### Flags + +- `-s`, `--server`: (Optional) Specify the server name or index to connect to, if multiple servers are defined. + +#### Examples + +- Establish tunnels to all dependency ports: + + ```bash + ftl tunnels + ``` + +- Specify a server to connect to (if multiple servers are configured): + + ```bash + ftl tunnels --server my-project.example.com + ``` + +Press `Ctrl+C` to terminate the tunnels when you're done. + +#### Purpose + +The `ftl tunnels` command is useful for: + +- Accessing dependency services (e.g., databases) running on your server from your local machine +- Simplifying local development by connecting to remote services without modifying your code +- Testing and debugging your application against live dependencies + ## How It Works FTL manages deployments and log retrieval through these main components: @@ -183,6 +226,13 @@ FTL manages deployments and log retrieval through these main components: - Supports real-time streaming with the `-f` flag - Allows limiting the number of log lines with the `-n` flag +### SSH Tunnels (`ftl tunnels`) + +- Connects to your server via SSH +- Establishes port forwarding from local ports to remote ports for all defined dependencies +- Maintains active tunnels with keep-alive packets +- Allows for graceful shutdown upon user interruption (Ctrl+C) + ## Use Cases ### Suitable For @@ -241,6 +291,7 @@ dependencies: volumes: [string] # Volume mappings (format: "volume:path") env: # Environment variables - KEY=value + ports: [int] # Ports to expose for SSH tunneling volumes: [string] # Named volumes list ``` @@ -266,6 +317,7 @@ FTL supports two forms of environment variable substitution in the configuration - **Environment Variables**: Set environment variables for services and dependencies, with support for environment variable substitution. - **Service Dependencies**: Specify dependent services and their configurations. - **Routing Rules**: Define custom routing paths and whether to strip prefixes. +- **SSH Tunnels**: Specify ports in dependencies to enable SSH tunneling for local access. ## Example Projects diff --git a/cmd/tunnels.go b/cmd/tunnels.go new file mode 100644 index 0000000..db6e891 --- /dev/null +++ b/cmd/tunnels.go @@ -0,0 +1,139 @@ +package cmd + +import ( + "context" + "fmt" + "github.com/yarlson/ftl/pkg/config" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/spf13/cobra" + + "github.com/yarlson/ftl/pkg/console" + "github.com/yarlson/ftl/pkg/ssh" +) + +var tunnelsCmd = &cobra.Command{ + Use: "tunnels", + Short: "Create SSH tunnels for dependencies", + Long: `Create SSH tunnels for all dependencies defined in ftl.yaml, +forwarding local ports to remote ports.`, + Run: runTunnels, +} + +func init() { + rootCmd.AddCommand(tunnelsCmd) + + tunnelsCmd.Flags().StringP("server", "s", "", "Server name or index to connect to") +} + +func runTunnels(cmd *cobra.Command, args []string) { + sm := console.NewSpinnerManager() + sm.Start() + defer sm.Stop() + + spinner := sm.AddSpinner("tunnels", "Establishing SSH tunnels") + + cfg, err := parseConfig("ftl.yaml") + if err != nil { + spinner.ErrorWithMessagef("Failed to parse config file: %v", err) + return + } + + serverName, _ := cmd.Flags().GetString("server") + serverConfig, err := selectServer(cfg, serverName) + if err != nil { + spinner.ErrorWithMessagef("Server selection failed: %v", err) + return + } + + user := serverConfig.User + + tunnels, err := collectDependencyTunnels(cfg) + if err != nil { + spinner.ErrorWithMessagef("Failed to collect dependencies: %v", err) + return + } + if len(tunnels) == 0 { + spinner.ErrorWithMessage("No dependencies with ports found in the configuration.") + return + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var wg sync.WaitGroup + errorChan := make(chan error, len(tunnels)) + + for _, tunnel := range tunnels { + wg.Add(1) + go func(t TunnelConfig) { + defer wg.Done() + err := ssh.CreateSSHTunnel(ctx, serverConfig.Host, serverConfig.Port, user, serverConfig.SSHKey, t.LocalPort, t.RemoteAddr) + if err != nil { + errorChan <- fmt.Errorf("Tunnel %s -> %s failed: %v", t.LocalPort, t.RemoteAddr, err) + } + }(tunnel) + } + + go func() { + wg.Wait() + close(errorChan) + }() + + select { + case err := <-errorChan: + spinner.ErrorWithMessagef("Failed to establish tunnels: %v", err) + return + case <-time.After(2 * time.Second): + spinner.Complete() + } + + sm.Stop() + + console.Success("SSH tunnels established. Press Ctrl+C to exit.") + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + <-sigs + + console.Info("Shutting down tunnels...") + cancel() + time.Sleep(1 * time.Second) +} + +func selectServer(cfg *config.Config, serverName string) (config.Server, error) { + if serverName != "" { + for _, srv := range cfg.Servers { + if srv.Host == serverName || srv.User == serverName { + return srv, nil + } + } + return config.Server{}, fmt.Errorf("server not found in configuration: %s", serverName) + } else if len(cfg.Servers) == 1 { + return cfg.Servers[0], nil + } else { + return config.Server{}, fmt.Errorf("multiple servers defined. Please specify a server using the --server flag") + } +} + +type TunnelConfig struct { + LocalPort string + RemoteAddr string +} + +func collectDependencyTunnels(cfg *config.Config) ([]TunnelConfig, error) { + var tunnels []TunnelConfig + for _, dep := range cfg.Dependencies { + for _, port := range dep.Ports { + tunnels = append(tunnels, TunnelConfig{ + LocalPort: fmt.Sprintf("%d", port), + RemoteAddr: fmt.Sprintf("localhost:%d", port), + }) + } + } + return tunnels, nil +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 9fe67cc..e6327a8 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -54,6 +54,7 @@ type Service struct { Env []string `yaml:"env"` Forwards []string `yaml:"forwards"` Recreate bool `yaml:"recreate"` + LocalPorts []int } type HealthCheck struct { @@ -73,6 +74,7 @@ type Dependency struct { Image string `yaml:"image" validate:"required"` Volumes []string `yaml:"volumes" validate:"dive,volume_reference"` Env []string `yaml:"env" validate:"dive"` + Ports []int `yaml:"ports" validate:"dive,min=1,max=65535"` } type Volume struct { diff --git a/pkg/deployment/deployment.go b/pkg/deployment/deployment.go index 7d206e2..5cace5b 100644 --- a/pkg/deployment/deployment.go +++ b/pkg/deployment/deployment.go @@ -240,10 +240,11 @@ func (d *Deployment) startProxy(ctx context.Context, project string, cfg *config func (d *Deployment) startDependency(project string, dependency *config.Dependency) error { service := &config.Service{ - Name: dependency.Name, - Image: dependency.Image, - Volumes: dependency.Volumes, - Env: dependency.Env, + Name: dependency.Name, + Image: dependency.Image, + Volumes: dependency.Volumes, + Env: dependency.Env, + LocalPorts: dependency.Ports, } if err := d.deployService(project, service); err != nil { return fmt.Errorf("failed to start container for %s: %v", dependency.Image, err) @@ -408,6 +409,10 @@ func (d *Deployment) createContainer(project string, service *config.Service, su args = append(args, "--health-timeout", fmt.Sprintf("%ds", int(service.HealthCheck.Timeout.Seconds()))) } + for _, port := range service.LocalPorts { + args = append(args, "-p", fmt.Sprintf("127.0.0.1:%d:%d", port, port)) + } + if len(service.Forwards) > 0 { for _, forward := range service.Forwards { args = append(args, "-p", forward) diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index f2f4380..7f60303 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -1,12 +1,14 @@ package ssh import ( + "context" "fmt" "io" "net" "os" "path/filepath" "strings" + "sync" "time" "golang.org/x/crypto/ssh" @@ -131,49 +133,106 @@ func getSSHDir() (string, error) { // CreateSSHTunnel establishes an SSH tunnel from a local port to a remote address through an SSH server. // It listens on localPort and forwards connections to remoteAddr via the SSH server at host:port. -// Authentication is done using the provided user and key (private key bytes). -func CreateSSHTunnel(host string, port int, user string, key []byte, localPort string, remoteAddr string) error { - client, err := NewSSHClientWithKey(host, port, user, key) +// Authentication is done using the provided user and keyPath (path to the private key file). +func CreateSSHTunnel(ctx context.Context, host string, port int, user, keyPath, localPort string, remoteAddr string) error { + client, _, err := FindKeyAndConnectWithUser(host, port, user, keyPath) if err != nil { return fmt.Errorf("failed to establish SSH connection: %v", err) } + defer client.Close() + + // Start keep-alive routine + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + _, _, err := client.SendRequest("keepalive@openssh.com", true, nil) + if err != nil { + fmt.Printf("Failed to send keep-alive packet: %v\n", err) + return + } + case <-ctx.Done(): + return + } + } + }() localListener, err := net.Listen("tcp", "localhost:"+localPort) if err != nil { return fmt.Errorf("failed to listen on local port %s: %v", localPort, err) } defer localListener.Close() - fmt.Printf("Listening on localhost:%s, forwarding to %s via %s:%d\n", localPort, remoteAddr, host, port) for { - localConn, err := localListener.Accept() - if err != nil { - fmt.Printf("Failed to accept local connection: %v\n", err) - continue - } + select { + case <-ctx.Done(): + return nil + default: + localConn, err := localListener.Accept() + if err != nil { + fmt.Printf("Failed to accept local connection: %v\n", err) + continue + } + + remoteConn, err := client.Dial("tcp", remoteAddr) + if err != nil { + fmt.Printf("Failed to dial remote address %s: %v\n", remoteAddr, err) + localConn.Close() + continue + } - remoteConn, err := client.Dial("tcp", remoteAddr) - if err != nil { - fmt.Printf("Failed to dial remote address %s: %v\n", remoteAddr, err) - localConn.Close() - continue + // Handle the connection in a separate goroutine + go handleConnection(localConn, remoteConn) } + } +} - go func() { - defer localConn.Close() - defer remoteConn.Close() +// handleConnection copies data between local and remote connections +func handleConnection(localConn, remoteConn net.Conn) { + defer localConn.Close() + defer remoteConn.Close() + + // Use WaitGroup to wait for both directions to finish + var wg sync.WaitGroup + wg.Add(2) + + // Copy from local to remote + go func() { + defer wg.Done() + _, err := io.Copy(remoteConn, localConn) + if err != nil && !isClosedNetworkError(err) { + fmt.Printf("Error copying from local to remote: %v\n", err) + } + }() + + // Copy from remote to local + go func() { + defer wg.Done() + _, err := io.Copy(localConn, remoteConn) + if err != nil && !isClosedNetworkError(err) { + fmt.Printf("Error copying from remote to local: %v\n", err) + } + }() - go func() { - _, err := io.Copy(remoteConn, localConn) - if err != nil { - fmt.Printf("Error copying from local to remote: %v\n", err) - } - }() + // Wait for both copying goroutines to finish + wg.Wait() +} - _, err := io.Copy(localConn, remoteConn) - if err != nil { - fmt.Printf("Error copying from remote to local: %v\n", err) - } - }() +// isClosedNetworkError checks if the error is due to closed network connection +func isClosedNetworkError(err error) bool { + if err == nil { + return false + } + if err == io.EOF { + return true + } + if netErr, ok := err.(*net.OpError); ok && netErr.Err.Error() == "use of closed network connection" { + return true + } + if strings.Contains(err.Error(), "use of closed network connection") { + return true } + return false }