Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add performance-optimized example for llama2 70b LoRA #12055

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions nemo/collections/llm/gpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from nemo.collections.llm.gpt.data.dolly import DollyDataModule
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.collections.llm.gpt.data.hf_dataset import HFDatasetDataModule
from nemo.collections.llm.gpt.data.mlperf_govreport import MLPerfGovReportDataModule
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule, build_pretraining_datamodule
from nemo.collections.llm.gpt.data.retrieval import CustomRetrievalDataModule
Expand All @@ -28,6 +29,7 @@
"DollyDataModule",
"FineTuningDataModule",
"HFDatasetDataModule",
"MLPerfGovReportDataModule",
"MockDataModule",
"PreTrainingDataModule",
"build_pretraining_datamodule",
Expand Down
189 changes: 189 additions & 0 deletions nemo/collections/llm/gpt/data/mlperf_govreport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import shutil
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import numpy as np
from datasets import DatasetDict, load_dataset

from nemo.collections.llm.gpt.data.core import get_dataset_root
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.lightning.io.mixin import IOMixin
from nemo.utils import logging

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs


class MLPerfGovReportDataModule(FineTuningDataModule, IOMixin):
"""
A data module for fine-tuning on the govreport dataset as preprocessed for MLPerf;
see https://huggingface.co/datasets/regisss/scrolls_gov_report_preprocessed_mlperf_2

Inherits from `FineTuningDataModule` and handles data download, splitting, and
saving in a format ready for training.

Args:
force_redownload (bool, optional): Whether to force re-download the dataset even
if it exists locally. Defaults to False.
delete_raw (bool, optional): Whether to delete the raw downloaded dataset after
preprocessing. Defaults to True.
See FineTuningDataModule for the other args
"""

def __init__(
self,
seq_length: int = 2048,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
force_redownload: bool = False,
delete_raw: bool = True,
seed: int = 1234,
memmap_workers: int = 1,
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
packed_sequence_specs: Optional["PackedSequenceSpecs"] = None,
dataset_kwargs: Optional[Dict[str, Any]] = None,
):
self.force_redownload = force_redownload
self.delete_raw = delete_raw

super().__init__(
dataset_root=get_dataset_root("govreport"),
seq_length=seq_length,
tokenizer=tokenizer,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
rampup_batch_size=rampup_batch_size,
seed=seed,
memmap_workers=memmap_workers,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
packed_sequence_specs=packed_sequence_specs,
dataset_kwargs=dataset_kwargs,
)

if self.packed_sequence_size != self.seq_length:
raise ValueError(
f"{self.__class__.__name__} requires `packed_sequence_specs.packed_sequence_size` to be nonzero "
f"and equal to `seq_length`. Instead got packed_sequence_size = {self.packed_sequence_size} "
f"and seq_length = {self.seq_length}"
)

def prepare_data(self) -> None:
# if train file is specified, no need to do anything
if not self.train_path.exists() or self.force_redownload:
dset = self._download_data()
self._preprocess_and_split_data(dset)
super().prepare_data()

def _download_data(self):
logging.info(f"Downloading {self.__class__.__name__}...")
return load_dataset(
"regisss/scrolls_gov_report_preprocessed_mlperf_2",
cache_dir=str(self.dataset_root),
download_mode="force_redownload" if self.force_redownload else None,
)

def _preprocess_and_split_data(
self, dset: DatasetDict, split_val_from_train: bool = True, val_proportion: float = 0.05
):
"""Preprocesses and splits the downloaded dataset into training, validation, and test sets.

Args:
dset (DatasetDict): The downloaded dataset object.
split_val_from_train (bool, optional): Whether to split the validation set from the training set.
If False, the validation set is split from the test set. Defaults to True.
val_proportion (float, optional): The proportion of the training or test set to be used for
the validation split.
Defaults to 0.05.
"""
logging.info(f"Preprocessing {self.__class__.__name__} to npy format and splitting...")
save_splits = {}
train_set = dset.get('train')
val_set = dset.get('validation')

