Skip to content

Commit

Permalink
added agentkillederror
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvkh committed Oct 20, 2024
1 parent 82450a2 commit 1bf24fb
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 128 deletions.
3 changes: 2 additions & 1 deletion src/torchrunx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .launcher import Launcher, LaunchResult, launch
from .launcher import AgentKilledError, Launcher, LaunchResult, launch
from .logging_utils import add_filter_to_handler, file_handler, stream_handler

__all__ = [
"AgentKilledError",
"Launcher",
"launch",
"LaunchResult",
Expand Down
262 changes: 135 additions & 127 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,130 +25,8 @@
from .utils import AgentStatus, LauncherAgentGroup, LauncherPayload, WorkerException, get_open_port


def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]:
if hostnames == "auto":
return auto_hosts()
if hostnames == "slurm":
return slurm_hosts()
return hostnames


def resolve_workers_per_host(
workers_per_host: int | list[int] | Literal["auto", "slurm"],
num_hosts: int,
) -> list[int]:
if workers_per_host == "auto":
workers_per_host = auto_workers()
elif workers_per_host == "slurm":
workers_per_host = slurm_workers()

if isinstance(workers_per_host, int):
workers_per_host = [workers_per_host] * num_hosts
elif len(workers_per_host) != num_hosts:
msg = "len(workers_per_host) != len(hostnames)"
raise ValueError(msg)

return workers_per_host


def build_logging_server(
log_handlers: list[Handler] | Literal["auto"] | None,
launcher_hostname: str,
hostnames: list[str],
workers_per_host: list[int],
log_dir: str | os.PathLike,
log_level: int,
) -> LogRecordSocketReceiver:
if log_handlers is None:
log_handlers = []
elif log_handlers == "auto":
log_handlers = default_handlers(
hostnames=hostnames,
workers_per_host=workers_per_host,
log_dir=log_dir,
log_level=log_level,
)

return LogRecordSocketReceiver(
host=launcher_hostname,
port=get_open_port(),
handlers=log_handlers,
)


def build_launch_command(
launcher_hostname: str,
launcher_port: int,
logger_port: int,
world_size: int,
rank: int,
env_vars: tuple[str, ...],
env_file: str | os.PathLike | None,
) -> str:
# shlex.quote prevents shell injection here (resolves S602 in execute_command)

commands = []

current_dir = shlex.quote(str(Path.cwd()))
commands.append("cd " + current_dir)

env_exports = []
for k, v in os.environ.items():
if any(fnmatch.fnmatch(k, e) for e in env_vars):
env_exports.append(shlex.quote(f"{k}={v}"))

if len(env_exports) > 0:
commands.append("export " + " ".join(env_exports))

if env_file is not None:
commands.append("source " + shlex.quote(str(env_file)))

python = shlex.quote(sys.executable)
launcher_hostname = shlex.quote(launcher_hostname)

commands.append(
f"{python} -u -m torchrunx "
f"--launcher-hostname {launcher_hostname} "
f"--launcher-port {launcher_port} "
f"--logger-port {logger_port} "
f"--world-size {world_size} "
f"--rank {rank}",
)

return " && ".join(commands)


def execute_command(
command: str,
hostname: str,
ssh_config_file: str | os.PathLike | None = None,
) -> None:
is_localhost = True
_hostname_or_ip = hostname
try:
_ip = ipaddress.ip_address(_hostname_or_ip)
except ValueError:
_ip = ipaddress.ip_address(socket.gethostbyname(_hostname_or_ip))
if not _ip.is_loopback:
# compare local interface addresses between host and localhost
_host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(_ip), None)]
_localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)]
is_localhost = len(set(_host_addrs) & set(_localhost_addrs)) > 0

if is_localhost:
# S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.9/library/subprocess.html#security-considerations)
# Made sure to shlex.quote arguments in build_command to prevent shell injection
subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # noqa: S602
else:
runtime_ssh_path = ssh_config_file
if isinstance(ssh_config_file, os.PathLike):
runtime_ssh_path = str(ssh_config_file)

with fabric.Connection(
host=hostname,
config=fabric.Config(runtime_ssh_path=runtime_ssh_path),
) as conn:
conn.run(f"{command} >> /dev/null 2>&1 &", asynchronous=True)
class AgentKilledError(Exception):
pass


@dataclass
Expand Down Expand Up @@ -263,8 +141,11 @@ def run( # noqa: C901, PLR0912
# loop to monitor agent statuses (until failed or done)

while True:
# raises RuntimeError if communication timeout due to death of any agent
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)
try:
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)
except RuntimeError as e:
# occurs if any agent dies and communication times out
raise AgentKilledError from e

