Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid init_ddp for inference #12011

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

import torch
import torch.distributed
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities import move_data_to_device
from megatron.core import parallel_state
from megatron.core.distributed import DistributedDataParallel as McoreDDP
Expand Down Expand Up @@ -564,7 +565,9 @@ def init_model_parallel(self):
if self.convert_module_fn:
self.apply_convert_module_fn()

self.init_ddp()
# Skip init_ddp for inference i.e testing as it can lead to OOM.
if not self.trainer.state.fn == TrainerFn.TESTING:
self.init_ddp()

def apply_convert_module_fn(self):
for i in range(len(self)):
Expand Down
6 changes: 4 additions & 2 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,11 +477,13 @@ def setup_megatron_parallel(self, trainer: pl.Trainer) -> None:
convert_module_fn=convert_module_fn,
)

# Assign trainer to megatron_parallel before init_model_parallel as its required to check stage of trainer
# (TESTING or not) in init_model_parallel.
self.megatron_parallel.trainer = trainer

if self._init_model_parallel:
self.init_model_parallel()

self.megatron_parallel.trainer = trainer

# check signature-def of self.model.configure_optimizers to check if there's an optional arg: megatron_parallel
sig = inspect.signature(self.model.configure_optimizers)
if "megatron_parallel" in sig.parameters:
Expand Down
Loading