Skip to content

Commit

Permalink
Formatting and Device-Setting
Browse files Browse the repository at this point in the history
  • Loading branch information
Sathwik Yanamaddi committed Mar 12, 2024
1 parent e732f07 commit 5945e9e
Showing 1 changed file with 6 additions and 15 deletions.
21 changes: 6 additions & 15 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,6 @@ def __init__(
else:
self.backend = "nccl"

# infer gpus per node if not provided
self.gpus_per_node = (
gpus_per_node if gpus_per_node is not None else torch.cuda.device_count()
)

if config.device == "cuda" and gpus_per_node:
self.local_rank = self.world_rank % gpus_per_node
torch.cuda.set_device(self.local_rank)

if not torch.distributed.is_initialized():
assert MPI4PY, "either install mpi4py and launch via mpirun/srun"
"or initialize torch.distributed outside axonn"
Expand All @@ -100,13 +91,13 @@ def __init__(
self.G_intra_d = G_intra_d

# infer gpus per node if not provided
# self.gpus_per_node = (
# gpus_per_node if gpus_per_node is not None else torch.cuda.device_count()
# )
self.gpus_per_node = (
gpus_per_node if gpus_per_node is not None else torch.cuda.device_count()
)

# if config.device == "cuda" and gpus_per_node:
# self.local_rank = self.world_rank % gpus_per_node
# torch.cuda.set_device(self.local_rank)
if config.device == "cuda" and gpus_per_node:
self.local_rank = self.world_rank % gpus_per_node
torch.cuda.set_device(self.local_rank)
self.intra_layer_parallel_rank = self.world_rank % G_intra
self.intra_layer_column_parallel_rank = (
self.intra_layer_parallel_rank % G_intra_c
Expand Down

0 comments on commit 5945e9e

Please sign in to comment.