Skip to content

Commit

Permalink
Merge pull request #39 from apoorvkh/moving-functions
Browse files Browse the repository at this point in the history
moving unshared utils into other files
  • Loading branch information
apoorvkh authored Jul 12, 2024
2 parents b5241af + 59b2b4a commit cfe9f42
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 155 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"]
line-length = 100
src = ["src", "tests"]
[tool.ruff.lint]
select = ["E", "F"]
select = ["E", "F", "I"]

[tool.pyright]
include = ["src", "tests"]
Expand Down
43 changes: 42 additions & 1 deletion src/torchrunx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,45 @@
from __future__ import annotations

from .launcher import launch
from .utils import slurm_hosts, slurm_workers


def slurm_hosts() -> list[str]:
"""Retrieves hostnames of Slurm-allocated nodes.
:return: Hostnames of nodes in current Slurm allocation
:rtype: list[str]
"""
import os
import subprocess

# TODO: sanity check SLURM variables, commands
assert "SLURM_JOB_ID" in os.environ
return (
subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]])
.decode()
.strip()
.split("\n")
)


def slurm_workers() -> int:
"""
| Determines number of workers per node in current Slurm allocation using
| the ``SLURM_JOB_GPUS`` or ``SLURM_CPUS_ON_NODE`` environmental variables.
:return: The implied number of workers per node
:rtype: int
"""
import os

# TODO: sanity check SLURM variables, commands
assert "SLURM_JOB_ID" in os.environ
if "SLURM_JOB_GPUS" in os.environ:
# TODO: is it possible to allocate uneven GPUs across nodes?
return len(os.environ["SLURM_JOB_GPUS"].split(","))
else:
# TODO: should we assume that we plan to do one worker per CPU?
return int(os.environ["SLURM_CPUS_ON_NODE"])


__all__ = ["launch", "slurm_hosts", "slurm_workers"]
4 changes: 2 additions & 2 deletions src/torchrunx/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from argparse import ArgumentParser

from . import agent
from .agent import main
from .utils import LauncherAgentGroup

if __name__ == "__main__":
Expand All @@ -18,4 +18,4 @@
rank=args.rank,
)

agent.main(launcher_agent_group)
main(launcher_agent_group)
26 changes: 25 additions & 1 deletion src/torchrunx/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import socket
import sys
from dataclasses import dataclass
from typing import Callable, Literal

Expand All @@ -17,7 +18,6 @@
AgentStatus,
LauncherAgentGroup,
LauncherPayload,
WorkerTee,
get_open_port,
)

Expand All @@ -42,6 +42,30 @@ def from_bytes(cls, serialized: bytes) -> Self:
return cloudpickle.loads(serialized)


class WorkerTee(object):
def __init__(self, name: os.PathLike | str, mode: str):
self.file = open(name, mode)
self.stdout = sys.stdout
sys.stdout = self

def __enter__(self):
return self

def __exit__(self, exception_type, exception_value, exception_traceback):
self.__del__()

def __del__(self):
sys.stdout = self.stdout
self.file.close()

def write(self, data):
self.file.write(data)
self.stdout.write(data)

def flush(self):
self.file.flush()


def entrypoint(serialized_worker_args: bytes):
worker_args = WorkerArgs.from_bytes(serialized_worker_args)

Expand Down
66 changes: 59 additions & 7 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,79 @@

import datetime
import fnmatch
import io
import ipaddress
import itertools
import os
import socket
import subprocess
import sys
import time
from collections import ChainMap
from functools import partial
from multiprocessing import Process
from pathlib import Path
from typing import Any, Callable, Literal

import fabric
import torch.distributed as dist

from .utils import (
AgentPayload,
AgentStatus,
LauncherAgentGroup,
LauncherPayload,
execute_command,
get_open_port,
monitor_log,
)


def is_localhost(hostname_or_ip: str) -> bool:
# check if host is "loopback" address (i.e. designated to send to self)
try:
ip = ipaddress.ip_address(hostname_or_ip)
except ValueError:
ip = ipaddress.ip_address(socket.gethostbyname(hostname_or_ip))
if ip.is_loopback:
return True
# else compare local interface addresses between host and localhost
host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(ip), None)]
localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)]
return len(set(host_addrs) & set(localhost_addrs)) > 0


def execute_command(
command: str,
hostname: str,
ssh_config_file: str | os.PathLike | None = None,
outfile: str | os.PathLike | None = None,
) -> None:
# TODO: permit different stderr / stdout
if is_localhost(hostname):
_outfile = subprocess.DEVNULL
if outfile is not None:
_outfile = open(outfile, "w")
subprocess.Popen(command, shell=True, stdout=_outfile, stderr=_outfile)
else:
with fabric.Connection(
host=hostname, config=fabric.Config(runtime_ssh_path=ssh_config_file)
) as conn:
if outfile is None:
outfile = "/dev/null"
conn.run(f"{command} >> {outfile} 2>&1 &", asynchronous=True)


def monitor_log(log_file: Path):
log_file.touch()
f = open(log_file, "r")
print(f.read())
f.seek(0, io.SEEK_END)
while True:
new = f.read()
if len(new) != 0:
print(new)
time.sleep(0.1)


def launch(
func: Callable,
func_kwargs: dict[str, Any],
Expand Down Expand Up @@ -83,16 +133,18 @@ def launch(

# launch command

env_export_string = ""
env_exports = []
for k, v in os.environ.items():
for e in env_vars:
if any(fnmatch.fnmatch(k, e)):
env_exports.append(f"{k}={v}")
if any(fnmatch.fnmatch(k, e) for e in env_vars):
env_exports.append(f"{k}={v}")

env_export_string = ""
if len(env_exports) > 0:
env_export_string = f"export {' '.join(env_exports)} && "

env_file_string = f"source {env_file} && " if env_file is not None else ""
env_file_string = ""
if env_file is not None:
env_file_string = f"source {env_file} && "

launcher_hostname = socket.getfqdn()
launcher_port = get_open_port()
Expand Down
Loading

0 comments on commit cfe9f42

Please sign in to comment.