From 1bf24fb9522d642d8620798d80e362001f35f111 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 20 Oct 2024 01:30:26 -0400 Subject: [PATCH] added agentkillederror --- src/torchrunx/__init__.py | 3 +- src/torchrunx/launcher.py | 262 ++++++++++++++++++++------------------ 2 files changed, 137 insertions(+), 128 deletions(-) diff --git a/src/torchrunx/__init__.py b/src/torchrunx/__init__.py index c5a7d6f..ca19796 100644 --- a/src/torchrunx/__init__.py +++ b/src/torchrunx/__init__.py @@ -1,7 +1,8 @@ -from .launcher import Launcher, LaunchResult, launch +from .launcher import AgentKilledError, Launcher, LaunchResult, launch from .logging_utils import add_filter_to_handler, file_handler, stream_handler __all__ = [ + "AgentKilledError", "Launcher", "launch", "LaunchResult", diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index c27cd67..ddd2dea 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -25,130 +25,8 @@ from .utils import AgentStatus, LauncherAgentGroup, LauncherPayload, WorkerException, get_open_port -def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: - if hostnames == "auto": - return auto_hosts() - if hostnames == "slurm": - return slurm_hosts() - return hostnames - - -def resolve_workers_per_host( - workers_per_host: int | list[int] | Literal["auto", "slurm"], - num_hosts: int, -) -> list[int]: - if workers_per_host == "auto": - workers_per_host = auto_workers() - elif workers_per_host == "slurm": - workers_per_host = slurm_workers() - - if isinstance(workers_per_host, int): - workers_per_host = [workers_per_host] * num_hosts - elif len(workers_per_host) != num_hosts: - msg = "len(workers_per_host) != len(hostnames)" - raise ValueError(msg) - - return workers_per_host - - -def build_logging_server( - log_handlers: list[Handler] | Literal["auto"] | None, - launcher_hostname: str, - hostnames: list[str], - workers_per_host: list[int], - log_dir: str | os.PathLike, - log_level: int, -) -> LogRecordSocketReceiver: - if log_handlers is None: - log_handlers = [] - elif log_handlers == "auto": - log_handlers = default_handlers( - hostnames=hostnames, - workers_per_host=workers_per_host, - log_dir=log_dir, - log_level=log_level, - ) - - return LogRecordSocketReceiver( - host=launcher_hostname, - port=get_open_port(), - handlers=log_handlers, - ) - - -def build_launch_command( - launcher_hostname: str, - launcher_port: int, - logger_port: int, - world_size: int, - rank: int, - env_vars: tuple[str, ...], - env_file: str | os.PathLike | None, -) -> str: - # shlex.quote prevents shell injection here (resolves S602 in execute_command) - - commands = [] - - current_dir = shlex.quote(str(Path.cwd())) - commands.append("cd " + current_dir) - - env_exports = [] - for k, v in os.environ.items(): - if any(fnmatch.fnmatch(k, e) for e in env_vars): - env_exports.append(shlex.quote(f"{k}={v}")) - - if len(env_exports) > 0: - commands.append("export " + " ".join(env_exports)) - - if env_file is not None: - commands.append("source " + shlex.quote(str(env_file))) - - python = shlex.quote(sys.executable) - launcher_hostname = shlex.quote(launcher_hostname) - - commands.append( - f"{python} -u -m torchrunx " - f"--launcher-hostname {launcher_hostname} " - f"--launcher-port {launcher_port} " - f"--logger-port {logger_port} " - f"--world-size {world_size} " - f"--rank {rank}", - ) - - return " && ".join(commands) - - -def execute_command( - command: str, - hostname: str, - ssh_config_file: str | os.PathLike | None = None, -) -> None: - is_localhost = True - _hostname_or_ip = hostname - try: - _ip = ipaddress.ip_address(_hostname_or_ip) - except ValueError: - _ip = ipaddress.ip_address(socket.gethostbyname(_hostname_or_ip)) - if not _ip.is_loopback: - # compare local interface addresses between host and localhost - _host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(_ip), None)] - _localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)] - is_localhost = len(set(_host_addrs) & set(_localhost_addrs)) > 0 - - if is_localhost: - # S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.9/library/subprocess.html#security-considerations) - # Made sure to shlex.quote arguments in build_command to prevent shell injection - subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # noqa: S602 - else: - runtime_ssh_path = ssh_config_file - if isinstance(ssh_config_file, os.PathLike): - runtime_ssh_path = str(ssh_config_file) - - with fabric.Connection( - host=hostname, - config=fabric.Config(runtime_ssh_path=runtime_ssh_path), - ) as conn: - conn.run(f"{command} >> /dev/null 2>&1 &", asynchronous=True) +class AgentKilledError(Exception): + pass @dataclass @@ -263,8 +141,11 @@ def run( # noqa: C901, PLR0912 # loop to monitor agent statuses (until failed or done) while True: - # raises RuntimeError if communication timeout due to death of any agent - agent_statuses = launcher_agent_group.sync_agent_statuses(status=None) + try: + agent_statuses = launcher_agent_group.sync_agent_statuses(status=None) + except RuntimeError as e: + # occurs if any agent dies and communication times out + raise AgentKilledError from e # raises specific exception if any agent fails for s in agent_statuses: @@ -334,7 +215,8 @@ def launch( :param default_env_vars: A list of environmental variables to be copied from the launcher process to workers. Allows for bash pattern matching syntax. :param extra_env_vars: Additional, user-specified variables to copy. :param env_file: A file (like ``.env``) with additional environment variables to copy. - :raises RuntimeError: May fail if ``torch.distributed`` not available or communication timeout between nodes + :raises RuntimeError: If ``torch.distributed`` not available + :raises AgentKilledError: If any agent is killed :raises Exception: Propagates exceptions raised in worker processes """ # noqa: E501 return Launcher( @@ -409,3 +291,129 @@ def value(self, rank: int) -> Any: msg = f"Rank {rank} larger than world_size" raise ValueError(msg) + + +def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: + if hostnames == "auto": + return auto_hosts() + if hostnames == "slurm": + return slurm_hosts() + return hostnames + + +def resolve_workers_per_host( + workers_per_host: int | list[int] | Literal["auto", "slurm"], + num_hosts: int, +) -> list[int]: + if workers_per_host == "auto": + workers_per_host = auto_workers() + elif workers_per_host == "slurm": + workers_per_host = slurm_workers() + + if isinstance(workers_per_host, int): + workers_per_host = [workers_per_host] * num_hosts + elif len(workers_per_host) != num_hosts: + msg = "len(workers_per_host) != len(hostnames)" + raise ValueError(msg) + + return workers_per_host + + +def build_logging_server( + log_handlers: list[Handler] | Literal["auto"] | None, + launcher_hostname: str, + hostnames: list[str], + workers_per_host: list[int], + log_dir: str | os.PathLike, + log_level: int, +) -> LogRecordSocketReceiver: + if log_handlers is None: + log_handlers = [] + elif log_handlers == "auto": + log_handlers = default_handlers( + hostnames=hostnames, + workers_per_host=workers_per_host, + log_dir=log_dir, + log_level=log_level, + ) + + return LogRecordSocketReceiver( + host=launcher_hostname, + port=get_open_port(), + handlers=log_handlers, + ) + + +def build_launch_command( + launcher_hostname: str, + launcher_port: int, + logger_port: int, + world_size: int, + rank: int, + env_vars: tuple[str, ...], + env_file: str | os.PathLike | None, +) -> str: + # shlex.quote prevents shell injection here (resolves S602 in execute_command) + + commands = [] + + current_dir = shlex.quote(str(Path.cwd())) + commands.append("cd " + current_dir) + + env_exports = [] + for k, v in os.environ.items(): + if any(fnmatch.fnmatch(k, e) for e in env_vars): + env_exports.append(shlex.quote(f"{k}={v}")) + + if len(env_exports) > 0: + commands.append("export " + " ".join(env_exports)) + + if env_file is not None: + commands.append("source " + shlex.quote(str(env_file))) + + python = shlex.quote(sys.executable) + launcher_hostname = shlex.quote(launcher_hostname) + + commands.append( + f"{python} -u -m torchrunx " + f"--launcher-hostname {launcher_hostname} " + f"--launcher-port {launcher_port} " + f"--logger-port {logger_port} " + f"--world-size {world_size} " + f"--rank {rank}", + ) + + return " && ".join(commands) + + +def execute_command( + command: str, + hostname: str, + ssh_config_file: str | os.PathLike | None = None, +) -> None: + is_localhost = True + _hostname_or_ip = hostname + try: + _ip = ipaddress.ip_address(_hostname_or_ip) + except ValueError: + _ip = ipaddress.ip_address(socket.gethostbyname(_hostname_or_ip)) + if not _ip.is_loopback: + # compare local interface addresses between host and localhost + _host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(_ip), None)] + _localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)] + is_localhost = len(set(_host_addrs) & set(_localhost_addrs)) > 0 + + if is_localhost: + # S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.9/library/subprocess.html#security-considerations) + # Made sure to shlex.quote arguments in build_command to prevent shell injection + subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # noqa: S602 + else: + runtime_ssh_path = ssh_config_file + if isinstance(ssh_config_file, os.PathLike): + runtime_ssh_path = str(ssh_config_file) + + with fabric.Connection( + host=hostname, + config=fabric.Config(runtime_ssh_path=runtime_ssh_path), + ) as conn: + conn.run(f"{command} >> /dev/null 2>&1 &", asynchronous=True)