Skip to content

Commit

Permalink
Merge branch 'chenhany/mamba_modelopt_support' into 'main'
Browse files Browse the repository at this point in the history
Support MCore MambaModel quantization through TensorRT Model Optimizer

See merge request ADLR/megatron-lm!2527
  • Loading branch information
ko3n1g committed Feb 6, 2025
2 parents 3b9035c + f575d3f commit 0dd78dd
Show file tree
Hide file tree
Showing 12 changed files with 348 additions and 73 deletions.
8 changes: 0 additions & 8 deletions megatron/core/inference/ammo_support/__init__.py

This file was deleted.

2 changes: 0 additions & 2 deletions megatron/core/inference/ammo_support/gpt/model_specs.py

This file was deleted.

5 changes: 0 additions & 5 deletions megatron/core/inference/ammo_support/gpt/state_dict_hooks.py

This file was deleted.

10 changes: 6 additions & 4 deletions megatron/core/inference/modelopt_support/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Integrations with NVIDIA TensorRT Model Optimizer (referred as ModelOpt).
ModelOpt is a library comprising state-of-the-art model optimization techniques including quantization and sparsity to
compress model for efficient inference on NVIDIA GPUs. ModelOpt is integrated with Megatron-core to provide a seamless
experience for users to optimize their Megatron-core models for inference. More details on ModelOpt including
installation and usage can be found at https://github.com/NVIDIA/TensorRT-Model-Optimizer.
ModelOpt is a library comprising state-of-the-art model optimization techniques
including quantization and sparsity to compress model for efficient inference on
NVIDIA GPUs. ModelOpt is integrated with Megatron-core to provide a seamless
experience for users to optimize their Megatron-core models for inference.
More details on ModelOpt including installation and usage can be found at
https://github.com/NVIDIA/TensorRT-Model-Optimizer.
"""
5 changes: 4 additions & 1 deletion megatron/core/inference/modelopt_support/gpt/model_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.spec_utils import ModuleSpec
Expand All @@ -16,6 +17,7 @@
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def get_gpt_layer_modelopt_spec(
num_experts: Optional[int] = None,
local_core_attention: bool = False,
moe_grouped_gemm: bool = False,
remap_te_layernorm: bool = False,
qk_layernorm: bool = False,
Expand All @@ -26,6 +28,7 @@ def get_gpt_layer_modelopt_spec(
is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex
has stopped supporting RMSNorm needed by llama.
"""
core_attention = DotProductAttention if local_core_attention else TEDotProductAttention
mlp = get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=False
)
Expand All @@ -49,7 +52,7 @@ def get_gpt_layer_modelopt_spec(
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=TEDotProductAttention,
core_attention=core_attention,
linear_proj=RowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
k_layernorm=TENorm if qk_layernorm else IdentityOp,
Expand Down
1 change: 1 addition & 0 deletions megatron/core/inference/modelopt_support/mamba/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
89 changes: 89 additions & 0 deletions megatron/core/inference/modelopt_support/mamba/model_specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules


# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def get_mamba_stack_modelopt_spec(
local_core_attention: bool = False, remap_te_layernorm: bool = False
) -> ModuleSpec:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine.
"""
mamba_state_dict_keys_map = {}
transformer_state_dict_keys_map = {}
if remap_te_layernorm:
mamba_state_dict_keys_map = {'norm.': 'mixer.in_proj.layer_norm_'}
transformer_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
}

mamba_layer = ModuleSpec(
module=MambaLayer,
submodules=MambaLayerSubmodules(
norm=TENorm,
mixer=ModuleSpec(
module=MambaMixer,
submodules=MambaMixerSubmodules(
in_proj=ColumnParallelLinear, out_proj=RowParallelLinear
),
),
mamba_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=mamba_state_dict_keys_map,
),
)

core_attention = DotProductAttention if local_core_attention else TEDotProductAttention
attention_layer = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=core_attention,
linear_proj=RowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=transformer_state_dict_keys_map,
),
)

mlp_layer = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
pre_mlp_layernorm=TENorm,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=transformer_state_dict_keys_map,
),
)

return ModuleSpec(
module=MambaStack,
submodules=MambaStackSubmodules(
mamba_layer=mamba_layer, attention_layer=attention_layer, mlp_layer=mlp_layer
),
)
33 changes: 31 additions & 2 deletions megatron/core/ssm/mamba_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Union
from dataclasses import dataclass, field
from typing import Dict, Optional, Union

import torch
from torch import Tensor

from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import apply_prefix_mapping
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
Expand All @@ -37,6 +39,9 @@ class MambaLayerSubmodules:
mixer: Union[ModuleSpec, type] = IdentityOp
mamba_bda: Union[ModuleSpec, type] = IdentityOp

# Mapping for sharded tensor keys to be applied in `sharded_state_dict` method
sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict)


class MambaLayer(MegatronModule):
"""
Expand All @@ -57,6 +62,7 @@ def __init__(
"""Initialize Mamba Layer."""
super().__init__(config)
self.config = config
self.submodules_config = submodules
self.layer_number = layer_number
self.residual_in_fp32 = residual_in_fp32
self.hidden_dropout = config.hidden_dropout
Expand Down Expand Up @@ -114,3 +120,26 @@ def forward(
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
"""Allocate the inference cache."""
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)

def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""
Generate a sharded state dictionary for the mamba layer.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the mamba layer.
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
prefixed_map = {
f'{prefix}{k}': f'{prefix}{v}'
for k, v in self.submodules_config.sharded_state_dict_keys_map.items()
}
if prefixed_map:
apply_prefix_mapping(sharded_state_dict, prefixed_map)
return sharded_state_dict
2 changes: 1 addition & 1 deletion megatron/inference/text_generation/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def score_and_return_on_first_stage(model, tokens: torch.Tensor, lengths: torch.
if mpu.is_pipeline_last_stage():
# Always the last stage should have an output.
assert logits is not None
log_probs = F.log_softmax(logits, dim=2)
log_probs = F.log_softmax(logits, dim=2).to(dtype=output_topk_log_probs.dtype)

# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
Expand Down
17 changes: 10 additions & 7 deletions megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,16 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
if ckpt_tp_pp != run_tp_pp:
print_rank_0("{}: Rerun state will be ignored".format(mismatch_msg))

# [ModelOpt]: IMPORTANT! Restoring modelopt_state (sharded or not) must be performed
# after the model instance has been created and before _load_base_checkpoint is called.
if has_nvidia_modelopt:
if ckpt_type == CheckpointType.LOCAL:
raise NotImplementedError('Local checkpointing does not support model opt')
if not args.use_dist_ckpt:
restore_modelopt_state(model, state_dict)
else:
restore_sharded_modelopt_state(model, checkpoint_name)

# [ModelOpt]: Initial loading from non-resume sharded checkpoint to a Distillation Model
# will result in key mismatch with loss modules potentially containing parameters, since
# it requires generating a state_dict before loading. Here we hide those modules if present.
Expand Down Expand Up @@ -1244,13 +1254,6 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
else:
print_rank_0('could not find arguments in the checkpoint ...')

# [ModelOpt]: loading modelopt_state (sharded or not)
if has_nvidia_modelopt:
if ckpt_type == CheckpointType.GLOBAL:
restore_sharded_modelopt_state(model, checkpoint_name)
else:
restore_modelopt_state(model, state_dict)

# Model.
strict = False if args.retro_add_retriever else strict
if not skip_load_to_model_and_opt:
Expand Down
43 changes: 0 additions & 43 deletions tests/unit_tests/inference/test_modelopt_gpt_model.py

This file was deleted.

Loading

0 comments on commit 0dd78dd

Please sign in to comment.