if split_val_from_train:
split_dataset = train_set.train_test_split(test_size=val_proportion, seed=self.seed)
save_splits['training'] = split_dataset['train']
save_splits['validation'] = split_dataset['test']
save_splits['test'] = val_set
else:
split_dataset = val_set.train_test_split(test_size=val_proportion, seed=self.seed)
save_splits['training'] = train_set
save_splits['validation'] = split_dataset['test']
save_splits['test'] = split_dataset['train']

for split_name, dataset in save_splits.items():
output_file = self.dataset_root / f"{split_name}.npy"
processed_data = [
{
"input_ids": example["input_ids"],
"loss_mask": [int(x != -100) for x in example["labels"]],
"seq_start_id": [0],
}
for example in dataset
]
np.save(output_file, processed_data)

logging.info(f"{split_name} split saved to {output_file}")

if self.delete_raw:
for p in self.dataset_root.iterdir():
if p.is_dir():
shutil.rmtree(p)
elif '.npy' not in str(p.name):
p.unlink()

@property
def train_path(self) -> Path:
"""Path to training dataset file"""
return self.dataset_root / "training.npy"

@property
def validation_path(self) -> Path:
"""Path to validation dataset file"""
return self.dataset_root / "validation.npy"

@property
def test_path(self) -> Path:
"""Path to test dataset file"""
return self.dataset_root / "test.npy"

@property
def default_pack_path(self) -> Path:
return None

@property
def pack_metadata(self) -> Path:
return None

@property
def train_path_packed(self) -> Path:
"""Path to training dataset file for packed sequence. The file path contains a reference to the
tokenizer/model name since packed sequence dataset consists of tokenized indices."""
return self.train_path

@property
def validation_path_packed(self) -> Path:
"""Path to validation dataset file for packed sequence. The file path contains a reference to the
tokenizer/model name since packed sequence dataset consists of tokenized indices."""
return self.validation_path
9 changes: 9 additions & 0 deletions nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
Llama32Config3B,
LlamaConfig,
LlamaModel,
MLPerfLoRALlamaModel,
)
from nemo.collections.llm.gpt.model.llama_embedding import Llama32EmbeddingConfig1B, LlamaEmbeddingModel
from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel, MistralNeMoConfig12B
Expand Down Expand Up @@ -112,9 +113,16 @@

