Skip to content

Commit

Permalink
[bug] Fix the version check bug in colossalai run when generating the…
Browse files Browse the repository at this point in the history
… cmd. (hpcaitech#4713)

* Fix the version check bug in colossalai run when generating the cmd.

* polish code
  • Loading branch information
littsk authored Sep 22, 2023
1 parent 3e05c07 commit 1e0e080
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions colossalai/cli/launcher/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def _arg_dict_to_list(arg_dict):
torch_version = version.parse(torch.__version__)
assert torch_version.major >= 1

if torch_version.minor < 9:
if torch_version.major == 1 and torch_version.minor < 9:
# torch distributed launch cmd with torch < 1.9
cmd = [
sys.executable,
"-m",
Expand All @@ -177,7 +178,8 @@ def _arg_dict_to_list(arg_dict):
value = extra_launch_args.pop(key)
default_torchrun_rdzv_args[key] = value

if torch_version.minor < 10:
if torch_version.major == 1 and torch_version.minor == 9:
# torch distributed launch cmd with torch == 1.9
cmd = [
sys.executable,
"-m",
Expand All @@ -187,6 +189,7 @@ def _arg_dict_to_list(arg_dict):
f"--node_rank={node_rank}",
]
else:
# torch distributed launch cmd with torch > 1.9
cmd = [
"torchrun",
f"--nproc_per_node={nproc_per_node}",
Expand Down

0 comments on commit 1e0e080

Please sign in to comment.