Skip to content

Commit

Permalink
misc edits
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvkh committed Jul 12, 2024
1 parent b3a27fb commit b5241af
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
2 changes: 1 addition & 1 deletion examples/submitit_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def main():
output_dir = "output",
do_train = True,
per_device_train_batch_size = 16,
max_steps = 100,
max_steps = 20,
)

trainer = Trainer(
Expand Down
31 changes: 19 additions & 12 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,14 @@ def launch(

# launch command

env_export_string = " ".join(
f'{k}="{v}"' for k, v in os.environ.items() if any(fnmatch.fnmatch(k, e) for e in env_vars)
)
if env_export_string != "":
env_export_string = f"export {env_export_string} && "
env_export_string = ""
env_exports = []
for k, v in os.environ.items():
for e in env_vars:
if any(fnmatch.fnmatch(k, e)):
env_exports.append(f"{k}={v}")
if len(env_exports) > 0:
env_export_string = f"export {' '.join(env_exports)} && "

env_file_string = f"source {env_file} && " if env_file is not None else ""

Expand All @@ -108,7 +111,7 @@ def launch(

log_dir = Path(log_dir)
log_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.datetime.now().strftime("%y-%m-%d-%H%M%S")
timestamp = datetime.datetime.now().isoformat(timespec="seconds")
agent_log_files = [log_dir / f"{timestamp}_{hostname}.log" for hostname in hostnames]

# start process to read from agent 0 log
Expand Down Expand Up @@ -136,11 +139,15 @@ def launch(

# build and sync payloads between launcher and agents

cumulative_workers = [0] + list(itertools.accumulate(workers_per_host))
worker_world_size = cumulative_workers[-1]
worker_global_ranks = [ # list of worker ranks per host
list(range(cumulative_workers[n], cumulative_workers[n + 1])) for n in range(num_hosts)
]
_cumulative_workers = [0] + list(itertools.accumulate(workers_per_host))

worker_world_size = _cumulative_workers[-1]

worker_global_ranks = [] # list of worker ranks per host
for n in range(num_hosts):
host_ranks = range(_cumulative_workers[n], _cumulative_workers[n + 1])
worker_global_ranks.append(list(host_ranks))

worker_log_files = [
[
log_dir / f"{timestamp}_{hostname}_{local_rank}.log"
Expand Down Expand Up @@ -183,7 +190,7 @@ def launch(
e += f"{v.message['extraInfo']['py_callstack']}\n\n"
raise RuntimeError(e)
except:
# kill all agents
# cleanup: SIGTERM all agents
for agent_pid, agent_hostname in zip(agent_pids, hostnames):
execute_command(
command=f"kill {agent_pid}",
Expand Down

0 comments on commit b5241af

Please sign in to comment.