Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware] Add support for Huawei Ascend NPU #198

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions examples/ppo_trainer/run_llama_3.2_1b_megatron.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
set -x

python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer_llama_3.2_1b'\
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=256 \
data.val_batch_size=256 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=$HOME/models/Llama-3.2-1B-Instruct \
actor_rollout_ref.actor.optim.lr=2e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
actor_rollout_ref.actor.ppo_micro_batch_size=8 \
actor_rollout_ref.rollout.log_prob_micro_batch_size=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.ref.log_prob_micro_batch_size=8 \
critic.optim.lr=2e-5 \
critic.model.path=$HOME/models/Llama-3.2-1B-Instruct \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=8 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='verl_megatron_gsm8k_examples' \
trainer.experiment_name='llama_3.2_1b_function_rm' \
trainer.n_gpus_per_node=4 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_epochs=15 $@
29 changes: 29 additions & 0 deletions npu-turtorial.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
环境信息:

1. NPU 基础镜像,安装最新的 CANN-Toolkit & Kernel & NNAL
2. 安装 veRL 和相关依赖
1. veRL
1. `git clone https://github.com/chendong-1998/verl.git`
2. `cd verl && git checkout support-ascend-npu`
3. `pip install -r requirements.txt`
2. vLLM on NPU
1. 参考该 PR https://github.com/vllm-project/vllm/pull/8054
2. `git clone -b 1130/npu_support https://github.com/Chendong98/vllm.git`
3. `cd vllm && VLLM_TARGET_DEVICE=npu pip install -e .`
3. Megatron
1. `git clone https://github.com/NVIDIA/Megatron-LM.git`
2. `git checkout core_r0.6.0`
3. `git apply verl/patches/megatron_v0.6_npu.patch`
4. `pip install -e .`
4. Mindspeed
1. `git clone https://gitee.com/ascend/MindSpeed.git`
2. `git checkout core_r0.6.0` (or commit id `e7ea32a1e054`)
3. `pip install -e .`
5. Ascend-Apex
1. `git clone https://gitee.com/ascend/apex.git ascend-apex`
2. `cd ascend-apex && bash scripts/build.sh --python=3.10`
3. 安装构建的 whl 包

另外可能需要校对模型、数据集路径、修改 ` verl/third_party/vllm/__init__.py` 指定安装的 vllm develop 版本(可能形如 `0.1.dev3628+g06f1b1d.d20250116.npu`)使用 `0.6.4` 目录的 spmd vllm。

最后进入 veRL 执行 `bash example/ppo_trainer/run_llama_3.2_1b_megatron.sh` 即可在 NPU 上运行。
281 changes: 281 additions & 0 deletions patches/megatron_v0.6_npu.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py
index 13e321f5..ddfbdc25 100644
--- a/megatron/core/distributed/distributed_data_parallel.py
+++ b/megatron/core/distributed/distributed_data_parallel.py
@@ -52,6 +52,7 @@ class DistributedDataParallel(MegatronModule):
):
super().__init__(config=config)
self.module = module
+ self.data_parallel_group = data_parallel_group

# Set bucket_size to infinity if overlap_grad_reduce is False.
self.overlap_grad_reduce = overlap_grad_reduce
@@ -268,7 +269,8 @@ class DistributedDataParallel(MegatronModule):
else:
torch.distributed.broadcast(
param.data,
- src=torch.distributed.get_process_group_ranks(self.data_parallel_group),
+ # refer to https://github.com/NVIDIA/Megatron-LM/pull/796
+ src=torch.distributed.get_process_group_ranks(self.data_parallel_group)[0],
group=self.data_parallel_group,
)

diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py
index eb251761..54a80f44 100644
--- a/megatron/core/pipeline_parallel/schedules.py
+++ b/megatron/core/pipeline_parallel/schedules.py
@@ -315,6 +315,7 @@ def forward_backward_no_pipelining(
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int, # unused
+ hidden_size: int, # unused
micro_batch_size: int, # unused
decoder_seq_length: int = None, # unused
forward_only: bool = False,
@@ -403,8 +404,10 @@ def forward_backward_pipelining_with_interleaving(
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
- seq_length: int,
- micro_batch_size: int,
+ seq_length: int = None,
+ hidden_size: int = None,
+ micro_batch_size: int = None,
+ input_shapes: list = None,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
@@ -491,7 +494,7 @@ def forward_backward_pipelining_with_interleaving(
"Interleaving is not supported with a different decoder sequence length."
)

- tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
+ tensor_shape = [seq_length, micro_batch_size, hidden_size]
tensor_shape[0] = tensor_shape[0] // parallel_state.get_context_parallel_world_size()
if config.sequence_parallel:
tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()
@@ -983,6 +986,7 @@ def get_tensor_shapes(
rank: int,
model_type: ModelType,
seq_length: int,
+ hidden_size: int,
micro_batch_size: int,
decoder_seq_length: int,
config,
@@ -1010,12 +1014,12 @@ def get_tensor_shapes(

if model_type == ModelType.encoder_and_decoder:
if parallel_state.is_pipeline_stage_before_split(rank):
- tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
+ tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
else:
- tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size))
- tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
+ tensor_shapes.append((decoder_seq_length, micro_batch_size, hidden_size))
+ tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
else:
- tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
+ tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
return tensor_shapes


@@ -1093,8 +1097,10 @@ def forward_backward_pipelining_without_interleaving(
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
- seq_length: int,
- micro_batch_size: int,
+ seq_length: int = None,
+ hidden_size: int = None,
+ micro_batch_size: int = None,
+ input_shapes: list = None,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
@@ -1171,22 +1177,34 @@ def forward_backward_pipelining_without_interleaving(
model_type = get_model_type(model)

rank = parallel_state.get_pipeline_model_parallel_rank()
- recv_tensor_shapes = get_tensor_shapes(
- rank=rank - 1,
- model_type=model_type,
- seq_length=seq_length,
- micro_batch_size=micro_batch_size,
- decoder_seq_length=decoder_seq_length,
- config=config,
- )
- send_tensor_shapes = get_tensor_shapes(
- rank=rank,
- model_type=model_type,
- seq_length=seq_length,
- micro_batch_size=micro_batch_size,
- decoder_seq_length=decoder_seq_length,
- config=config,
- )
+
+ def get_recv_tensor_shapes(microbatch_id):
+ if input_shapes:
+ return [input_shapes[microbatch_id]]
+ recv_tensor_shapes = get_tensor_shapes(
+ rank=rank - 1,
+ model_type=model_type,
+ seq_length=seq_length,
+ hidden_size=hidden_size,
+ micro_batch_size=micro_batch_size,
+ decoder_seq_length=decoder_seq_length,
+ config=config,
+ )
+ return recv_tensor_shapes
+
+ def get_send_tensor_shapes(microbatch_id):
+ if input_shapes:
+ return [input_shapes[microbatch_id]]
+ send_tensor_shapes = get_tensor_shapes(
+ rank=rank,
+ model_type=model_type,
+ seq_length=seq_length,
+ hidden_size=hidden_size,
+ micro_batch_size=micro_batch_size,
+ decoder_seq_length=decoder_seq_length,
+ config=config,
+ )
+ return send_tensor_shapes

# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
@@ -1207,6 +1225,7 @@ def forward_backward_pipelining_without_interleaving(
else:
checkpoint_activations_microbatch = None

+ recv_tensor_shapes = get_recv_tensor_shapes(i) # fwd recv shape
input_tensor = recv_forward(recv_tensor_shapes, config)
output_tensor = forward_step(
forward_step_func,
@@ -1220,6 +1239,7 @@ def forward_backward_pipelining_without_interleaving(
checkpoint_activations_microbatch,
check_first_val_step(first_val_step, forward_only, i == 0),
)
+ send_tensor_shapes = get_send_tensor_shapes(i) # fwd send shape
send_forward(output_tensor, send_tensor_shapes, config)

if not forward_only:
@@ -1231,11 +1251,14 @@ def forward_backward_pipelining_without_interleaving(
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
+ recv_tensor_shapes = get_recv_tensor_shapes(num_warmup_microbatches) # fwd recv shape
input_tensor = recv_forward(recv_tensor_shapes, config)

# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = i == (num_microbatches_remaining - 1)
+ next_forward_k = num_warmup_microbatches + i + 1
+ backward_k = i

# Decide to checkpoint all layers' activations of the current micro-batch
if max_outstanding_backprops is not None:
@@ -1260,13 +1283,18 @@ def forward_backward_pipelining_without_interleaving(
),
)

+ send_tensor_shapes = get_send_tensor_shapes(i) # fwd send shape
+
if forward_only:
+ send_tensor_shapes = get_send_tensor_shapes(next_forward_k - 1) # fwd send shape
send_forward(output_tensor, send_tensor_shapes, config)

if not last_iteration:
+ recv_tensor_shapes = get_recv_tensor_shapes(next_forward_k) # fwd recv shape
input_tensor = recv_forward(recv_tensor_shapes, config)

else:
+ send_tensor_shapes = get_send_tensor_shapes(backward_k) # bwd recv shape
output_tensor_grad = send_forward_recv_backward(
output_tensor, send_tensor_shapes, config
)
@@ -1293,8 +1321,10 @@ def forward_backward_pipelining_without_interleaving(

if last_iteration:
input_tensor = None
+ recv_tensor_shapes = get_recv_tensor_shapes(backward_k) # bwd send shape
send_backward(input_tensor_grad, recv_tensor_shapes, config)
else:
+ recv_tensor_shapes = get_recv_tensor_shapes(next_forward_k) # fwd recv shape
input_tensor = send_backward_recv_forward(
input_tensor_grad, recv_tensor_shapes, config
)
@@ -1302,6 +1332,7 @@ def forward_backward_pipelining_without_interleaving(
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
+ backward_k = num_microbatches_remaining + i

# Enable async grad reduction in the last backward pass
# Note: If grad sync function is provided, only enable
@@ -1315,12 +1346,14 @@ def forward_backward_pipelining_without_interleaving(
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)

+ send_tensor_shapes = get_send_tensor_shapes(backward_k) # bwd recv shape
output_tensor_grad = recv_backward(send_tensor_shapes, config)

input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)

+ recv_tensor_shapes = get_recv_tensor_shapes(backward_k) # bwd send shape
send_backward(input_tensor_grad, recv_tensor_shapes, config)

# Launch any remaining grad reductions.
diff --git a/megatron/core/utils.py b/megatron/core/utils.py
index 44abd182..969f0112 100644
--- a/megatron/core/utils.py
+++ b/megatron/core/utils.py
@@ -56,7 +56,7 @@ def get_model_type(model):


def get_model_config(model):
- return get_attr_wrapped_model(model, 'config', allow_none=False)
+ return get_attr_wrapped_model(model, 'megatron_config', allow_none=False)


class GlobalMemoryBuffer:
diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py
index 6e3ff990..e190ee1a 100644
--- a/megatron/training/arguments.py
+++ b/megatron/training/arguments.py
@@ -49,7 +49,19 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):

# Parse.
if ignore_unknown_args:
- args, _ = parser.parse_known_args()
+ # mindspeed need some args & global var so far, we need a more in-depth analysis to remove this part of the code.
+ arg_list = [
+ "--tensor-model-parallel-size", "2",
+ "--pipeline-model-parallel-size", "1",
+ "--num-layers", "16",
+ "--hidden-size", "2048",
+ "--num-attention-heads", "32",
+ "--seq-length", "512",
+ "--max-position-embeddings", "131072",
+ "--micro-batch-size", "8",
+ ]
+ args, _ = parser.parse_known_args(arg_list)
+ # args, _ = parser.parse_known_args()
else:
args = parser.parse_args()

diff --git a/setup.py b/setup.py
index 2071a62c..3c6cde5a 100644
--- a/setup.py
+++ b/setup.py
@@ -113,7 +113,8 @@ setuptools.setup(
'Natural Language :: English',
'Operating System :: OS Independent',
],
- packages=setuptools.find_namespace_packages(include=["megatron.core", "megatron.core.*"]),
+ # make training and other megatron module available when use `pip install -e .`
+ packages=setuptools.find_namespace_packages(include=["megatron.core", "megatron.core.*", "megatron.training", "megatron.legacy", "megatron.inference"]),
ext_modules=[
Extension(
"megatron.core.datasets.helpers",
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ pybind11
ray>=2.38
tensordict<0.6
transformers<4.48
vllm<=0.6.3
# vllm<=0.6.3
wandb
liger-kernel
7 changes: 6 additions & 1 deletion verl/models/llama/megatron/checkpoint_utils/llama_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pkg_resources
import torch
import time
from typing import Dict, Any, Callable, Optional
import torch.distributed as dist

megatron_version = pkg_resources.get_distribution('megatron_core').version

def _megatron_calc_layer_map(config):
"""Calculate the mapping of global layer_idx to local layer_idx
Expand Down Expand Up @@ -53,7 +55,10 @@ def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params
"""
import megatron
from megatron.core import mpu
from megatron.utils import print_rank_0, unwrap_model
if pkg_resources.parse_version(megatron_version) < pkg_resources.parse_version('0.6.0'):
from megatron.utils import print_rank_0, unwrap_model
else:
from megatron.training.utils import print_rank_0, unwrap_model
from megatron.core.transformer.module import Float16Module
from megatron.core import DistributedDataParallel as LocalDDP
from torch.nn.parallel import DistributedDataParallel as torchDDP
Expand Down
Loading