Skip to content

Commit

Permalink
moving unshared utils into other files
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvkh committed Jul 12, 2024
1 parent b5241af commit 09cf8ce
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 150 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"]
line-length = 100
src = ["src", "tests"]
[tool.ruff.lint]
select = ["E", "F"]
select = ["E", "F", "I"]

[tool.pyright]
include = ["src", "tests"]
Expand Down
41 changes: 40 additions & 1 deletion src/torchrunx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,43 @@
from .launcher import launch
from .utils import slurm_hosts, slurm_workers


def slurm_hosts() -> list[str]:
"""Retrieves hostnames of Slurm-allocated nodes.
:return: Hostnames of nodes in current Slurm allocation
:rtype: list[str]
"""
import os
import subprocess

# TODO: sanity check SLURM variables, commands
assert "SLURM_JOB_ID" in os.environ
return (
subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]])
.decode()
.strip()
.split("\n")
)


def slurm_workers() -> int:
"""
| Determines number of workers per node in current Slurm allocation using
| the ``SLURM_JOB_GPUS`` or ``SLURM_CPUS_ON_NODE`` environmental variables.
:return: The implied number of workers per node
:rtype: int
"""
import os

# TODO: sanity check SLURM variables, commands
assert "SLURM_JOB_ID" in os.environ
if "SLURM_JOB_GPUS" in os.environ:
# TODO: is it possible to allocate uneven GPUs across nodes?
return len(os.environ["SLURM_JOB_GPUS"].split(","))
else:
# TODO: should we assume that we plan to do one worker per CPU?
return int(os.environ["SLURM_CPUS_ON_NODE"])


__all__ = ["launch", "slurm_hosts", "slurm_workers"]
4 changes: 2 additions & 2 deletions src/torchrunx/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from argparse import ArgumentParser

from . import agent
from .agent import main
from .utils import LauncherAgentGroup

if __name__ == "__main__":
Expand All @@ -18,4 +18,4 @@
rank=args.rank,
)

agent.main(launcher_agent_group)
main(launcher_agent_group)
26 changes: 25 additions & 1 deletion src/torchrunx/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import socket
import sys
from dataclasses import dataclass
from typing import Callable, Literal

Expand All @@ -17,7 +18,6 @@
AgentStatus,
LauncherAgentGroup,
LauncherPayload,
WorkerTee,
get_open_port,
)

Expand All @@ -42,6 +42,30 @@ def from_bytes(cls, serialized: bytes) -> Self:
return cloudpickle.loads(serialized)


class WorkerTee(object):
def __init__(self, name: os.PathLike | str, mode: str):
self.file = open(name, mode)
self.stdout = sys.stdout
sys.stdout = self

def __enter__(self):
return self

def __exit__(self, exception_type, exception_value, exception_traceback):
self.__del__()

def __del__(self):
sys.stdout = self.stdout
self.file.close()

def write(self, data):
self.file.write(data)
self.stdout.write(data)

def flush(self):
self.file.flush()


def entrypoint(serialized_worker_args: bytes):
worker_args = WorkerArgs.from_bytes(serialized_worker_args)

Expand Down
54 changes: 52 additions & 2 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,79 @@

import datetime
import fnmatch
import io
import ipaddress
import itertools
import os
import socket
import subprocess
import sys
import time
from collections import ChainMap
from functools import partial
from multiprocessing import Process
from pathlib import Path
from typing import Any, Callable, Literal

import fabric
import torch.distributed as dist

from .utils import (
AgentPayload,
AgentStatus,
LauncherAgentGroup,
LauncherPayload,
execute_command,
get_open_port,
monitor_log,
)


def is_localhost(hostname_or_ip: str) -> bool:
# check if host is "loopback" address (i.e. designated to send to self)
try:
ip = ipaddress.ip_address(hostname_or_ip)
except ValueError:
ip = ipaddress.ip_address(socket.gethostbyname(hostname_or_ip))
if ip.is_loopback:
return True
# else compare local interface addresses between host and localhost
host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(ip), None)]
localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)]
return len(set(host_addrs) & set(localhost_addrs)) > 0


def execute_command(
command: str,
hostname: str,
ssh_config_file: str | os.PathLike | None = None,
outfile: str | os.PathLike | None = None,
) -> None:
# TODO: permit different stderr / stdout
if is_localhost(hostname):
_outfile = subprocess.DEVNULL
if outfile is not None:
_outfile = open(outfile, "w")
subprocess.Popen(command, shell=True, stdout=_outfile, stderr=_outfile)
else:
with fabric.Connection(
host=hostname, config=fabric.Config(runtime_ssh_path=ssh_config_file)
) as conn:
if outfile is None:
outfile = "/dev/null"
conn.run(f"{command} >> {outfile} 2>&1 &", asynchronous=True)


def monitor_log(log_file: Path):
log_file.touch()
f = open(log_file, "r")
print(f.read())
f.seek(0, io.SEEK_END)
while True:
new = f.read()
if len(new) != 0:
print(new)
time.sleep(0.1)


def launch(
func: Callable,
func_kwargs: dict[str, Any],
Expand Down
Loading

0 comments on commit 09cf8ce

Please sign in to comment.