__all__ = [
"GPTConfig",
"GPTConfig5B",
"GPTConfig7B",
"GPTConfig20B",
"GPTConfig40B",
"GPTConfig126M",
"GPTConfig175B",
"GPTModel",
"MistralConfig7B",
"MistralModel",
"MistralNeMoConfig12B",
"MixtralConfig8x3B",
"MixtralConfig8x7B",
"MixtralConfig8x22B",
Expand Down Expand Up @@ -167,6 +175,7 @@
"Gemma2Config9B",
"Gemma2Model",
"LlamaModel",
"MLPerfLoRALlamaModel",
"Baichuan2Config",
"Baichuan2Config7B",
"Baichuan2Model",
Expand Down
73 changes: 57 additions & 16 deletions nemo/collections/llm/gpt/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

if TYPE_CHECKING:
from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel
from peft import PeftConfig
from peft import AutoPeftModelForCausalLM, PeftConfig
from transformers import LlamaConfig as HFLlamaConfig
from transformers import LlamaForCausalLM

Expand Down Expand Up @@ -252,6 +252,37 @@ def __init__(
super().__init__(config or LlamaConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform)


class MLPerfLoRALlamaModel(LlamaModel):
"""
This class wraps LlamaModel and adds context managers around configure_model to reduce memory consumption.

Changes made here are experimental, proceed with caution.
"""

def __init__(
self,
config: Annotated[Optional[LlamaConfig], Config[LlamaConfig]] = None,
optim: Optional[OptimizerModule] = None,
tokenizer: Optional["TokenizerSpec"] = None,
model_transform: Optional[Callable[[nn.Module], nn.Module]] = None,
):
super().__init__(config or LlamaConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform)

from nemo.utils.import_utils import safe_import

_, HAVE_TE = safe_import("transformer_engine")
assert HAVE_TE, "TransformerEngine is required for MLPerfLoRALlamaModel."

def configure_model(self):
# Apply context managers to reduce memory by (1) avoiding unnecessary gradients
# and (2) requesting that TE initialize params as FP8. See:
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html#transformer_engine.pytorch.fp8_model_init
import transformer_engine.pytorch as te

with torch.no_grad(), te.fp8_model_init():
super().configure_model()


@io.model_importer(LlamaModel, "hf")
class HFLlamaImporter(io.ModelConnector["LlamaForCausalLM", LlamaModel]):
def init(self) -> LlamaModel:
Expand Down Expand Up @@ -434,31 +465,39 @@ def convert_state(self, source, target):
pn = "decoder.layers."
ph = "base_model.model.model.layers."

# linear_proj and linear_fc2 prefixes
p_proj = "self_attention.linear_proj.adapter"
p_fc2 = "mlp.linear_fc2.adapter"

# linear_qkv and linear_fc1 prefixes
p_qkv = "self_attention.linear_qkv.adapter"
p_fc1 = "mlp.linear_fc1.adapter"

mapping = {
# linear_proj for both canonical and performant lora
f"{pn}*.self_attention.linear_proj.adapter.linear_in.weight": f"{ph}*.self_attn.o_proj.lora_A.default.weight",
f"{pn}*.self_attention.linear_proj.adapter.linear_out.weight": f"{ph}*.self_attn.o_proj.lora_B.default.weight",
f"{pn}*.{p_proj}.linear_in.weight": f"{ph}*.self_attn.o_proj.lora_A.default.weight",
f"{pn}*.{p_proj}.linear_out.weight": f"{ph}*.self_attn.o_proj.lora_B.default.weight",
# linear_fc2 for both canonical and performant lora
f"{pn}*.mlp.linear_fc2.adapter.linear_in.weight": f"{ph}*.mlp.down_proj.lora_A.default.weight",
f"{pn}*.mlp.linear_fc2.adapter.linear_out.weight": f"{ph}*.mlp.down_proj.lora_B.default.weight",
f"{pn}*.{p_fc2}.linear_in.weight": f"{ph}*.mlp.down_proj.lora_A.default.weight",
f"{pn}*.{p_fc2}.linear_out.weight": f"{ph}*.mlp.down_proj.lora_B.default.weight",
}
transforms = []

if isinstance(self.peft_obj, CanonicalLoRA):
mapping.update(
{
# linear_qkv for canonical lora
f"{pn}*.self_attention.linear_qkv.adapter.adapter_q.linear_in.weight": f"{ph}*.self_attn.q_proj.lora_A.default.weight",
f"{pn}*.self_attention.linear_qkv.adapter.adapter_q.linear_out.weight": f"{ph}*.self_attn.q_proj.lora_B.default.weight",
f"{pn}*.self_attention.linear_qkv.adapter.adapter_k.linear_in.weight": f"{ph}*.self_attn.k_proj.lora_A.default.weight",
f"{pn}*.self_attention.linear_qkv.adapter.adapter_k.linear_out.weight": f"{ph}*.self_attn.k_proj.lora_B.default.weight",
f"{pn}*.self_attention.linear_qkv.adapter.adapter_v.linear_in.weight": f"{ph}*.self_attn.v_proj.lora_A.default.weight",
f"{pn}*.self_attention.linear_qkv.adapter.adapter_v.linear_out.weight": f"{ph}*.self_attn.v_proj.lora_B.default.weight",
f"{pn}*.{p_qkv}.adapter_q.linear_in.weight": f"{ph}*.self_attn.q_proj.lora_A.default.weight",
f"{pn}*.{p_qkv}.adapter_q.linear_out.weight": f"{ph}*.self_attn.q_proj.lora_B.default.weight",
f"{pn}*.{p_qkv}.adapter_k.linear_in.weight": f"{ph}*.self_attn.k_proj.lora_A.default.weight",
f"{pn}*.{p_qkv}.adapter_k.linear_out.weight": f"{ph}*.self_attn.k_proj.lora_B.default.weight",
f"{pn}*.{p_qkv}.adapter_v.linear_in.weight": f"{ph}*.self_attn.v_proj.lora_A.default.weight",
f"{pn}*.{p_qkv}.adapter_v.linear_out.weight": f"{ph}*.self_attn.v_proj.lora_B.default.weight",
# linear_fc1 for canonical lora
f"{pn}*.mlp.linear_fc1.adapter.adapter_up.linear_in.weight": f"{ph}*.mlp.up_proj.lora_A.default.weight",
f"{pn}*.mlp.linear_fc1.adapter.adapter_up.linear_out.weight": f"{ph}*.mlp.up_proj.lora_B.default.weight",
f"{pn}*.mlp.linear_fc1.adapter.adapter_gate.linear_in.weight": f"{ph}*.mlp.gate_proj.lora_A.default.weight",
f"{pn}*.mlp.linear_fc1.adapter.adapter_gate.linear_out.weight": f"{ph}*.mlp.gate_proj.lora_B.default.weight",
f"{pn}*.{p_fc1}.adapter_up.linear_in.weight": f"{ph}*.mlp.up_proj.lora_A.default.weight",
f"{pn}*.{p_fc1}.adapter_up.linear_out.weight": f"{ph}*.mlp.up_proj.lora_B.default.weight",
f"{pn}*.{p_fc1}.adapter_gate.linear_in.weight": f"{ph}*.mlp.gate_proj.lora_A.default.weight",
f"{pn}*.{p_fc1}.adapter_gate.linear_out.weight": f"{ph}*.mlp.gate_proj.lora_B.default.weight",
}
)
else:
Expand Down Expand Up @@ -669,7 +708,8 @@ def apply_rope_scaling(
old_context_len: int = 8192,
):
logging.info(
f"Apply rope scaling with factor={factor}, low_freq_factor={low_freq_factor}, high_freq_factor={high_freq_factor}, old_context_len={old_context_len}."
f"Apply rope scaling with factor={factor}, low_freq_factor={low_freq_factor}, "
f"high_freq_factor={high_freq_factor}, old_context_len={old_context_len}."
)

low_freq_wavelen = old_context_len / low_freq_factor
Expand Down Expand Up @@ -705,4 +745,5 @@ def apply_rope_scaling(
"CodeLlamaConfig34B",
"CodeLlamaConfig70B",
"LlamaModel",
"MLPerfLoRALlamaModel",
]
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class PipelineOverlapCfg(TPOverlapCfg):
num_splits: int
set_sm_margin: bool
fp8_buf: bool = (False,)
atomic_gemm: bool = False
method: str = 'pipeline'


Expand All @@ -41,7 +42,10 @@ class RingExchangeOverlapCfg(TPOverlapCfg):
aggregate: bool = False
method: str = 'ring_exchange'
num_sm: int = 1
cga_size: int = 1
set_sm_margin: bool = False
fp8_buf: bool = False
atomic_gemm: bool = False


@dataclass
Expand Down
4 changes: 4 additions & 0 deletions nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,10 @@ def _init_te_userbuffers(self, model_parallel_cfg: ModelParallelConfig):
else:
# ub_cfgs is a dataclass, however TE needs a dict, so convert here
self.tp_comm_overlap_cfg = asdict(self.tp_comm_overlap_cfg)
# remove keys with None values from dictionary to match TE's expectations
self.tp_comm_overlap_cfg = {
key: value for key, value in self.tp_comm_overlap_cfg.items() if value is not None
}
erhoo82 marked this conversation as resolved.
Show resolved Hide resolved

micro_batch_size = get_micro_batch_size()
hidden_size = model_parallel_cfg.hidden_size
Expand Down
Loading
Loading