diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a33e85d22..e75b01876 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -60,7 +60,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ macos-latest, ubuntu-latest] + os: [ macos-latest, ubuntu-latest, windows-latest ] python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ] steps: - uses: actions/checkout@v4 @@ -75,7 +75,8 @@ jobs: with: name: frontend-build path: src/dstack/_internal/server/statics - - name: Run pytest + - name: Run pytest on POSIX + if: matrix.os != 'windows-latest' # Skip Postgres tests on macos since macos runner doesn't have Docker. # Skip Postgres tests for Python 3.8 since testcontainers<4 doesn't support asyncpg correctly. run: | @@ -84,6 +85,10 @@ jobs: RUNPOSTGRES="--runpostgres" fi pytest src/tests --runui $RUNPOSTGRES + - name: Run pytest on Windows + if: matrix.os == 'windows-latest' + run: | + pytest src/tests --runui --runpostgres update-get-dstack: if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 65de0fc76..a1c2c4d79 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -51,7 +51,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ macos-latest, ubuntu-latest] + os: [ macos-latest, ubuntu-latest, windows-latest ] python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ] steps: - uses: actions/checkout@v4 @@ -66,7 +66,8 @@ jobs: with: name: frontend-build path: src/dstack/_internal/server/statics - - name: Run pytest + - name: Run pytest on POSIX + if: matrix.os != 'windows-latest' # Skip Postgres tests on macos since macos runner doesn't have Docker. # Skip Postgres tests for Python 3.8 since testcontainers<4 doesn't support asyncpg correctly. run: | @@ -75,6 +76,10 @@ jobs: RUNPOSTGRES="--runpostgres" fi pytest src/tests --runui $RUNPOSTGRES + - name: Run pytest on Windows + if: matrix.os == 'windows-latest' + run: | + pytest src/tests --runui --runpostgres runner-test: defaults: diff --git a/docs/docs/installation/index.md b/docs/docs/installation/index.md index 6cea105cc..90e2f329d 100644 --- a/docs/docs/installation/index.md +++ b/docs/docs/installation/index.md @@ -9,6 +9,18 @@ To use the open-source version of `dstack` with your own cloud accounts or on-pr > If you don't want to host the `dstack` server (or want to access GPU marketplace), > skip installation and proceed to [dstack Sky :material-arrow-top-right-thin:{ .external }](https://sky.dstack.ai){:target="_blank"}. +## Prerequisites + +`dstack` works on Linux, macOS, and Windows, with one exception — the `dstack server` functionality is not currently supported on Windows. + +`dstack` requires Git and OpenSSH client to operate. + +On Windows, install [Git for Windows](https://git-scm.com/download/win), it contains both Git and OpenSSH. During the installation, +make sure the following options are checked: + +- _“Git from the command line and also from 3-rd party software”_ or _“Use Git and optional Unix tools from the Command Prompt”_ +- _“Use bundled OpenSSH”_ + ## Configure backends To use `dstack` with your own cloud accounts, or Kubernetes, diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py index 72ffba656..4b492f660 100644 --- a/src/dstack/_internal/cli/main.py +++ b/src/dstack/_internal/cli/main.py @@ -18,7 +18,8 @@ from dstack._internal.cli.commands.volume import VolumeCommand from dstack._internal.cli.utils.common import _colors, console from dstack._internal.cli.utils.updates import check_for_updates -from dstack._internal.core.errors import ClientError, CLIError, ConfigurationError +from dstack._internal.core.errors import ClientError, CLIError, ConfigurationError, SSHError +from dstack._internal.core.services.ssh.client import get_ssh_client_info from dstack._internal.utils.logging import get_logger from dstack.version import __version__ as version @@ -72,8 +73,9 @@ def main(): args.unknown = unknown_args try: check_for_updates() + get_ssh_client_info() args.func(args) - except (ClientError, CLIError, ConfigurationError) as e: + except (ClientError, CLIError, ConfigurationError, SSHError) as e: console.print(f"[error]{escape(str(e))}[/]") logger.debug(e, exc_info=True) exit(1) diff --git a/src/dstack/_internal/compat.py b/src/dstack/_internal/compat.py new file mode 100644 index 000000000..4e17099eb --- /dev/null +++ b/src/dstack/_internal/compat.py @@ -0,0 +1,3 @@ +import os + +IS_WINDOWS = os.name == "nt" diff --git a/src/dstack/_internal/core/services/ssh/attach.py b/src/dstack/_internal/core/services/ssh/attach.py index 77cd83aeb..73f18cfb8 100644 --- a/src/dstack/_internal/core/services/ssh/attach.py +++ b/src/dstack/_internal/core/services/ssh/attach.py @@ -2,40 +2,64 @@ import re import subprocess import time -from typing import Optional, Tuple +from pathlib import Path +from typing import Optional +from dstack._internal.compat import IS_WINDOWS from dstack._internal.core.errors import SSHError from dstack._internal.core.models.instances import SSHConnectionParams from dstack._internal.core.services.configs import ConfigManager +from dstack._internal.core.services.ssh.client import get_ssh_client_info from dstack._internal.core.services.ssh.ports import PortsLock -from dstack._internal.core.services.ssh.tunnel import ( - FilePath, - SSHTunnel, - ports_to_forwarded_sockets, +from dstack._internal.core.services.ssh.tunnel import SSHTunnel, ports_to_forwarded_sockets +from dstack._internal.utils.path import FilePath, PathLike +from dstack._internal.utils.ssh import ( + include_ssh_config, + normalize_path, + update_ssh_config, ) -from dstack._internal.utils.path import PathLike -from dstack._internal.utils.ssh import get_ssh_config, include_ssh_config, update_ssh_config class SSHAttach: - @staticmethod - def reuse_control_sock_path_and_port_locks(run_name: str) -> Optional[Tuple[str, PortsLock]]: - ssh_config_path = str(ConfigManager().dstack_ssh_config_path) - host_config = get_ssh_config(ssh_config_path, run_name) - if host_config and host_config.get("ControlPath"): + @classmethod + def get_control_sock_path(cls, run_name: str) -> Path: + return ConfigManager().dstack_ssh_dir / f"{run_name}.control.sock" + + @classmethod + def reuse_ports_lock(cls, run_name: str) -> Optional[PortsLock]: + if not get_ssh_client_info().supports_control_socket: + raise SSHError("Unsupported SSH client") + control_sock_path = normalize_path(cls.get_control_sock_path(run_name)) + filter_prefix: str + output: bytes + if IS_WINDOWS: + filter_prefix = "powershell" + output = subprocess.check_output( + [ + "powershell", + "-c", + f"""Get-CimInstance Win32_Process \ + -filter "commandline like '%-S {control_sock_path}%'" \ + | select -ExpandProperty CommandLine \ + """, + ] + ) + else: + filter_prefix = "grep" ps = subprocess.Popen(("ps", "-A", "-o", "command"), stdout=subprocess.PIPE) - control_sock_path = host_config.get("ControlPath") - output = subprocess.check_output(("grep", control_sock_path), stdin=ps.stdout) + output = subprocess.check_output( + ["grep", "--", f"-S {control_sock_path}"], stdin=ps.stdout + ) ps.wait() - commands = list( - filter(lambda s: not s.startswith("grep"), output.decode().strip().split("\n")) + commands = list( + filter(lambda s: not s.startswith(filter_prefix), output.decode().strip().split("\n")) + ) + if commands: + port_pattern = r"-L (?:[\w.-]+:)?(\d+):localhost:(\d+)" + matches = re.findall(port_pattern, commands[0]) + return PortsLock( + {int(target_port): int(local_port) for local_port, target_port in matches} ) - if commands: - port_pattern = r"-L (?:[\w.-]+:)?(\d+):localhost:(\d+)" - matches = re.findall(port_pattern, commands[0]) - return control_sock_path, PortsLock( - {int(target_port): int(local_port) for local_port, target_port in matches} - ) return None def __init__( @@ -48,17 +72,21 @@ def __init__( run_name: str, dockerized: bool, ssh_proxy: Optional[SSHConnectionParams] = None, - control_sock_path: Optional[str] = None, local_backend: bool = False, bind_address: Optional[str] = None, ): self._ports_lock = ports_lock self.ports = ports_lock.dict() self.run_name = run_name - self.ssh_config_path = str(ConfigManager().dstack_ssh_config_path) + self.ssh_config_path = ConfigManager().dstack_ssh_config_path + control_sock_path = self.get_control_sock_path(run_name) + # Cast all path-like values used in configs to FilePath instances for automatic + # path normalization in :func:`update_ssh_config`. + self.control_sock_path = FilePath(control_sock_path) + self.identity_file = FilePath(id_rsa_path) self.tunnel = SSHTunnel( destination=run_name, - identity=FilePath(id_rsa_path), + identity=self.identity_file, forwarded_sockets=ports_to_forwarded_sockets( ports=self.ports, bind_local=bind_address or "localhost", @@ -72,7 +100,7 @@ def __init__( "HostName": hostname, "Port": ssh_port, "User": user, - "IdentityFile": id_rsa_path, + "IdentityFile": self.identity_file, "IdentitiesOnly": "yes", "StrictHostKeyChecking": "no", "UserKnownHostsFile": "/dev/null", @@ -82,7 +110,7 @@ def __init__( "HostName": ssh_proxy.hostname, "Port": ssh_proxy.port, "User": ssh_proxy.username, - "IdentityFile": id_rsa_path, + "IdentityFile": self.identity_file, "IdentitiesOnly": "yes", "StrictHostKeyChecking": "no", "UserKnownHostsFile": "/dev/null", @@ -92,13 +120,10 @@ def __init__( "HostName": "localhost", "Port": 10022, "User": "root", # TODO(#1535): support non-root images properly - "IdentityFile": id_rsa_path, + "IdentityFile": self.identity_file, "IdentitiesOnly": "yes", "StrictHostKeyChecking": "no", "UserKnownHostsFile": "/dev/null", - "ControlPath": self.tunnel.control_sock_path, - "ControlMaster": "auto", - "ControlPersist": "yes", "ProxyJump": f"{run_name}-host", } elif ssh_proxy is not None: @@ -106,17 +131,21 @@ def __init__( "HostName": hostname, "Port": ssh_port, "User": user, - "IdentityFile": id_rsa_path, + "IdentityFile": self.identity_file, "IdentitiesOnly": "yes", "StrictHostKeyChecking": "no", "UserKnownHostsFile": "/dev/null", - "ControlPath": self.tunnel.control_sock_path, - "ControlMaster": "auto", - "ControlPersist": "yes", "ProxyJump": f"{run_name}-jump-host", } else: self.container_config = None + if self.container_config is not None and get_ssh_client_info().supports_multiplexing: + self.container_config.update( + { + "ControlMaster": "auto", + "ControlPath": self.control_sock_path, + } + ) def attach(self): include_ssh_config(self.ssh_config_path) diff --git a/src/dstack/_internal/core/services/ssh/client.py b/src/dstack/_internal/core/services/ssh/client.py new file mode 100644 index 000000000..738201782 --- /dev/null +++ b/src/dstack/_internal/core/services/ssh/client.py @@ -0,0 +1,147 @@ +import os +import re +import shutil +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple + +from dstack._internal.compat import IS_WINDOWS +from dstack._internal.core.errors import SSHError +from dstack._internal.utils.path import PathLike + + +@dataclass +class SSHClientInfo: + # Path to `ssh` executable + path: Path + # Full version including portable suffix, e.g., "9.6p1" + version: str + # Base version not including portable suffix, e.g., (9, 6) + version_tuple: Tuple[int, ...] + # True if OpenSSH_for_Windows (Microsoft's OpenSSH Portable fork) + for_windows: bool + # Supports Control{Master,Path,Persist} directives, but only for control purposes + # (e.g., `ssh -O exit`), cannot be used for connection multiplexing + supports_control_socket: bool + # Supports Control{Master,Path,Persist} for connection multiplexing + supports_multiplexing: bool + # Supports ForkAfterAuthentication (`ssh -f`) + supports_background_mode: bool + + RAW_VERSION_REGEX = re.compile( + r"OpenSSH_(?Pfor_Windows_)?(?P[\d.]+)(?Pp\d+)?", + flags=re.I, + ) + + @classmethod + def from_raw_version(cls, raw_version: str, path: Path) -> "SSHClientInfo": + match = cls.RAW_VERSION_REGEX.match(raw_version) + if not match: + raise ValueError("no match") + for_windows, base_version, portable_version = match.group( + "for_windows", "base_version", "portable_version" + ) + if portable_version: + version = f"{base_version}{portable_version}" + else: + version = base_version + return cls( + path=path, + version=version, + version_tuple=tuple(map(int, base_version.split("."))), + for_windows=bool(for_windows), + supports_control_socket=(not for_windows), + supports_multiplexing=(not IS_WINDOWS), + supports_background_mode=(not for_windows), + ) + + +def inspect_ssh_client(path: PathLike) -> SSHClientInfo: + """ + Inspects various aspects of a given SSH client — version, "flavor", features — by executing + and parsing `ssh -V`. + + :param path: a path of the ssh executable. + :return: :class:`SSHClientInfo` named tuple. + :raise dstack._internal.core.errors.SSHError: if path does not exist, `ssh -V` returns + non-zero exit status, or `ssh -V` output does not match the pattern. + """ + path = Path(path).resolve() + try: + cp = subprocess.run( + [path, "-V"], + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + text=True, + timeout=5, + ) + except (OSError, subprocess.SubprocessError) as e: + raise SSHError(f"failed to execute `{path} -V`: {e}") from e + output = cp.stderr + if cp.returncode != 0: + raise SSHError(f"`{path} -V` returned non-zero exit status {cp.returncode}: {output}") + try: + return SSHClientInfo.from_raw_version(output, path) + except ValueError: + raise SSHError(f"failed to parse `{path} -V` output: {output}") + + +def find_ssh_client() -> Optional[Path]: + path_str = os.getenv("DSTACK_SSH_CLIENT") + if path_str: + return Path(path_str) + if not IS_WINDOWS: + path_str = shutil.which("ssh") + if path_str: + return Path(path_str) + return None + # First, we check for ssh bundled with Git for Windows (MSYS2/MinGW-w64-built OpenSSH Portable) + # as a preferred client. It supports ForkAfterAuthentication; ControlMaster is only partially + # supported, we don't use it. + git_path_str = shutil.which("git") + if git_path_str: + # C:\Program Files\Git\cmd\git.exe -> C:\Program Files\Git\usr\bin\ssh.exe + path = Path(git_path_str).parent.parent / "usr" / "bin" / "ssh.exe" + if path.exists(): + return path + # Then we check for OpenSSH for Windows (Microsoft's fork of OpenSSH Portable). + # It does not support some features, namely ControlMaster and ForkAfterAuthentication. + windir_str = os.getenv("WINDIR") + if windir_str: + path = Path(windir_str) / "System32" / "OpenSSH" / "ssh.exe" + if path.exists(): + return path + # Finally, we check for any ssh client in PATH. It can be anything, it can be not compatible, + # so we use it only as a last resort. + path_str = shutil.which("ssh") + if path_str: + return Path(path_str) + return None + + +_ssh_client_info: Optional[SSHClientInfo] = None + + +def get_ssh_client_info() -> SSHClientInfo: + """ + Returns :class:`SSHClientInfo` for the default SSH client. The result is cached. + + :return: :class:`SSHClientInfo` named tuple. + :raise dstack._internal.core.errors.SSHError: if no ssh client found or the underlying + :func:`inspect_ssh_client` raises an error. + """ + global _ssh_client_info + if _ssh_client_info is not None: + return _ssh_client_info + path = find_ssh_client() + if path is None: + if IS_WINDOWS: + msg = "SSH client not found, install Git for Windows." + else: + msg = "SSH client not found." + raise SSHError(msg) + _ssh_client_info = inspect_ssh_client(path) + if _ssh_client_info.for_windows: + raise SSHError("OpenSSH for Windows is not supported, install Git for Windows.") + return _ssh_client_info diff --git a/src/dstack/_internal/core/services/ssh/ports.py b/src/dstack/_internal/core/services/ssh/ports.py index 3d81d0f11..6bdd901b5 100644 --- a/src/dstack/_internal/core/services/ssh/ports.py +++ b/src/dstack/_internal/core/services/ssh/ports.py @@ -2,6 +2,7 @@ import socket from typing import Dict, List, Optional +from dstack._internal.compat import IS_WINDOWS from dstack._internal.core.errors import DstackError from dstack._internal.core.models.configurations import PortMapping @@ -52,7 +53,6 @@ def acquire(self) -> "PortsLock": def release(self) -> Dict[int, int]: mapping = self.dict() for sock in self.sockets.values(): - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.close() self.sockets = {} return mapping @@ -73,6 +73,10 @@ def __str__(self) -> str: def _listen(port: int) -> Optional[socket.socket]: try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if IS_WINDOWS: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) + else: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(("", port)) return sock except socket.error as e: diff --git a/src/dstack/_internal/core/services/ssh/tunnel.py b/src/dstack/_internal/core/services/ssh/tunnel.py index 68897ebc7..addf1b31e 100644 --- a/src/dstack/_internal/core/services/ssh/tunnel.py +++ b/src/dstack/_internal/core/services/ssh/tunnel.py @@ -5,13 +5,15 @@ import subprocess import tempfile from dataclasses import dataclass -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Literal, Optional, Union from dstack._internal.core.errors import SSHError from dstack._internal.core.models.instances import SSHConnectionParams from dstack._internal.core.services.ssh import get_ssh_error +from dstack._internal.core.services.ssh.client import get_ssh_client_info from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import FilePath, FilePathOrContent, PathLike +from dstack._internal.utils.ssh import normalize_path logger = get_logger(__name__) SSH_TIMEOUT = 15 @@ -64,7 +66,7 @@ def __init__( reverse_forwarded_sockets: Iterable[SocketPair] = (), control_sock_path: Optional[PathLike] = None, options: Dict[str, str] = SSH_DEFAULT_OPTIONS, - ssh_config_path: str = "none", + ssh_config_path: Union[PathLike, Literal["none"]] = "none", port: Optional[int] = None, ssh_proxy: Optional[SSHConnectionParams] = None, ): @@ -79,23 +81,13 @@ def __init__( self.reverse_forwarded_sockets = list(reverse_forwarded_sockets) self.options = options self.port = port - self.ssh_config_path = ssh_config_path + self.ssh_config_path = normalize_path(ssh_config_path) self.ssh_proxy = ssh_proxy - - self.temp_dir, self.identity_path, self.control_sock_path = self._init_temp_dir_if_needed( - identity, control_sock_path - ) - - @staticmethod - def _init_temp_dir_if_needed( - identity: FilePathOrContent, control_sock_path: Optional[PathLike] - ) -> Tuple[Optional[tempfile.TemporaryDirectory], PathLike, PathLike]: - if control_sock_path is not None and isinstance(identity, FilePath): - return None, identity.path, control_sock_path - temp_dir = tempfile.TemporaryDirectory() + self.temp_dir = temp_dir if control_sock_path is None: control_sock_path = os.path.join(temp_dir.name, "control.sock") + self.control_sock_path = normalize_path(control_sock_path) if isinstance(identity, FilePath): identity_path = identity.path else: @@ -104,22 +96,47 @@ def _init_temp_dir_if_needed( identity_path, opener=lambda path, flags: os.open(path, flags, 0o600), mode="w" ) as f: f.write(identity.content) - - return temp_dir, identity_path, control_sock_path + self.identity_path = normalize_path(identity_path) + self.log_path = normalize_path(os.path.join(temp_dir.name, "tunnel.log")) + self.ssh_client_info = get_ssh_client_info() + self.ssh_exec_path = str(self.ssh_client_info.path) def open_command(self) -> List[str]: + # Some information about how `ssh(1)` handles options: + # 1. Command-line options override config options regardless of the order of the arguments: + # `ssh -S sock2 -F config` with `ControlPath sock1` in the config -> the control socket + # path is `sock2`. + # 2. First argument wins: + # `ssh -S sock2 -S sock1` -> the control socket path is `sock2`. + # 3. `~` is not expanded in the arguments, but expanded in the config file. command = [ - "ssh", + self.ssh_exec_path, "-F", self.ssh_config_path, - "-f", # go to background after connecting - "-N", # do not run commands on remote - "-M", # use the control socket in master mode - "-S", - str(self.control_sock_path), "-i", - str(self.identity_path), + self.identity_path, + "-E", + self.log_path, + "-N", # do not run commands on remote ] + if self.ssh_client_info.supports_background_mode: + command += ["-f"] # go to background after successful authentication + else: + raise SSHError("Unsupported SSH client") + if self.ssh_client_info.supports_control_socket: + # It's safe to use ControlMaster even if the ssh client does not support multiplexing + # as long as we don't allow more than one tunnel to the specific host to be running. + # We use this feature for control only (see :meth:`close_command`). + command += [ + # Not `-M`, which means `ControlMaster=yes`, to avoid spawning uncontrollable + # ssh instances if more than one tunnel is started (precaution). + "-o", + "ControlMaster=auto", + "-S", + self.control_sock_path, + ] + else: + raise SSHError("Unsupported SSH client") if self.port is not None: command += ["-p", str(self.port)] for k, v in self.options.items(): @@ -134,21 +151,21 @@ def open_command(self) -> List[str]: return command def close_command(self) -> List[str]: - return ["ssh", "-S", str(self.control_sock_path), "-O", "exit", self.destination] + return [self.ssh_exec_path, "-S", self.control_sock_path, "-O", "exit", self.destination] def check_command(self) -> List[str]: - return ["ssh", "-S", str(self.control_sock_path), "-O", "check", self.destination] + return [self.ssh_exec_path, "-S", self.control_sock_path, "-O", "check", self.destination] def exec_command(self) -> List[str]: - return ["ssh", "-S", str(self.control_sock_path), self.destination] + return [self.ssh_exec_path, "-S", self.control_sock_path, self.destination] def proxy_command(self) -> Optional[List[str]]: if self.ssh_proxy is None: return None return [ - "ssh", + self.ssh_exec_path, "-i", - str(self.identity_path), + self.identity_path, "-W", "%h:%p", "-o", @@ -161,37 +178,43 @@ def proxy_command(self) -> Optional[List[str]]: ] def open(self) -> None: - # Using stderr=subprocess.PIPE may block subprocess.run. - # Redirect stderr to file to get ssh error message - with tempfile.NamedTemporaryFile(delete=False) as f: - try: - r = subprocess.run( - self.open_command(), stdout=subprocess.DEVNULL, stderr=f, timeout=SSH_TIMEOUT - ) - except subprocess.TimeoutExpired as e: - msg = f"SSH tunnel to {self.destination} did not open in {SSH_TIMEOUT} seconds" - logger.debug(msg) - raise SSHError(msg) from e - with open(f.name, "r+b") as f: - error = f.read() - os.remove(f.name) + # We cannot use `stderr=subprocess.PIPE` here since the forked process (daemon) does not + # close standard streams if ProxyJump is used, therefore we will wait EOF from the pipe + # as long as the daemon exists. + self._remove_log_file() + try: + r = subprocess.run( + self.open_command(), + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + timeout=SSH_TIMEOUT, + ) + except subprocess.TimeoutExpired as e: + msg = f"SSH tunnel to {self.destination} did not open in {SSH_TIMEOUT} seconds" + logger.debug(msg) + raise SSHError(msg) from e if r.returncode == 0: return - logger.debug("SSH tunnel failed: %s", error) - raise get_ssh_error(error) + stderr = self._read_log_file() + logger.debug("SSH tunnel failed: %s", stderr) + raise get_ssh_error(stderr) async def aopen(self) -> None: + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self._remove_log_file) proc = await asyncio.create_subprocess_exec( - *self.open_command(), stdout=subprocess.DEVNULL, stderr=subprocess.PIPE + *self.open_command(), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL ) try: - _, stderr = await asyncio.wait_for(proc.communicate(), SSH_TIMEOUT) + await asyncio.wait_for(proc.communicate(), SSH_TIMEOUT) except asyncio.TimeoutError as e: + proc.kill() msg = f"SSH tunnel to {self.destination} did not open in {SSH_TIMEOUT} seconds" logger.debug(msg) raise SSHError(msg) from e if proc.returncode == 0: return + stderr = await loop.run_in_executor(None, self._read_log_file) logger.debug("SSH tunnel failed: %s", stderr) raise get_ssh_error(stderr) @@ -228,6 +251,18 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() + def _read_log_file(self) -> bytes: + with open(self.log_path, "rb") as f: + return f.read() + + def _remove_log_file(self) -> None: + try: + os.remove(self.log_path) + except FileNotFoundError: + pass + except OSError as e: + logger.debug("Failed to remove SSH tunnel log file %s: %s", self.log_path, e) + def ports_to_forwarded_sockets( ports: Dict[int, int], bind_local: str = "localhost" diff --git a/src/dstack/_internal/utils/ssh.py b/src/dstack/_internal/utils/ssh.py index 46b6b1181..0992b69a3 100644 --- a/src/dstack/_internal/utils/ssh.py +++ b/src/dstack/_internal/utils/ssh.py @@ -5,7 +5,7 @@ import sys import tempfile from pathlib import Path -from typing import Dict, Optional +from typing import Dict, Optional, Union import paramiko from filelock import FileLock @@ -13,8 +13,9 @@ from paramiko.pkey import PKey, PublicBlob from paramiko.ssh_exception import SSHException +from dstack._internal.compat import IS_WINDOWS from dstack._internal.utils.logging import get_logger -from dstack._internal.utils.path import PathLike +from dstack._internal.utils.path import FilePath, PathLike logger = get_logger(__name__) @@ -51,6 +52,46 @@ def try_ssh_key_passphrase(identity_file: PathLike, passphrase: str = "") -> boo return r.returncode == 0 +def normalize_path(path: PathLike, *, collapse_user: bool = False) -> str: + """ + Converts a path to the most compatible format. + On Windows, replaces backslashes with slashes. + Additionally, if `collapse_user` is `True`, tries to replace the user home part of the path + with `~`. + + :param path: Path object or string + :param collapse_user: try to replace user home prefix with `~`. `False` by default. + :return: Normalized path as string + """ + if collapse_user: + # The following "reverse" expanduser operation not only makes paths shorter and "nicer", + # but also fixes one specific issue with OpenSSH bundled with Git for Windows (MSYS2), + # see :func:`include_ssh_config` for details. + try: + path = Path(path).relative_to(Path.home()) + path = f"~/{path}" + except ValueError: + pass + if IS_WINDOWS: + # Git for Windows ssh (based on MSYS2, but there may be subtle differences between + # vanilla MSYS2 ssh and Git for Windows ssh) supports: + # * C:\\Users\\User + # * C:/Users/User + # * /c/Users/User + # does not support: + # * C:\Users\User (as pathllib.WindowsPath is rendered) + # OpenSSH_for_Windows supports: + # * C:\Users\User + # * C:\\Users\\User + # * C:/Users/User + # does not support: + # * /c/User/User + # We use C:/Users/User format as the safest (supported by both ssh builds; + # no backslash-escaping pitfalls) + return str(path).replace("\\", "/") + return str(path) + + def include_ssh_config(path: PathLike, ssh_config_path: PathLike = default_ssh_config_path): """ Adds Include entry on top of the default ssh config file @@ -59,6 +100,10 @@ def include_ssh_config(path: PathLike, ssh_config_path: PathLike = default_ssh_c """ ssh_config_path = os.path.expanduser(ssh_config_path) Path(ssh_config_path).parent.mkdir(0o600, parents=True, exist_ok=True) + # MSYS2 OpenSSH accepts only /c/Users/User/... format in the Include directive (although + # it accepts C:/Users/User/... in other directives). We try to work around this issue + # converting the path to ~/.dstack/... format. + path = normalize_path(path, collapse_user=True) include = f"Include {path}\n" content = "" with FileLock(str(ssh_config_path) + ".lock"): @@ -103,7 +148,7 @@ def get_ssh_config(path: PathLike, host: str) -> Optional[Dict[str, str]]: return None -def update_ssh_config(path: PathLike, host: str, options: Dict[str, str]): +def update_ssh_config(path: PathLike, host: str, options: Dict[str, Union[str, FilePath]]): Path(path).parent.mkdir(parents=True, exist_ok=True) with FileLock(str(path) + ".lock"): copy_mode = True @@ -121,6 +166,8 @@ def update_ssh_config(path: PathLike, host: str, options: Dict[str, str]): if options: f.write(f"Host {host}\n") for k, v in options.items(): + if isinstance(v, FilePath): + v = normalize_path(v.path, collapse_user=True) f.write(f" {k} {v}\n") f.flush() diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 2bda3a401..26cea1752 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -252,11 +252,9 @@ def attach( job = self._run.jobs[0] # TODO(egor-s): pull logs from all replicas? provisioning_data = job.job_submissions[-1].job_provisioning_data - control_sock_path_and_port_locks = SSHAttach.reuse_control_sock_path_and_port_locks( - run_name=self.name - ) + ports_lock = SSHAttach.reuse_ports_lock(run_name=self.name) - if control_sock_path_and_port_locks is None: + if ports_lock is None: if self._ports_lock is None: self._ports_lock = _reserve_ports(job.job_spec) logger.debug( @@ -266,7 +264,7 @@ def attach( self._ports_lock.dict(), ) else: - self._ports_lock = control_sock_path_and_port_locks[1] + self._ports_lock = ports_lock logger.debug( "Reusing the existing tunnel to %s (%s: %s)", self.name, @@ -283,13 +281,10 @@ def attach( run_name=self.name, dockerized=provisioning_data.dockerized, ssh_proxy=provisioning_data.ssh_proxy, - control_sock_path=control_sock_path_and_port_locks[0] - if control_sock_path_and_port_locks - else None, local_backend=provisioning_data.backend == BackendType.LOCAL, bind_address=bind_address, ) - if not control_sock_path_and_port_locks: + if not ports_lock: self._ssh_attach.attach() self._ports_lock = None diff --git a/src/tests/_internal/core/services/ssh/test_client.py b/src/tests/_internal/core/services/ssh/test_client.py new file mode 100644 index 000000000..a5ea64d4a --- /dev/null +++ b/src/tests/_internal/core/services/ssh/test_client.py @@ -0,0 +1,76 @@ +from pathlib import Path + +import pytest + +from dstack._internal.core.services.ssh.client import SSHClientInfo + + +class TestSSHClientInfo: + def test_openbsd(self): + path = Path("/usr/bin/ssh") + info = SSHClientInfo.from_raw_version("OpenSSH_9.7, LibreSSL 3.9.0", path) + assert info == SSHClientInfo( + path=path, + version="9.7", + version_tuple=(9, 7), + for_windows=False, + supports_control_socket=True, + supports_multiplexing=True, + supports_background_mode=True, + ) + + def test_linux(self): + path = Path("/usr/bin/ssh") + info = SSHClientInfo.from_raw_version( + "OpenSSH_9.2p1 Debian-2+deb12u3, OpenSSL 3.0.13 30 Jan 2024", path + ) + assert info == SSHClientInfo( + path=path, + version="9.2p1", + version_tuple=(9, 2), + for_windows=False, + supports_control_socket=True, + supports_multiplexing=True, + supports_background_mode=True, + ) + + def test_macos(self): + path = Path("/usr/bin/ssh") + info = SSHClientInfo.from_raw_version("OpenSSH_9.7p1, LibreSSL 3.3.6", path) + assert info == SSHClientInfo( + path=path, + version="9.7p1", + version_tuple=(9, 7), + for_windows=False, + supports_control_socket=True, + supports_multiplexing=True, + supports_background_mode=True, + ) + + @pytest.mark.windows_only + def test_windows_msys2(self): + path = Path("C:\\Program Files\\Git\\usr\\bin\\ssh.exe") + info = SSHClientInfo.from_raw_version("OpenSSH_9.8p1, OpenSSL 3.2.2 4 Jun 2024", path) + assert info == SSHClientInfo( + path=path, + version="9.8p1", + version_tuple=(9, 8), + for_windows=False, + supports_control_socket=True, + supports_multiplexing=False, + supports_background_mode=True, + ) + + @pytest.mark.windows_only + def test_windows_for_windows(self): + path = Path("C:\\Windows\\System32\\OpenSSH\\ssh.exe") + info = SSHClientInfo.from_raw_version("OpenSSH_for_Windows_8.6p1, LibreSSL 3.4.3", path) + assert info == SSHClientInfo( + path=path, + version="8.6p1", + version_tuple=(8, 6), + for_windows=True, + supports_control_socket=False, + supports_multiplexing=False, + supports_background_mode=False, + ) diff --git a/src/tests/_internal/core/services/ssh/test_tunnel.py b/src/tests/_internal/core/services/ssh/test_tunnel.py index 69d91b0c1..5bd5e1d8b 100644 --- a/src/tests/_internal/core/services/ssh/test_tunnel.py +++ b/src/tests/_internal/core/services/ssh/test_tunnel.py @@ -1,7 +1,9 @@ -import re from pathlib import Path +import pytest + from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.services.ssh.client import SSHClientInfo from dstack._internal.core.services.ssh.tunnel import ( IPSocket, SocketPair, @@ -11,20 +13,31 @@ ) from dstack._internal.utils.path import FileContent, FilePath -SAMPLE_TUNNEL_WITH_ALL_PARAMS = SSHTunnel( - destination="ubuntu@my-server", - identity=FilePath("/home/user/.ssh/id_rsa"), - control_sock_path="/tmp/control.sock", - options={"Opt1": "opt1"}, - ssh_config_path="/home/user/.ssh/config", - port=10022, - ssh_proxy=SSHConnectionParams(hostname="proxy", username="test", port=10022), - forwarded_sockets=[SocketPair(UnixSocket("/1"), UnixSocket("/2"))], - reverse_forwarded_sockets=[SocketPair(UnixSocket("/1"), UnixSocket("/2"))], -) - class TestSSHTunnel: + @pytest.fixture + def ssh_client_info(self, monkeypatch: pytest.MonkeyPatch) -> SSHClientInfo: + ssh_client_info = SSHClientInfo.from_raw_version("OpenSSH_9.7p1", Path("/usr/bin/ssh")) + monkeypatch.setattr( + "dstack._internal.core.services.ssh.client._ssh_client_info", ssh_client_info + ) + return ssh_client_info + + @pytest.fixture + def sample_tunnel_with_all_params(self, ssh_client_info: SSHClientInfo) -> SSHTunnel: + return SSHTunnel( + destination="ubuntu@my-server", + identity=FilePath("/home/user/.ssh/id_rsa"), + control_sock_path="/tmp/control.sock", + options={"Opt1": "opt1"}, + ssh_config_path="/home/user/.ssh/config", + port=10022, + ssh_proxy=SSHConnectionParams(hostname="proxy", username="test", port=10022), + forwarded_sockets=[SocketPair(UnixSocket("/1"), UnixSocket("/2"))], + reverse_forwarded_sockets=[SocketPair(UnixSocket("/1"), UnixSocket("/2"))], + ) + + @pytest.mark.usefixtures("ssh_client_info") def test_open_command_basic(self) -> None: tunnel = SSHTunnel( destination="ubuntu@my-server", @@ -38,17 +51,20 @@ def test_open_command_basic(self) -> None: port=10022, ) assert " ".join(tunnel.open_command()) == ( - "ssh" + "/usr/bin/ssh" " -F /home/user/.ssh/config" - " -f -N -M" - " -S /tmp/control.sock" " -i /home/user/.ssh/id_rsa" + f" -E {tunnel.temp_dir.name}/tunnel.log" + " -N -f" + " -o ControlMaster=auto" + " -S /tmp/control.sock" " -p 10022" " -o Opt1=opt1" " -o Opt2=opt2" " ubuntu@my-server" ) + @pytest.mark.usefixtures("ssh_client_info") def test_open_command_with_temp_identity_file(self) -> None: tunnel = SSHTunnel( destination="ubuntu@my-server", @@ -56,24 +72,39 @@ def test_open_command_with_temp_identity_file(self) -> None: control_sock_path="/tmp/control.sock", options={}, ) - command = " ".join(tunnel.open_command()) - match = re.fullmatch( - r"ssh -F none -f -N -M -S /tmp/control.sock -i (\S+) ubuntu@my-server", command + temp_dir = tunnel.temp_dir.name + assert " ".join(tunnel.open_command()) == ( + "/usr/bin/ssh" + " -F none" + f" -i {temp_dir}/identity" + f" -E {temp_dir}/tunnel.log" + " -N -f" + " -o ControlMaster=auto" + " -S /tmp/control.sock" + " ubuntu@my-server" ) - assert match - assert Path(match.group(1)).read_text() == "my private key" + assert (Path(temp_dir) / "identity").read_text() == "my private key" + @pytest.mark.usefixtures("ssh_client_info") def test_open_command_with_temp_control_socket(self) -> None: tunnel = SSHTunnel( destination="ubuntu@my-server", identity=FilePath("/home/user/.ssh/id_rsa"), options={}, ) - command = " ".join(tunnel.open_command()) - assert re.fullmatch( - r"ssh -F none -f -N -M -S \S+ -i /home/user/.ssh/id_rsa ubuntu@my-server", command + temp_dir = tunnel.temp_dir.name + assert " ".join(tunnel.open_command()) == ( + "/usr/bin/ssh" + " -F none" + " -i /home/user/.ssh/id_rsa" + f" -E {temp_dir}/tunnel.log" + " -N -f" + " -o ControlMaster=auto" + f" -S {temp_dir}/control.sock" + " ubuntu@my-server" ) + @pytest.mark.usefixtures("ssh_client_info") def test_open_command_with_proxy(self) -> None: tunnel = SSHTunnel( destination="ubuntu@my-server", @@ -83,25 +114,29 @@ def test_open_command_with_proxy(self) -> None: ssh_proxy=SSHConnectionParams(hostname="proxy", username="test", port=10022), ) assert tunnel.open_command() == [ - "ssh", + "/usr/bin/ssh", "-F", "none", - "-f", + "-i", + "/home/user/.ssh/id_rsa", + "-E", + f"{tunnel.temp_dir.name}/tunnel.log", "-N", - "-M", + "-f", + "-o", + "ControlMaster=auto", "-S", "/tmp/control.sock", - "-i", - "/home/user/.ssh/id_rsa", "-o", ( "ProxyCommand=" - "ssh -i /home/user/.ssh/id_rsa -W %h:%p -o StrictHostKeyChecking=no" + "/usr/bin/ssh -i /home/user/.ssh/id_rsa -W %h:%p -o StrictHostKeyChecking=no" " -o UserKnownHostsFile=/dev/null -p 10022 test@proxy" ), "ubuntu@my-server", ] + @pytest.mark.usefixtures("ssh_client_info") def test_open_command_with_forwarding(self) -> None: tunnel = SSHTunnel( destination="ubuntu@my-server", @@ -118,11 +153,13 @@ def test_open_command_with_forwarding(self) -> None: ], ) assert " ".join(tunnel.open_command()) == ( - "ssh" + "/usr/bin/ssh" " -F none" - " -f -N -M" - " -S /tmp/control.sock" " -i /home/user/.ssh/id_rsa" + f" -E {tunnel.temp_dir.name}/tunnel.log" + " -N -f" + " -o ControlMaster=auto" + " -S /tmp/control.sock" " -L /tmp/80:localhost:80" " -L 127.0.0.1:8000:[::1]:80" " -R /tmp/remote:/tmp/local" @@ -130,17 +167,31 @@ def test_open_command_with_forwarding(self) -> None: " ubuntu@my-server" ) - def test_check_command(self) -> None: - command = SAMPLE_TUNNEL_WITH_ALL_PARAMS.check_command() - assert command == ["ssh", "-S", "/tmp/control.sock", "-O", "check", "ubuntu@my-server"] + def test_check_command(self, sample_tunnel_with_all_params: SSHTunnel) -> None: + command = sample_tunnel_with_all_params.check_command() + assert command == [ + "/usr/bin/ssh", + "-S", + "/tmp/control.sock", + "-O", + "check", + "ubuntu@my-server", + ] - def test_close_command(self) -> None: - command = SAMPLE_TUNNEL_WITH_ALL_PARAMS.close_command() - assert command == ["ssh", "-S", "/tmp/control.sock", "-O", "exit", "ubuntu@my-server"] + def test_close_command(self, sample_tunnel_with_all_params: SSHTunnel) -> None: + command = sample_tunnel_with_all_params.close_command() + assert command == [ + "/usr/bin/ssh", + "-S", + "/tmp/control.sock", + "-O", + "exit", + "ubuntu@my-server", + ] - def test_exec_command(self) -> None: - command = SAMPLE_TUNNEL_WITH_ALL_PARAMS.exec_command() - assert command == ["ssh", "-S", "/tmp/control.sock", "ubuntu@my-server"] + def test_exec_command(self, sample_tunnel_with_all_params: SSHTunnel) -> None: + command = sample_tunnel_with_all_params.exec_command() + assert command == ["/usr/bin/ssh", "-S", "/tmp/control.sock", "ubuntu@my-server"] def test_ports_to_forwarded_sockets() -> None: diff --git a/src/tests/conftest.py b/src/tests/conftest.py index d537c1d66..9b0ef039a 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -1,3 +1,5 @@ +import os + import pytest @@ -6,6 +8,10 @@ def pytest_configure(config): config.addinivalue_line( "markers", "postgres: mark test as testing Postgres to run only with --runpostgres" ) + config.addinivalue_line( + "markers", "windows: mark test to be run on Windows in addition to POSIX" + ) + config.addinivalue_line("markers", "windows_only: mark test to be run on Windows only") def pytest_addoption(parser): @@ -18,8 +24,17 @@ def pytest_addoption(parser): def pytest_collection_modifyitems(config, items): skip_ui = pytest.mark.skip(reason="need --runui option to run") skip_postgres = pytest.mark.skip(reason="need --runpostgres option to run") + is_windows = os.name == "nt" + skip_posix = pytest.mark.skip(reason="requires POSIX") + skip_windows = pytest.mark.skip(reason="requires Windows") for item in items: if not config.getoption("--runui") and "ui" in item.keywords: item.add_marker(skip_ui) if not config.getoption("--runpostgres") and "postgres" in item.keywords: item.add_marker(skip_postgres) + for_windows_only = "windows_only" in item.keywords + for_windows = for_windows_only or "windows" in item.keywords + if for_windows_only and not is_windows: + item.add_marker(skip_windows) + if not for_windows and is_windows: + item.add_marker(skip_posix)