# raises specific exception if any agent fails
for s in agent_statuses:
Expand Down Expand Up @@ -334,7 +215,8 @@ def launch(
:param default_env_vars: A list of environmental variables to be copied from the launcher process to workers. Allows for bash pattern matching syntax.
:param extra_env_vars: Additional, user-specified variables to copy.
:param env_file: A file (like ``.env``) with additional environment variables to copy.
:raises RuntimeError: May fail if ``torch.distributed`` not available or communication timeout between nodes
:raises RuntimeError: If ``torch.distributed`` not available
:raises AgentKilledError: If any agent is killed
:raises Exception: Propagates exceptions raised in worker processes
""" # noqa: E501
return Launcher(
Expand Down Expand Up @@ -409,3 +291,129 @@ def value(self, rank: int) -> Any:

msg = f"Rank {rank} larger than world_size"
raise ValueError(msg)


def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]:
if hostnames == "auto":
return auto_hosts()
if hostnames == "slurm":
return slurm_hosts()
return hostnames


def resolve_workers_per_host(
workers_per_host: int | list[int] | Literal["auto", "slurm"],
num_hosts: int,
) -> list[int]:
if workers_per_host == "auto":
workers_per_host = auto_workers()
elif workers_per_host == "slurm":
workers_per_host = slurm_workers()

if isinstance(workers_per_host, int):
workers_per_host = [workers_per_host] * num_hosts
elif len(workers_per_host) != num_hosts:
msg = "len(workers_per_host) != len(hostnames)"
raise ValueError(msg)

return workers_per_host


def build_logging_server(
log_handlers: list[Handler] | Literal["auto"] | None,
launcher_hostname: str,
hostnames: list[str],
workers_per_host: list[int],
log_dir: str | os.PathLike,
log_level: int,
) -> LogRecordSocketReceiver:
if log_handlers is None:
log_handlers = []
elif log_handlers == "auto":
log_handlers = default_handlers(
hostnames=hostnames,
workers_per_host=workers_per_host,
log_dir=log_dir,
log_level=log_level,
)

return LogRecordSocketReceiver(
host=launcher_hostname,
port=get_open_port(),
handlers=log_handlers,
)


def build_launch_command(
launcher_hostname: str,
launcher_port: int,
logger_port: int,
world_size: int,
rank: int,
env_vars: tuple[str, ...],
env_file: str | os.PathLike | None,
) -> str:
# shlex.quote prevents shell injection here (resolves S602 in execute_command)

commands = []

current_dir = shlex.quote(str(Path.cwd()))
commands.append("cd " + current_dir)

env_exports = []
for k, v in os.environ.items():
if any(fnmatch.fnmatch(k, e) for e in env_vars):
env_exports.append(shlex.quote(f"{k}={v}"))

if len(env_exports) > 0:
commands.append("export " + " ".join(env_exports))

if env_file is not None:
commands.append("source " + shlex.quote(str(env_file)))

python = shlex.quote(sys.executable)
launcher_hostname = shlex.quote(launcher_hostname)

commands.append(
f"{python} -u -m torchrunx "
f"--launcher-hostname {launcher_hostname} "
f"--launcher-port {launcher_port} "
f"--logger-port {logger_port} "
f"--world-size {world_size} "
f"--rank {rank}",
)

return " && ".join(commands)


def execute_command(
command: str,
hostname: str,
ssh_config_file: str | os.PathLike | None = None,
) -> None:
is_localhost = True
_hostname_or_ip = hostname
try:
_ip = ipaddress.ip_address(_hostname_or_ip)
except ValueError:
_ip = ipaddress.ip_address(socket.gethostbyname(_hostname_or_ip))
if not _ip.is_loopback:
# compare local interface addresses between host and localhost
_host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(_ip), None)]
_localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)]
is_localhost = len(set(_host_addrs) & set(_localhost_addrs)) > 0

if is_localhost:
# S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.9/library/subprocess.html#security-considerations)
# Made sure to shlex.quote arguments in build_command to prevent shell injection
subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # noqa: S602
else:
runtime_ssh_path = ssh_config_file
if isinstance(ssh_config_file, os.PathLike):
runtime_ssh_path = str(ssh_config_file)

with fabric.Connection(
host=hostname,
config=fabric.Config(runtime_ssh_path=runtime_ssh_path),
) as conn:
conn.run(f"{command} >> /dev/null 2>&1 &", asynchronous=True)

0 comments on commit 1bf24fb

Please sign in to comment.