Skip to content

Commit

Permalink
Guard TE import & clean up unused imports
Browse files Browse the repository at this point in the history
Signed-off-by: Valerie Sarge <[email protected]>
  • Loading branch information
vysarge committed Feb 5, 2025
1 parent c4e1e32 commit 3b9525a
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
7 changes: 1 addition & 6 deletions nemo/collections/llm/gpt/data/mlperf_govreport.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,13 @@

import shutil
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from datasets import DatasetDict, load_dataset
import numpy as np

import torch
from torch import nn

from nemo.collections.llm.gpt.data.core import get_dataset_root
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule
from nemo.lightning.io.mixin import IOMixin
from nemo.utils import logging

Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/llm/gpt/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ def __init__(
):
super().__init__(config or LlamaConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform)

from nemo.utils.import_utils import safe_import
_, HAVE_TE = safe_import("transformer_engine")
assert HAVE_TE, "TransformerEngine is required for MLPerfLoRALlamaModel."

def configure_model(self):
# Apply context managers to reduce memory by (1) avoiding unnecessary gradients
# and (2) requesting that TE initialize params as FP8.
Expand Down
2 changes: 1 addition & 1 deletion scripts/llm/performance/mlperf_lora_llama2_70b.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from nemo import lightning as nl
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, MegatronOptimizerModule
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin

from utils import (
Expand Down

0 comments on commit 3b9525a

Please sign in to comment.