From 55d09a7cfe5ddff7f95ddd5f02afbb92662bd36b Mon Sep 17 00:00:00 2001 From: jvstme <36324149+jvstme@users.noreply.github.com> Date: Tue, 1 Oct 2024 07:39:16 +0000 Subject: [PATCH] Fix restarting gateway connections (#1746) by deleting the remote unix socket before reopening the connection --- .../server/services/gateways/connection.py | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/src/dstack/_internal/server/services/gateways/connection.py b/src/dstack/_internal/server/services/gateways/connection.py index be2a39111..9f480add4 100644 --- a/src/dstack/_internal/server/services/gateways/connection.py +++ b/src/dstack/_internal/server/services/gateways/connection.py @@ -66,6 +66,7 @@ def __init__(self, ip_address: str, id_rsa: str, server_port: int): ], # reverse_forwarded_sockets are added later in .open() ) + self.tunnel_id = uuid.uuid4() self._client = GatewayClient(uds=self.gateway_socket_path) @staticmethod @@ -79,7 +80,7 @@ async def check_or_restart(self) -> bool: async with self._lock.writer_lock: if not await self.tunnel.acheck(): logger.info("Connection to gateway %s is down, restarting", self.ip_address) - await self.tunnel.aopen() + await self._open_tunnel() return True return False @@ -89,18 +90,26 @@ async def open(self, close_existing_tunnel: bool = False) -> None: # Close remaining tunnel if previous server process died w/o graceful shutdown if await self.tunnel.acheck(): await self.tunnel.aclose() - - self.connection_dir.mkdir(parents=True, exist_ok=True) - await self.tunnel.aopen() - await self.tunnel.aexec(f"mkdir -p {CONNECTIONS_DIR_ON_GATEWAY}") - - self.tunnel.reverse_forwarded_sockets = [ - SocketPair( - local=IPSocket(host="localhost", port=self.server_port), - remote=UnixSocket(path=f"{CONNECTIONS_DIR_ON_GATEWAY}/{uuid.uuid4()}.sock"), - ), - ] - await self.tunnel.aopen() # apply reverse forwarding + await self._open_tunnel() + + async def _open_tunnel(self) -> None: + self.connection_dir.mkdir(parents=True, exist_ok=True) + remote_socket_path = f"{CONNECTIONS_DIR_ON_GATEWAY}/{self.tunnel_id}.sock" + + # open w/o reverse forwarding and make sure reverse forwarding will be possible + self.tunnel.reverse_forwarded_sockets = [] + await self.tunnel.aopen() + await self.tunnel.aexec(f"mkdir -p {CONNECTIONS_DIR_ON_GATEWAY}") + await self.tunnel.aexec(f"rm -f {remote_socket_path}") + + # add reverse forwarding + self.tunnel.reverse_forwarded_sockets = [ + SocketPair( + local=IPSocket(host="localhost", port=self.server_port), + remote=UnixSocket(path=remote_socket_path), + ), + ] + await self.tunnel.aopen() async def close(self) -> None: async with self._lock.writer_lock: