diff --git a/examples/ppo_trainer/run_llama_3.2_1b_megatron.sh b/examples/ppo_trainer/run_llama_3.2_1b_megatron.sh new file mode 100644 index 00000000..966390d9 --- /dev/null +++ b/examples/ppo_trainer/run_llama_3.2_1b_megatron.sh @@ -0,0 +1,31 @@ +set -x + +python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer_llama_3.2_1b'\ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=256 \ + data.val_batch_size=256 \ + data.max_prompt_length=512 \ + data.max_response_length=512 \ + actor_rollout_ref.model.path=$HOME/models/Llama-3.2-1B-Instruct \ + actor_rollout_ref.actor.optim.lr=2e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size=8 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=8 \ + critic.optim.lr=2e-5 \ + critic.model.path=$HOME/models/Llama-3.2-1B-Instruct \ + critic.model.enable_gradient_checkpointing=False \ + critic.ppo_micro_batch_size=8 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl_megatron_gsm8k_examples' \ + trainer.experiment_name='llama_3.2_1b_function_rm' \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.total_epochs=15 $@ diff --git a/npu-turtorial.md b/npu-turtorial.md new file mode 100644 index 00000000..59aa76fe --- /dev/null +++ b/npu-turtorial.md @@ -0,0 +1,29 @@ +环境信息: + +1. NPU 基础镜像,安装最新的 CANN-Toolkit & Kernel & NNAL +2. 安装 veRL 和相关依赖 + 1. veRL + 1. `git clone https://github.com/chendong-1998/verl.git` + 2. `cd verl && git checkout support-ascend-npu` + 3. `pip install -r requirements.txt` + 2. vLLM on NPU + 1. 参考该 PR https://github.com/vllm-project/vllm/pull/8054 + 2. `git clone -b 1130/npu_support https://github.com/Chendong98/vllm.git` + 3. `cd vllm && VLLM_TARGET_DEVICE=npu pip install -e .` + 3. Megatron + 1. `git clone https://github.com/NVIDIA/Megatron-LM.git` + 2. `git checkout core_r0.6.0` + 3. `git apply verl/patches/megatron_v0.6_npu.patch` + 4. `pip install -e .` + 4. Mindspeed + 1. `git clone https://gitee.com/ascend/MindSpeed.git` + 2. `git checkout core_r0.6.0` (or commit id `e7ea32a1e054`) + 3. `pip install -e .` + 5. Ascend-Apex + 1. `git clone https://gitee.com/ascend/apex.git ascend-apex` + 2. `cd ascend-apex && bash scripts/build.sh --python=3.10` + 3. 安装构建的 whl 包 + +另外可能需要校对模型、数据集路径、修改 ` verl/third_party/vllm/__init__.py` 指定安装的 vllm develop 版本(可能形如 `0.1.dev3628+g06f1b1d.d20250116.npu`)使用 `0.6.4` 目录的 spmd vllm。 + +最后进入 veRL 执行 `bash example/ppo_trainer/run_llama_3.2_1b_megatron.sh` 即可在 NPU 上运行。 diff --git a/patches/megatron_v0.6_npu.patch b/patches/megatron_v0.6_npu.patch new file mode 100644 index 00000000..6c76465b --- /dev/null +++ b/patches/megatron_v0.6_npu.patch @@ -0,0 +1,281 @@ +diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py +index 13e321f5..ddfbdc25 100644 +--- a/megatron/core/distributed/distributed_data_parallel.py ++++ b/megatron/core/distributed/distributed_data_parallel.py +@@ -52,6 +52,7 @@ class DistributedDataParallel(MegatronModule): + ): + super().__init__(config=config) + self.module = module ++ self.data_parallel_group = data_parallel_group + + # Set bucket_size to infinity if overlap_grad_reduce is False. + self.overlap_grad_reduce = overlap_grad_reduce +@@ -268,7 +269,8 @@ class DistributedDataParallel(MegatronModule): + else: + torch.distributed.broadcast( + param.data, +- src=torch.distributed.get_process_group_ranks(self.data_parallel_group), ++ # refer to https://github.com/NVIDIA/Megatron-LM/pull/796 ++ src=torch.distributed.get_process_group_ranks(self.data_parallel_group)[0], + group=self.data_parallel_group, + ) + +diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py +index eb251761..54a80f44 100644 +--- a/megatron/core/pipeline_parallel/schedules.py ++++ b/megatron/core/pipeline_parallel/schedules.py +@@ -315,6 +315,7 @@ def forward_backward_no_pipelining( + model: Union[torch.nn.Module, List[torch.nn.Module]], + num_microbatches: int, + seq_length: int, # unused ++ hidden_size: int, # unused + micro_batch_size: int, # unused + decoder_seq_length: int = None, # unused + forward_only: bool = False, +@@ -403,8 +404,10 @@ def forward_backward_pipelining_with_interleaving( + data_iterator: Union[Iterator, List[Iterator]], + model: Union[torch.nn.Module, List[torch.nn.Module]], + num_microbatches: int, +- seq_length: int, +- micro_batch_size: int, ++ seq_length: int = None, ++ hidden_size: int = None, ++ micro_batch_size: int = None, ++ input_shapes: list = None, + decoder_seq_length: int = None, + forward_only: bool = False, + collect_non_loss_data: bool = False, +@@ -491,7 +494,7 @@ def forward_backward_pipelining_with_interleaving( + "Interleaving is not supported with a different decoder sequence length." + ) + +- tensor_shape = [seq_length, micro_batch_size, config.hidden_size] ++ tensor_shape = [seq_length, micro_batch_size, hidden_size] + tensor_shape[0] = tensor_shape[0] // parallel_state.get_context_parallel_world_size() + if config.sequence_parallel: + tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size() +@@ -983,6 +986,7 @@ def get_tensor_shapes( + rank: int, + model_type: ModelType, + seq_length: int, ++ hidden_size: int, + micro_batch_size: int, + decoder_seq_length: int, + config, +@@ -1010,12 +1014,12 @@ def get_tensor_shapes( + + if model_type == ModelType.encoder_and_decoder: + if parallel_state.is_pipeline_stage_before_split(rank): +- tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) ++ tensor_shapes.append((seq_length, micro_batch_size, hidden_size)) + else: +- tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size)) +- tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) ++ tensor_shapes.append((decoder_seq_length, micro_batch_size, hidden_size)) ++ tensor_shapes.append((seq_length, micro_batch_size, hidden_size)) + else: +- tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) ++ tensor_shapes.append((seq_length, micro_batch_size, hidden_size)) + return tensor_shapes + + +@@ -1093,8 +1097,10 @@ def forward_backward_pipelining_without_interleaving( + data_iterator: Union[Iterator, List[Iterator]], + model: Union[torch.nn.Module, List[torch.nn.Module]], + num_microbatches: int, +- seq_length: int, +- micro_batch_size: int, ++ seq_length: int = None, ++ hidden_size: int = None, ++ micro_batch_size: int = None, ++ input_shapes: list = None, + decoder_seq_length: int = None, + forward_only: bool = False, + collect_non_loss_data: bool = False, +@@ -1171,22 +1177,34 @@ def forward_backward_pipelining_without_interleaving( + model_type = get_model_type(model) + + rank = parallel_state.get_pipeline_model_parallel_rank() +- recv_tensor_shapes = get_tensor_shapes( +- rank=rank - 1, +- model_type=model_type, +- seq_length=seq_length, +- micro_batch_size=micro_batch_size, +- decoder_seq_length=decoder_seq_length, +- config=config, +- ) +- send_tensor_shapes = get_tensor_shapes( +- rank=rank, +- model_type=model_type, +- seq_length=seq_length, +- micro_batch_size=micro_batch_size, +- decoder_seq_length=decoder_seq_length, +- config=config, +- ) ++ ++ def get_recv_tensor_shapes(microbatch_id): ++ if input_shapes: ++ return [input_shapes[microbatch_id]] ++ recv_tensor_shapes = get_tensor_shapes( ++ rank=rank - 1, ++ model_type=model_type, ++ seq_length=seq_length, ++ hidden_size=hidden_size, ++ micro_batch_size=micro_batch_size, ++ decoder_seq_length=decoder_seq_length, ++ config=config, ++ ) ++ return recv_tensor_shapes ++ ++ def get_send_tensor_shapes(microbatch_id): ++ if input_shapes: ++ return [input_shapes[microbatch_id]] ++ send_tensor_shapes = get_tensor_shapes( ++ rank=rank, ++ model_type=model_type, ++ seq_length=seq_length, ++ hidden_size=hidden_size, ++ micro_batch_size=micro_batch_size, ++ decoder_seq_length=decoder_seq_length, ++ config=config, ++ ) ++ return send_tensor_shapes + + # Input, output tensors only need to be saved when doing backward passes + input_tensors = None +@@ -1207,6 +1225,7 @@ def forward_backward_pipelining_without_interleaving( + else: + checkpoint_activations_microbatch = None + ++ recv_tensor_shapes = get_recv_tensor_shapes(i) # fwd recv shape + input_tensor = recv_forward(recv_tensor_shapes, config) + output_tensor = forward_step( + forward_step_func, +@@ -1220,6 +1239,7 @@ def forward_backward_pipelining_without_interleaving( + checkpoint_activations_microbatch, + check_first_val_step(first_val_step, forward_only, i == 0), + ) ++ send_tensor_shapes = get_send_tensor_shapes(i) # fwd send shape + send_forward(output_tensor, send_tensor_shapes, config) + + if not forward_only: +@@ -1231,11 +1251,14 @@ def forward_backward_pipelining_without_interleaving( + # If all microbatches are run in warmup / cooldown phase, then no need to + # receive this tensor here. + if num_microbatches_remaining > 0: ++ recv_tensor_shapes = get_recv_tensor_shapes(num_warmup_microbatches) # fwd recv shape + input_tensor = recv_forward(recv_tensor_shapes, config) + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + last_iteration = i == (num_microbatches_remaining - 1) ++ next_forward_k = num_warmup_microbatches + i + 1 ++ backward_k = i + + # Decide to checkpoint all layers' activations of the current micro-batch + if max_outstanding_backprops is not None: +@@ -1260,13 +1283,18 @@ def forward_backward_pipelining_without_interleaving( + ), + ) + ++ send_tensor_shapes = get_send_tensor_shapes(i) # fwd send shape ++ + if forward_only: ++ send_tensor_shapes = get_send_tensor_shapes(next_forward_k - 1) # fwd send shape + send_forward(output_tensor, send_tensor_shapes, config) + + if not last_iteration: ++ recv_tensor_shapes = get_recv_tensor_shapes(next_forward_k) # fwd recv shape + input_tensor = recv_forward(recv_tensor_shapes, config) + + else: ++ send_tensor_shapes = get_send_tensor_shapes(backward_k) # bwd recv shape + output_tensor_grad = send_forward_recv_backward( + output_tensor, send_tensor_shapes, config + ) +@@ -1293,8 +1321,10 @@ def forward_backward_pipelining_without_interleaving( + + if last_iteration: + input_tensor = None ++ recv_tensor_shapes = get_recv_tensor_shapes(backward_k) # bwd send shape + send_backward(input_tensor_grad, recv_tensor_shapes, config) + else: ++ recv_tensor_shapes = get_recv_tensor_shapes(next_forward_k) # fwd recv shape + input_tensor = send_backward_recv_forward( + input_tensor_grad, recv_tensor_shapes, config + ) +@@ -1302,6 +1332,7 @@ def forward_backward_pipelining_without_interleaving( + # Run cooldown backward passes. + if not forward_only: + for i in range(num_warmup_microbatches): ++ backward_k = num_microbatches_remaining + i + + # Enable async grad reduction in the last backward pass + # Note: If grad sync function is provided, only enable +@@ -1315,12 +1346,14 @@ def forward_backward_pipelining_without_interleaving( + input_tensor = input_tensors.pop(0) + output_tensor = output_tensors.pop(0) + ++ send_tensor_shapes = get_send_tensor_shapes(backward_k) # bwd recv shape + output_tensor_grad = recv_backward(send_tensor_shapes, config) + + input_tensor_grad = backward_step( + input_tensor, output_tensor, output_tensor_grad, model_type, config + ) + ++ recv_tensor_shapes = get_recv_tensor_shapes(backward_k) # bwd send shape + send_backward(input_tensor_grad, recv_tensor_shapes, config) + + # Launch any remaining grad reductions. +diff --git a/megatron/core/utils.py b/megatron/core/utils.py +index 44abd182..969f0112 100644 +--- a/megatron/core/utils.py ++++ b/megatron/core/utils.py +@@ -56,7 +56,7 @@ def get_model_type(model): + + + def get_model_config(model): +- return get_attr_wrapped_model(model, 'config', allow_none=False) ++ return get_attr_wrapped_model(model, 'megatron_config', allow_none=False) + + + class GlobalMemoryBuffer: +diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py +index 6e3ff990..e190ee1a 100644 +--- a/megatron/training/arguments.py ++++ b/megatron/training/arguments.py +@@ -49,7 +49,19 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): + + # Parse. + if ignore_unknown_args: +- args, _ = parser.parse_known_args() ++ # mindspeed need some args & global var so far, we need a more in-depth analysis to remove this part of the code. ++ arg_list = [ ++ "--tensor-model-parallel-size", "2", ++ "--pipeline-model-parallel-size", "1", ++ "--num-layers", "16", ++ "--hidden-size", "2048", ++ "--num-attention-heads", "32", ++ "--seq-length", "512", ++ "--max-position-embeddings", "131072", ++ "--micro-batch-size", "8", ++ ] ++ args, _ = parser.parse_known_args(arg_list) ++ # args, _ = parser.parse_known_args() + else: + args = parser.parse_args() + +diff --git a/setup.py b/setup.py +index 2071a62c..3c6cde5a 100644 +--- a/setup.py ++++ b/setup.py +@@ -113,7 +113,8 @@ setuptools.setup( + 'Natural Language :: English', + 'Operating System :: OS Independent', + ], +- packages=setuptools.find_namespace_packages(include=["megatron.core", "megatron.core.*"]), ++ # make training and other megatron module available when use `pip install -e .` ++ packages=setuptools.find_namespace_packages(include=["megatron.core", "megatron.core.*", "megatron.training", "megatron.legacy", "megatron.inference"]), + ext_modules=[ + Extension( + "megatron.core.datasets.helpers", diff --git a/requirements.txt b/requirements.txt index caab0e2f..0a063421 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,6 @@ pybind11 ray>=2.38 tensordict<0.6 transformers<4.48 -vllm<=0.6.3 +# vllm<=0.6.3 wandb liger-kernel diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_loader.py b/verl/models/llama/megatron/checkpoint_utils/llama_loader.py index 00fb0a9c..a3f6407c 100644 --- a/verl/models/llama/megatron/checkpoint_utils/llama_loader.py +++ b/verl/models/llama/megatron/checkpoint_utils/llama_loader.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pkg_resources import torch import time from typing import Dict, Any, Callable, Optional import torch.distributed as dist +megatron_version = pkg_resources.get_distribution('megatron_core').version def _megatron_calc_layer_map(config): """Calculate the mapping of global layer_idx to local layer_idx @@ -53,7 +55,10 @@ def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params """ import megatron from megatron.core import mpu - from megatron.utils import print_rank_0, unwrap_model + if pkg_resources.parse_version(megatron_version) < pkg_resources.parse_version('0.6.0'): + from megatron.utils import print_rank_0, unwrap_model + else: + from megatron.training.utils import print_rank_0, unwrap_model from megatron.core.transformer.module import Float16Module from megatron.core import DistributedDataParallel as LocalDDP from torch.nn.parallel import DistributedDataParallel as torchDDP diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_saver.py b/verl/models/llama/megatron/checkpoint_utils/llama_saver.py index 0764b6fe..b24fb36b 100644 --- a/verl/models/llama/megatron/checkpoint_utils/llama_saver.py +++ b/verl/models/llama/megatron/checkpoint_utils/llama_saver.py @@ -14,9 +14,9 @@ import megatron from megatron.core import mpu -from megatron.utils import print_rank_0, unwrap_model from megatron.model import Float16Module from megatron.model import DistributedDataParallel as LocalDDP +import pkg_resources from torch.nn.parallel import DistributedDataParallel as torchDDP import torch import time @@ -24,6 +24,12 @@ import torch.distributed as dist from megatron import get_args +megatron_version = pkg_resources.get_distribution('megatron_core').version + +if pkg_resources.parse_version(megatron_version) < pkg_resources.parse_version('0.6.0'): + from megatron.utils import print_rank_0, unwrap_model +else: + from megatron.training.utils import print_rank_0, unwrap_model def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): """given TP,DP,PP rank to get the global rank.""" diff --git a/verl/models/llama/megatron/layers/parallel_attention.py b/verl/models/llama/megatron/layers/parallel_attention.py index f14653fc..8244bfaa 100644 --- a/verl/models/llama/megatron/layers/parallel_attention.py +++ b/verl/models/llama/megatron/layers/parallel_attention.py @@ -21,7 +21,9 @@ import math from typing import Optional, Tuple +import numpy as np import torch +from importlib.util import find_spec from megatron.core import parallel_state as mpu from megatron.core import tensor_parallel from megatron.core import ModelParallelConfig @@ -32,6 +34,15 @@ from verl.utils.megatron import tensor_parallel as tp_utils +def is_package_present(package_name: str) -> bool: + try: + return find_spec(package_name) is not None + except ModuleNotFoundError: + return False + +if is_package_present("torch_npu"): + import torch_npu + class LlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): @@ -202,7 +213,7 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): self._init_rope() def _init_rope(self): - if self.config.rope_scaling is None: + if self.config.rope_scaling is None or self.config.rope_scaling.get("rope_type", None) == "llama3": self.rotary_emb = LlamaRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, @@ -289,13 +300,116 @@ def forward( from transformers.utils import is_flash_attn_2_available import torch.nn.functional as F -from einops import rearrange +from einops import rearrange, repeat if is_flash_attn_2_available(): from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - + from flash_attn.layers.rotary import apply_rotary_emb +else: + # flash-attn can't install on npu, so we need copy the logic to here. + class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather( + rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) + ).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], + device=grad_output.device, + dtype=grad_output.dtype, + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + index_first_axis = IndexFirstAxis.apply + + class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros( + first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + index_put_first_axis = IndexPutFirstAxis.apply + + def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[-1] + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + def unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length): batch_size = position_ids.shape[0] @@ -312,9 +426,6 @@ def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_l return q_embed, k_embed -from flash_attn.layers.rotary import apply_rotary_emb - - # use flash-attn rotary embeddings with rmpad # cos/sin shoudl be: (seq_length, rotary_dim / 2) def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): @@ -368,14 +479,25 @@ def forward(self, value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) - cos, sin = cos[:, :cos.shape[1] // 2], sin[:, :sin.shape[1] // 2] # flash attn only needs half - query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states, - key_states, - cos, - sin, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen_in_batch) - # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices, + if is_flash_attn_2_available(): + cos, sin = cos[:, :cos.shape[1] // 2], sin[:, :sin.shape[1] // 2] # flash attn only needs half + query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states, + key_states, + cos, + sin, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen_in_batch) + else: + query_states, key_states = apply_rotary_pos_emb_rmpad( + query_states, + key_states, + cos, + sin, + position_ids, + indices, + # hard coding so far + sequence_length=1024 + ) # TODO: llama does not have dropout in the config?? # It is recommended to use dropout with FA according to the docs @@ -393,18 +515,38 @@ def forward(self, key_states = key_states.to(torch.float16) value_states = value_states.to(torch.float16) - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen_in_batch, - max_seqlen_k=max_seqlen_in_batch, - dropout_p=dropout_rate, - softmax_scale=None, - causal=True, - ) + if is_flash_attn_2_available(): + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen_in_batch, + max_seqlen_k=max_seqlen_in_batch, + dropout_p=dropout_rate, + softmax_scale=None, + causal=True, + ) + elif is_package_present("torch_npu"): + attention_mask_npu = torch.from_numpy(np.triu(np.ones([max_seqlen_in_batch, max_seqlen_in_batch]), k=1)).to(bool).npu() + head_num = query_states.shape[1] + attn_output_unpad = torch_npu.npu_fusion_attention( + query_states, + key_states, + value_states, + head_num, + atten_mask=attention_mask_npu, + scale=1.0 / math.sqrt(query_states.shape[-1]), + keep_prob=1, + input_layout="TND", + actual_seq_qlen=tuple(cu_seqlens[1:].cpu().numpy().tolist()), + actual_seq_kvlen=tuple(cu_seqlens[1:].cpu().numpy().tolist()), + pre_tockens=2147483647, + next_tockens=0 + )[0] + else: + RuntimeError("No flash-attn kernel available, install flash-attn on gpu on install torch_npu on npu.") attn_output_unpad = attn_output_unpad.to(input_dtype) attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous() diff --git a/verl/models/llama/megatron/layers/parallel_rmsnorm.py b/verl/models/llama/megatron/layers/parallel_rmsnorm.py index 7027036b..6066d07b 100644 --- a/verl/models/llama/megatron/layers/parallel_rmsnorm.py +++ b/verl/models/llama/megatron/layers/parallel_rmsnorm.py @@ -12,15 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.util import numbers import torch from megatron.core import ModelParallelConfig from torch import nn from transformers import LlamaConfig -from apex.normalization.fused_layer_norm import fused_rms_norm_affine +apex_available = True +try: + from apex.normalization.fused_layer_norm import fused_rms_norm_affine +except ImportError: + apex_available = False + from verl.utils.megatron import sequence_parallel as sp_utils +def package_exists(package_name): + package_spec = importlib.util.find_spec(package_name) + return package_spec is not None class ParallelLlamaRMSNorm(nn.Module): @@ -37,10 +46,16 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): if megatron_config.sequence_parallel: sp_utils.mark_parameter_as_sequence_parallel(self.weight) + if not apex_available: + self.rms_norm_func = nn.RMSNorm(self.normalized_shape, eps=self.variance_epsilon) + self.rms_norm_func.weight = self.weight def forward(self, hidden_states): - return fused_rms_norm_affine(input=hidden_states, - weight=self.weight, - normalized_shape=self.normalized_shape, - eps=self.variance_epsilon, - memory_efficient=True) \ No newline at end of file + if apex_available: + return fused_rms_norm_affine(input=hidden_states, + weight=self.weight, + normalized_shape=self.normalized_shape, + eps=self.variance_epsilon, + memory_efficient=True) + else: + return self.rms_norm_func(hidden_states) diff --git a/verl/models/llama/megatron/modeling_llama_megatron.py b/verl/models/llama/megatron/modeling_llama_megatron.py index c693f33c..e7b8d7c3 100644 --- a/verl/models/llama/megatron/modeling_llama_megatron.py +++ b/verl/models/llama/megatron/modeling_llama_megatron.py @@ -29,12 +29,13 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import CausalLMOutputWithPast +from transformers.utils import is_flash_attn_2_available from verl.utils.megatron import sequence_parallel as sp_utils from verl.utils.megatron import tensor_parallel as tp_utils from .layers import ParallelLlamaDecoderLayer, ParallelLlamaRMSNorm, ParallelLlamaDecoderLayerRmPad """ -TODO: +TODO: 1. Add weight initialization. Here we need to be careful on TP weight init. 2. Add sequence parallel 3. Load checkpoint from meta LLama pretrained checkpoint @@ -208,9 +209,10 @@ def forward( attentions=None, ) - -from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - +if is_flash_attn_2_available(): + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +else: + from .layers.parallel_attention import pad_input, unpad_input class ParallelLlamaModelRmPad(nn.Module): """ diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index ad6bab93..9de9b0c8 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -127,8 +127,10 @@ def __init__(self, cuda_visible_devices=None) -> None: master_addr = os.environ["MASTER_ADDR"] master_port = os.environ["MASTER_PORT"] - local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1")) - local_rank = int(os.getenv("LOCAL_RANK", "0")) + # HCCL communication on npu need to unset ASCEND_RT_VISIBLE_DEVICES (unlike CUDA_VISIBLE_DEVICES) in ray actor. + # so we need to RAY_LOCAL_RANK (instead of always 0 in ray on gpu) as device id. + local_world_size = int(os.getenv("RAY_LOCAL_WORLD_SIZE", "1")) + local_rank = int(os.getenv("RAY_LOCAL_RANK", "0")) store = { '_world_size': world_size, diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index eaa1b00d..96614a6d 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -70,7 +70,8 @@ def get_placement_groups(self, strategy="STRICT_PACK", name=None): # print(f"pg_name_prefix = {pg_name_prefix}") pg_scheme = [[{ "CPU": self.max_collocate_count, - "GPU": 1 + # "GPU": 1 + "NPU": 1 } if self.use_gpu else { "CPU": self.max_collocate_count } for _ in range(process_count)] for process_count in self._store] @@ -161,7 +162,8 @@ def __call__(self, options.update(self._options) if use_gpu: - options["num_gpus"] = num_gpus + # options["num_gpus"] = num_gpus + options["resources"] = {"NPU": num_gpus} if len(self._additional_resource) > 1: for k, v in self._additional_resource.items(): diff --git a/verl/third_party/vllm/__init__.py b/verl/third_party/vllm/__init__.py index 290c8378..cf1ca126 100644 --- a/verl/third_party/vllm/__init__.py +++ b/verl/third_party/vllm/__init__.py @@ -45,6 +45,12 @@ def get_version(pkg): from .vllm_v_0_6_3.llm import LLM from .vllm_v_0_6_3.llm import LLMEngine from .vllm_v_0_6_3 import parallel_state +# this package support npu could be build via https://github.com/vllm-project/vllm/pull/8054 +elif package_version == '0.1dev3628+g06f1b1d.npu': + vllm_version = '0.6.4' + from .vllm_v_0_6_4.llm import LLM + from .vllm_v_0_6_4.llm import LLMEngine + from .vllm_v_0_6_4 import parallel_state else: raise ValueError( f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4 and 0.6.3.' diff --git a/verl/third_party/vllm/vllm_v_0_6_4/__init__.py b/verl/third_party/vllm/vllm_v_0_6_4/__init__.py new file mode 100644 index 00000000..1ce90c5e --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_4/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/third_party/vllm/vllm_v_0_6_4/arg_utils.py b/verl/third_party/vllm/vllm_v_0_6_4/arg_utils.py new file mode 100644 index 00000000..365798fe --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_4/arg_utils.py @@ -0,0 +1,82 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py + +import os +from dataclasses import dataclass + +from transformers import PretrainedConfig +from vllm.config import VllmConfig +from vllm.engine.arg_utils import EngineArgs + +from .config import LoadConfig, ModelConfig + + +@dataclass +class EngineArgs(EngineArgs): + model_hf_config: PretrainedConfig = None # for verl + + def __post_init__(self): + pass + + def create_model_config(self) -> ModelConfig: + return ModelConfig( + hf_config=self.model_hf_config, + task=self.task, + # We know this is not None because we set it in __post_init__ + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + allowed_local_media_path=self.allowed_local_media_path, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + hf_overrides=self.hf_overrides, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + quantization_param_path=self.quantization_param_path, + enforce_eager=self.enforce_eager, + max_seq_len_to_capture=self.max_seq_len_to_capture, + max_logprobs=self.max_logprobs, + disable_sliding_window=self.disable_sliding_window, + skip_tokenizer_init=self.skip_tokenizer_init, + served_model_name=self.served_model_name, + limit_mm_per_prompt=self.limit_mm_per_prompt, + use_async_output_proc=not self.disable_async_output_proc, + config_format=self.config_format, + mm_processor_kwargs=self.mm_processor_kwargs, + override_neuron_config=self.override_neuron_config, + override_pooler_config=self.override_pooler_config, + ) + + def create_load_config(self) -> LoadConfig: + return LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + ) + + def create_engine_config(self) -> VllmConfig: + engine_config = super().create_engine_config() + + # NOTE[VERL]: Use the world_size set by torchrun + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + engine_config.parallel_config.world_size = world_size + + return engine_config diff --git a/verl/third_party/vllm/vllm_v_0_6_4/config.py b/verl/third_party/vllm/vllm_v_0_6_4/config.py new file mode 100644 index 00000000..4173e44d --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_4/config.py @@ -0,0 +1,112 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py + +import enum +import json +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, Union + +from transformers import PretrainedConfig + +# Add for verl +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform + +if TYPE_CHECKING: + from vllm.model_executor.model_loader.loader import BaseModelLoader + +logger = init_logger(__name__) + + +class LoadFormat(str, enum.Enum): + AUTO = "auto" + MEGATRON = "megatron" + HF = "hf" + DTENSOR = "dtensor" + DUMMY_HF = "dummy_hf" + DUMMY_MEGATRON = "dummy_megatron" + DUMMY_DTENSOR = "dummy_dtensor" + + +class ModelConfig(ModelConfig): + + def __init__(self, hf_config: PretrainedConfig, *args, **kwargs) -> None: + super().__init__(model=hf_config._name_or_path, tokenizer=hf_config._name_or_path, *args, **kwargs) + self.hf_config = hf_config + + +@dataclass +class LoadConfig: + """ + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. + "bitsandbytes" will load nf4 type weights. + ignore_patterns: The list of patterns to ignore when loading the model. + Default to "original/**/*" to avoid repeated loading of llama's + checkpoints. + """ + + load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO + download_dir: Optional[str] = None + model_loader_extra_config: Optional[Union[str, dict]] = field( + default_factory=dict) + ignore_patterns: Optional[Union[List[str], str]] = None + + def __post_init__(self): + model_loader_extra_config = self.model_loader_extra_config or {} + if isinstance(model_loader_extra_config, str): + self.model_loader_extra_config = json.loads( + model_loader_extra_config) + self._verify_load_format() + + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info( + "Ignoring the following patterns when downloading weights: %s", + self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] + + def _verify_load_format(self) -> None: + if not isinstance(self.load_format, str): + return + + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + rocm_not_supported_load_format: List[str] = [] + if current_platform.is_rocm( + ) and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f for f in LoadFormat.__members__ + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format '{load_format}' is not supported in ROCm. " + f"Supported load formats are " + f"{rocm_supported_load_format}") + diff --git a/verl/third_party/vllm/vllm_v_0_6_4/dtensor_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_4/dtensor_weight_loaders.py new file mode 100644 index 00000000..a3042cab --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_4/dtensor_weight_loaders.py @@ -0,0 +1,380 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader + +from typing import Dict + +import torch.nn as nn +from torch.distributed._tensor import DTensor +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import is_pp_missing_parameter + + +def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + for param_name, shard_name, shard_id in stacked_params_mapping: + if shard_name not in name: + continue + stacked_name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if stacked_name.endswith(".bias") and stacked_name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[stacked_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "lm_head.weight" in name: + continue + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight) + + +def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def qwen2vl_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +from vllm.model_executor.layers.fused_moe import FusedMoE + + +def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=vllm_model.config.n_routed_experts, + ) + + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + local_loaded_weight.to(dtype=param.dtype), + weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + pass + + +def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None): + param_name = _process_parameter_names(name=param_name) + if parallelize_plan is not None: + assert ( + param_name + in parallelize_plan.keys()), f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}" + placement = parallelize_plan[param_name] + local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh, + placements=placement).to_local() + else: + local_loaded_weights = loaded_weights.full_tensor() + return local_loaded_weights + + +def _process_parameter_names(name): + # Remove '.weight' if it exists at the end of the string + if name.endswith(".weight"): + name = name[:-7] + + # Remove 'model.layers.x.' or 'model.' prefix + if "model.layers" in name: + parts = name.split(".") + # Reconstruct the string without 'model.layers.x.' + name = ".".join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x' + elif name.startswith("model."): + name = name[6:] # Remove 'model.' + + return name + + +__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = { + "GPT2LMHeadModel": gpt2_dtensor_weight_loader, + "LlamaForCausalLM": llama_dtensor_weight_loader, + "LLaMAForCausalLM": llama_dtensor_weight_loader, + "MistralForCausalLM": llama_dtensor_weight_loader, # mistral is the same as llama in vLLM + "InternLMForCausalLM": llama_dtensor_weight_loader, + "AquilaModel": llama_dtensor_weight_loader, + "AquilaForCausalLM": llama_dtensor_weight_loader, + "Phi3ForCausalLM": llama_dtensor_weight_loader, + "GemmaForCausalLM": gemma_dtensor_weight_loader, + "Gemma2ForCausalLM": gemma_dtensor_weight_loader, + "GPTBigCodeForCausalLM": gptbigcode_dtensor_load_weights, + "Starcoder2ForCausalLM": starcoder2_dtensor_load_weights, + "Qwen2ForCausalLM": qwen2_dtensor_weight_loader, + "DeepseekV2ForCausalLM": deepseekv2_dtensor_weight_loader, + "Qwen2VLForConditionalGeneration": qwen2vl_dtensor_weight_loader, +} + + +# the actor model is .state_dict() +# Load dtensor weights +def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module): + weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__) + weight_loader(actor_weights, vllm_model) + # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu + # after init, and we need this after sync model weights for in first iter. + vllm_model = vllm_model.cuda() + + +def _get_model_weight_loader(arch: str): + if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__: + return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch] + raise ValueError(f"Model architectures {arch} are not supported for now. " + f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}") + + +# NOTE(sgm): we use per-parameter weight loader in each vllm sub +def update_dtensor_weight_loader(): + pass diff --git a/verl/third_party/vllm/vllm_v_0_6_4/hf_weight_loader.py b/verl/third_party/vllm/vllm_v_0_6_4/hf_weight_loader.py new file mode 100644 index 00000000..a3e5b22b --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_4/hf_weight_loader.py @@ -0,0 +1,41 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader + +from typing import Dict + +import torch.nn as nn +from vllm.model_executor.model_loader.utils import set_default_torch_dtype + + +def update_hf_weight_loader(): + print("no hf weight loader need to be updated") + return + + +def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): + assert isinstance(actor_weights, Dict) + with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys(): + del actor_weights["lm_head.weight"] + vllm_model.load_weights(actor_weights.items()) + for _, module in vllm_model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + vllm_model = vllm_model.cuda() diff --git a/verl/third_party/vllm/vllm_v_0_6_4/llm.py b/verl/third_party/vllm/vllm_v_0_6_4/llm.py new file mode 100644 index 00000000..89d1175c --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_4/llm.py @@ -0,0 +1,277 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py + +from typing import (Any, Dict, List, Optional, Tuple, Type, Union) + +import json +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence +from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast +from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer +from vllm import LLM +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.utils import Counter +from vllm.config import CompilationConfig +from vllm.engine.arg_utils import (HfOverrides, PoolerConfig, TaskOption) + +from .arg_utils import EngineArgs +from .llm_engine_sp import LLMEngine + + + + +from tqdm import tqdm +from vllm import envs + + + + + + + + + + +class LLM(LLM): + """An LLM for generating texts from given prompts and sampling parameters. + + This class includes a tokenizer, a language model (possibly distributed + across multiple GPUs), and GPU memory space allocated for intermediate + states (aka KV cache). Given a batch of prompts and sampling parameters, + this class generates texts from the model, using an intelligent batching + mechanism and efficient memory management. + + Args: + model: The name or path of a HuggingFace Transformers model. + tokenizer: The name or path of a HuggingFace Transformers tokenizer. + tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer + if available, and "slow" will always use the slow tokenizer. + skip_tokenizer_init: If true, skip initialization of tokenizer and + detokenizer. Expect valid prompt_token_ids and None for prompt + from the input. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + allowed_local_media_path: Allowing API requests to read local images + or videos from directories specified by the server file system. + This is a security risk. Should only be enabled in trusted + environments. + tensor_parallel_size: The number of GPUs to use for distributed + execution with tensor parallelism. + dtype: The data type for the model weights and activations. Currently, + we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + the `torch_dtype` attribute specified in the model config file. + However, if the `torch_dtype` in the config is `float32`, we will + use `float16` instead. + quantization: The method used to quantize the model weights. Currently, + we support "awq", "gptq", and "fp8" (experimental). + If None, we first check the `quantization_config` attribute in the + model config file. If that is None, we assume the model weights are + not quantized and use `dtype` to determine the data type of + the weights. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. + seed: The seed to initialize the random number generator for sampling. + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to + reserve for the model weights, activations, and KV cache. Higher + values will increase the KV cache size and thus improve the model's + throughput. However, if the value is too high, it may cause out-of- + memory (OOM) errors. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + This can be used for temporarily storing the states of the requests + when their `best_of` sampling parameters are larger than 1. If all + requests will have `best_of=1`, you can safely set this to 0. + Otherwise, too small values may cause out-of-memory (OOM) errors. + cpu_offload_gb: The size (GiB) of CPU memory to use for offloading + the model weights. This virtually increases the GPU memory space + you can use to hold the model weights, at the cost of CPU-GPU data + transfer for every forward pass. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. Additionally for encoder-decoder models, if the + sequence length of the encoder input is larger than this, we fall + back to the eager mode. + disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig` + disable_async_output_proc: Disable async output processing. + This may result in lower performance. + hf_overrides: If a dictionary, contains arguments to be forwarded to the + HuggingFace config. If a callable, it is called to update the + HuggingFace config. + compilation_config: Either an integer or a dictionary. If it is an + integer, it is used as the level of compilation optimization. If it + is a dictionary, it can specify the full compilation configuration. + **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See + :ref:`engine_args`) + + Note: + This class is intended to be used for offline inference. For online + serving, use the :class:`~vllm.AsyncLLMEngine` class instead. + """ + + def __init__( + self, + model: Union[nn.Module, Dict], # model itself or its parameter dict + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer], + model_hf_config: PretrainedConfig, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + allowed_local_media_path: str = "", + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + disable_async_output_proc: bool = False, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + # After positional args are removed, move this right below `model` + task: TaskOption = "generate", + override_pooler_config: Optional[PoolerConfig] = None, + compilation_config: Optional[Union[int, Dict[str, Any]]] = None, + **kwargs, + ) -> None: + ''' + LLM constructor. + + Note: if enforce_eager is unset (enforce_eager is None) + it defaults to False. + ''' + + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + removed_vision_keys = ("image_token_id", "image_feature_size", "image_input_shape", "image_input_type") + if any(k in kwargs for k in removed_vision_keys): + raise TypeError("There is no need to pass vision-related arguments anymore.") + + if compilation_config is not None: + compilation_config_instance = CompilationConfig.from_cli( + json.dumps(compilation_config)) + else: + compilation_config_instance = None + + engine_args = EngineArgs( + model_hf_config=model_hf_config, + task=task, + # tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + allowed_local_media_path=allowed_local_media_path, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + disable_async_output_proc=disable_async_output_proc, + hf_overrides=hf_overrides, + mm_processor_kwargs=mm_processor_kwargs, + override_pooler_config=override_pooler_config, + compilation_config=compilation_config_instance, + **kwargs, + ) + tokenizer_cls = (PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer) + if not isinstance(tokenizer, tokenizer_cls): + raise ValueError( + f"Unexpected tokenizer type: {type(tokenizer)}. Must be" + "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.ppo.rollout.HybridEngineBaseTokenizer" + ) + self.engine_class = self.get_engine_class() + self.llm_engine = self.engine_class.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext + self.request_counter = Counter() + + @staticmethod + def get_engine_class() -> Type[LLMEngine]: + if envs.VLLM_USE_V1: + # Lazy import: the v1 package isn't distributed + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + return V1LLMEngine # type: ignore + return LLMEngine + + def init_cache_engine(self): + self.llm_engine.init_cache_engine() + + def free_cache_engine(self): + self.llm_engine.free_cache_engine() + + def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + return self.llm_engine.tokenizer + + def set_tokenizer( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + ) -> None: + self.llm_engine.tokenizer = tokenizer + + def _run_engine(self, *, use_tqdm: bool) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + outputs = super()._run_engine(use_tqdm=use_tqdm) + return self._post_process_outputs(outputs) + + # # NOTE(shengguangming): add for verl + # # TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding. + # def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]: + # # remove the left padding in the prompt token_id + # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + # non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] + # token_ids = prompt_token_ids[non_pad_index:].tolist() + # return token_ids + + # NOTE(shengguangming): add for verl + def _post_process_outputs(self, request_outputs: List[RequestOutput]) -> Tuple[torch.Tensor, torch.Tensor]: + output_token_ids = [] + logprobs = [] + for request_output in request_outputs: # List[RequestOutput] + outputs = request_output.outputs + for output in outputs: # List[CompletionOutput], usually len == 1 + output_token_ids.append(torch.tensor(output.token_ids)) + # TODO(shengguangming): can be optimzied by rewrite the Sampler._get_logprobs() logits + logprobs_dicts = output.logprobs + if logprobs_dicts is not None: + logprob = [] + for logprobs_dict, id in zip(logprobs_dicts, output.token_ids): + logprob.append(logprobs_dict[id].logprob) + logprobs.append(torch.tensor(logprob)) + + pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + # pad_token_id = 128004 + output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id) + if len(logprobs) > 0: + logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id) + return output_token_ids, logprobs + + def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None: + self.llm_engine.sync_model_weights(actor_weights=actor_weights, load_format=load_format) + + def offload_model_weights(self) -> None: + self.llm_engine.offload_model_weights() diff --git a/verl/third_party/vllm/vllm_v_0_6_4/llm_engine_sp.py b/verl/third_party/vllm/vllm_v_0_6_4/llm_engine_sp.py new file mode 100644 index 00000000..b4d533a4 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_4/llm_engine_sp.py @@ -0,0 +1,423 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py + +from functools import partial +from typing import Dict, Optional, Union, Type, Callable + +import torch +import torch.nn as nn +from typing_extensions import TypeVar + +from vllm.config import (DecodingConfig, + ObservabilityConfig, + VllmConfig) + +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.executor.executor_base import ExecutorBase +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.logger import init_logger +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.engine.metrics import StatLoggerBase +from vllm.inputs.preprocess import InputPreprocessor +from vllm.tracing import init_tracer +from vllm.sequence import Sequence +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) +from vllm.utils import Counter, weak_bind +from vllm.engine.llm_engine import LLMEngine, SchedulerContext, SchedulerOutputState, _load_generation_config_dict +from vllm.version import __version__ as VLLM_VERSION + +from .arg_utils import EngineArgs +from .tokenizer import TokenizerGroup +from .config import ModelConfig, LoadConfig + +logger = init_logger(__name__) +_LOCAL_LOGGING_INTERVAL_SEC = 5 + +class LLMEngine(LLMEngine): + """An LLM engine that receives requests and generates texts. + + This is the main class for the vLLM engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + The :class:`~vllm.LLM` class wraps this class for offline batched inference + and the :class:`AsyncLLMEngine` class wraps this class for online serving. + + The config arguments are derived from :class:`~vllm.EngineArgs`. (See + :ref:`engine_args`) + + Args: + model: the actor model initialize outside vllm (add for verl) + tokenizer: the initialized tokenizer (add for verl) + model_config: The configuration related to the LLM model. + cache_config: The configuration related to the KV cache memory + management. + parallel_config: The configuration related to distributed execution. + scheduler_config: The configuration related to the request scheduler. + device_config: The configuration related to the device. + lora_config (Optional): The configuration related to serving multi-LoRA. + speculative_config (Optional): The configuration related to speculative + decoding. + executor_class: The model executor class for managing distributed + execution. + prompt_adapter_config (Optional): The configuration related to serving + prompt adapters. + log_stats: Whether to log statistics. + usage_context: Specified entry point, used for usage info collection. + """ + + def __init__( + self, + # NOTE(sgm): first two arguments are added for verl + model: Union[nn.Module, Dict], # model itself or its parameter dict + tokenizer: nn.Module, + # NOTE(sgm): vllm original arguments + vllm_config: VllmConfig, + executor_class: Type[ExecutorBase], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + use_cached_outputs: bool = False, + ) -> None: + + # TODO: remove the local variables and use self.* throughout the class. + model_config = self.model_config = vllm_config.model_config + cache_config = self.cache_config = vllm_config.cache_config + lora_config = self.lora_config = vllm_config.lora_config + parallel_config = self.parallel_config = vllm_config.parallel_config + scheduler_config = self.scheduler_config = vllm_config.scheduler_config + device_config = self.device_config = vllm_config.device_config + speculative_config = self.speculative_config = vllm_config.speculative_config # noqa + load_config = self.load_config = vllm_config.load_config + decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa + ) + prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa + observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa + ) + + logger.info( + "Initializing an LLM engine (v%s) with config: " + "model=%r, speculative_config=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "override_neuron_config=%s, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "pipeline_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " + "quantization_param_path=%s, device_config=%s, " + "decoding_config=%r, observability_config=%r, " + "seed=%d, served_model_name=%s, " + "num_scheduler_steps=%d, chunked_prefill_enabled=%s " + "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " + "use_async_output_proc=%s, use_cached_outputs=%s, " + "mm_processor_kwargs=%s, pooler_config=%r," + "compilation_config=%r", + VLLM_VERSION, + model_config.model, + speculative_config, + model_config.tokenizer, + model_config.skip_tokenizer_init, + model_config.tokenizer_mode, + model_config.revision, + model_config.override_neuron_config, + model_config.tokenizer_revision, + model_config.trust_remote_code, + model_config.dtype, + model_config.max_model_len, + load_config.download_dir, + load_config.load_format, + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.disable_custom_all_reduce, + model_config.quantization, + model_config.enforce_eager, + cache_config.cache_dtype, + model_config.quantization_param_path, + device_config.device, + decoding_config, + observability_config, + model_config.seed, + model_config.served_model_name, + scheduler_config.num_scheduler_steps, + scheduler_config.chunked_prefill_enabled, + scheduler_config.multi_step_stream_outputs, + cache_config.enable_prefix_caching, + model_config.use_async_output_proc, + use_cached_outputs, + model_config.mm_processor_kwargs, + model_config.pooler_config, + vllm_config.compilation_config, + ) + # TODO(woosuk): Print more configs in debug mode. + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.speculative_config = speculative_config + self.load_config = load_config + self.decoding_config = decoding_config or DecodingConfig() + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config or ObservabilityConfig( + ) + self.log_stats = log_stats + self.use_cached_outputs = use_cached_outputs + + # TODO(shengguangming): maybe we can choose init here or from arguments + if not self.model_config.skip_tokenizer_init: + self.tokenizer = self._init_tokenizer(tokenizer) + self.detokenizer = Detokenizer(self.tokenizer) + tokenizer_group = self.get_tokenizer_group() + else: + self.tokenizer = None + self.detokenizer = None + tokenizer_group = None + + + + # Ensure that the function doesn't contain a reference to self, + # to avoid engine GC issues + def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: + assert tokenizer_group, ("tokenizer_group cannot be None, " + "make sure skip_tokenizer_init is False") + return tokenizer_group.get_lora_tokenizer(sequence.lora_request) + + self.seq_counter = Counter() + self.generation_config_fields = _load_generation_config_dict( + model_config) + + self.input_preprocessor = InputPreprocessor(model_config, + self.tokenizer, + mm_registry) + + self.input_registry = input_registry + self.input_processor = input_registry.create_input_processor( + model_config) + + self.model_executor = executor_class(model=model, # add for spmd_gpu_executor + vllm_config=vllm_config, ) + + if self.model_config.task != "embedding": + self._initialize_kv_caches() + + # If usage stat is enabled, collect relevant info. + if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import ( + get_architecture_class_name) + usage_message.report_usage( + get_architecture_class_name(model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": + str(model_config.dtype), + "tensor_parallel_size": + parallel_config.tensor_parallel_size, + "block_size": + cache_config.block_size, + "gpu_memory_utilization": + cache_config.gpu_memory_utilization, + + # Quantization + "quantization": + model_config.quantization, + "kv_cache_dtype": + str(cache_config.cache_dtype), + + # Feature flags + "enable_lora": + bool(lora_config), + "enable_prompt_adapter": + bool(prompt_adapter_config), + "enable_prefix_caching": + cache_config.enable_prefix_caching, + "enforce_eager": + model_config.enforce_eager, + "disable_custom_all_reduce": + parallel_config.disable_custom_all_reduce, + }) + + if self.tokenizer: + # Ping the tokenizer to ensure liveness if it runs in a + # different process. + self.tokenizer.ping() + + self.cached_scheduler_outputs = [ + SchedulerOutputState() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + self.scheduler_contexts = [ + SchedulerContext(multi_step_stream_outputs=self.scheduler_config. + multi_step_stream_outputs) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + if model_config.use_async_output_proc: + process_model_outputs = weak_bind(self._process_model_outputs) + + self.async_callbacks = [ + partial(process_model_outputs, + ctx=self.scheduler_contexts[v_id]) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + else: + self.async_callbacks = [] + + # Currently used by AsyncLLMEngine to ensure quick append + # of request outputs to asyncio queues + self.process_request_outputs_callback: Optional[Callable] = None + + # Create the scheduler. + # NOTE: the cache_config here have been updated with the numbers of + # GPU and CPU blocks, which are profiled in the distributed executor. + self.scheduler = [ + Scheduler( + scheduler_config, cache_config, lora_config, + parallel_config.pipeline_parallel_size, + self.async_callbacks[v_id] + if model_config.use_async_output_proc else None) + for v_id in range(parallel_config.pipeline_parallel_size) + ] + + # Metric Logging. + if self.log_stats: + if stat_loggers is not None: + self.stat_loggers = stat_loggers + else: + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from vllm.engine.metrics import (LoggingStatLogger, + PrometheusStatLogger) + + self.stat_loggers = { + "logging": + LoggingStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC), + "prometheus": + PrometheusStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict(model_name=model_config.served_model_name), + max_model_len=self.model_config.max_model_len), + } + self.stat_loggers["prometheus"].info("cache_config", + self.cache_config) + + self.tracer = None + if self.observability_config.otlp_traces_endpoint: + self.tracer = init_tracer( + "vllm.llm_engine", + self.observability_config.otlp_traces_endpoint) + + # Create sequence output processor, e.g. for beam search or + # speculative decoding. + self.output_processor = ( + SequenceGroupOutputProcessor.create_output_processor( + self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + get_tokenizer_for_seq, + stop_checker=StopChecker( + self.scheduler_config.max_model_len, + get_tokenizer_for_seq, + ), + )) + + self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} + + + # TODO(sgm): add for verl but we may not tokenizer in Rollout + def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs): + init_kwargs = dict(enable_lora=bool(self.lora_config), + max_num_seqs=self.scheduler_config.max_num_seqs, + max_input_length=None) + init_kwargs.update(tokenizer_init_kwargs) + return TokenizerGroup(tokenizer, **init_kwargs) + + + def init_cache_engine(self): + # TODO: check whether we should rebuild the CUDAGraph every iter when offload/load KVCache + # Re-capture CUDAGraph would be time-consuming + self.model_executor.init_cache_engine() + + def free_cache_engine(self): + self.model_executor.free_cache_engine() + + # NOTE(sgm): currently, we only support GPU executor + # The GPUExecutor remove the Ray dependency + @classmethod + def _get_executor_cls(cls, engine_config: VllmConfig) -> Type[ExecutorBase]: + assert engine_config.device_config.device_type in ("cuda", "npu"), \ + "Currently, the vllm in verl only support running on GPU or NPU" + + if engine_config.parallel_config.world_size == 1: + engine_config.load_config.load_format = "dummy_hf" + + from .spmd_gpu_executor import SPMDGPUExecutor + executor_class = SPMDGPUExecutor + return executor_class + + @classmethod + def from_engine_args( + cls, + model, + tokenizer, + engine_args: EngineArgs, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_config = engine_args.create_engine_config() + executor_class = cls._get_executor_cls(engine_config) + # Initialize the cluster and specify the executor class. + assert engine_config.device_config.device_type in ("cuda", "npu"), \ + "Currently, the vllm in verl only support running on GPU and NPU" + + from .spmd_gpu_executor import SPMDGPUExecutor + executor_class = SPMDGPUExecutor + + # Create the LLM engine. + engine = cls( + model, + tokenizer, + vllm_config=engine_config, + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + + return engine + + def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None: + self.model_executor.sync_model_weights(actor_weights=actor_weights, load_format=load_format) + + def offload_model_weights(self) -> None: + self.model_executor.offload_model_weights() \ No newline at end of file diff --git a/verl/third_party/vllm/vllm_v_0_6_4/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_4/megatron_weight_loaders.py new file mode 100644 index 00000000..1382616f --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_4/megatron_weight_loaders.py @@ -0,0 +1,308 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader + +from typing import Dict + +import torch +import torch.nn as nn +from vllm.model_executor.layers.linear import * +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding +from vllm.model_executor.models import ModelRegistry + + +# NOTE(shengguangming): replace the origin weight loader function in the class +def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Parallel Linear weight loader.""" + assert (param.size() == loaded_weight.size( + )), "the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}".format( + param.size(), loaded_weight.size()) + assert (param.data.dtype == loaded_weight.data.dtype + ), "if we want to shared weights, the data type should also be the same" + + param.data = loaded_weight.data + + +def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + assert param.size() == loaded_weight.size() + assert (param.data.dtype == loaded_weight.data.dtype + ), "if we want to shared weights, the data type should also be the same" + + param.data = loaded_weight.data + + +def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if not name.startswith("transformer."): + name = "transformer." + name + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + # TODO: check megatron + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name or "lm_head.weight" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), + ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), + ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ( + "input_layernorm", + "input_layernorm", + ), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def _replace_name(megatron_name, name_mapping): + for m_name, v_name in name_mapping: + if m_name not in megatron_name: + continue + if "layers" in megatron_name: # deal with decoder layers + megatron_name = megatron_name.replace("decoder", "model") + megatron_name_list = megatron_name.split(".") + if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: + param_name_list = megatron_name_list[:3] + param_name_list.append(v_name) + param_name = ".".join(param_name_list) + else: + param_name_list = megatron_name_list[:3] + weight_or_bias = megatron_name_list[-1] + param_name_list.append(v_name) + param_name_list.append(weight_or_bias) + param_name = ".".join(param_name_list) + return param_name + else: + param_name = megatron_name.replace(m_name, v_name) + return param_name + + +def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), + ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), + ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ( + "input_layernorm", + "input_layernorm", + ), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def _replace_name(megatron_name, name_mapping): + for m_name, v_name in name_mapping: + if m_name not in megatron_name: + continue + if "layers" in megatron_name: # deal with decoder layers + megatron_name = megatron_name.replace("decoder", "model") + megatron_name_list = megatron_name.split(".") + if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: + param_name_list = megatron_name_list[:3] + param_name_list.append(v_name) + param_name = ".".join(param_name_list) + else: + param_name_list = megatron_name_list[:3] + weight_or_bias = megatron_name_list[-1] + param_name_list.append(v_name) + param_name_list.append(weight_or_bias) + param_name = ".".join(param_name_list) + return param_name + else: + param_name = megatron_name.replace(m_name, v_name) + return param_name + + +def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + # TODO: need to implement a general way to deal with prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +__LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = { + ColumnParallelLinear: parallel_weight_loader, + MergedColumnParallelLinear: parallel_weight_loader, + QKVParallelLinear: parallel_weight_loader, + RowParallelLinear: parallel_weight_loader, + VocabParallelEmbedding: parallel_weight_loader, + ParallelLMHead: parallel_weight_loader, + # "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights + # "default_weight_loader": default_weight_loader +} + +# for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items(): +# # setattr(layer_class, 'megatron_weight_loader', weight_loader) +# layer_class.weight_loader = weight_loader + +__MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = { + "GPT2LMHeadModel": gpt2_weight_loader, + "LlamaForCausalLM": llama_megatron_weight_loader, # use te backend for open-source megatron + "LLaMAForCausalLM": llama_megatron_weight_loader, + "MistralForCausalLM": mistral_megatron_weight_loader, +} + + +# the actor model is .state_dict() +# Load megatron weights +def load_megatron_weights(actor_weights: Dict, vllm_model: nn.Module): + weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__) + weight_loader(actor_weights, vllm_model) + # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu + # after init, and we need this after sync model weights for in first iter. + vllm_model = vllm_model.cuda() + + +def _get_model_weight_loader(arch: str): + if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__: + return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch] + raise ValueError(f"Model architectures {arch} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def update_megatron_weight_loader(): + for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items(): + layer_class.weight_loader = weight_loader diff --git a/verl/third_party/vllm/vllm_v_0_6_4/model_loader.py b/verl/third_party/vllm/vllm_v_0_6_4/model_loader.py new file mode 100644 index 00000000..64b4bf10 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_4/model_loader.py @@ -0,0 +1,279 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader + +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +from transformers import PreTrainedModel +from vllm.config import (LoadConfig, LoadFormat, ModelConfig, + VllmConfig) +from vllm.model_executor.model_loader import BaseModelLoader +from vllm.model_executor.model_loader.loader import _initialize_model +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.distributed.communication_op import tensor_model_parallel_all_gather + +from .config import ModelConfig, LoadFormat, LoadConfig +from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader +from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader +from .hf_weight_loader import update_hf_weight_loader + + +def get_model(actor_model: Union[PreTrainedModel, Dict], vllm_config: VllmConfig) -> nn.Module: + load_config = vllm_config.load_config + loader = get_model_loader(load_config) + if load_config.load_format.startswith('dummy'): + return loader.load_model(vllm_config=vllm_config) + else: + return loader.load_model(actor_model=actor_model, + vllm_config=vllm_config) + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.AUTO: + update_megatron_weight_loader() + return MegatronLoader(load_config) + + # NOTE(sgm): change the weight_loader function in runtime + if load_config.load_format == LoadFormat.MEGATRON: + update_megatron_weight_loader() + return MegatronLoader(load_config) + + if load_config.load_format == LoadFormat.HF: + update_hf_weight_loader() + return HFLoader(load_config) + + if load_config.load_format == LoadFormat.DTENSOR: + update_dtensor_weight_loader() + return DTensorLoader(load_config) + + if load_config.load_format == LoadFormat.DUMMY_HF: + update_hf_weight_loader() + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.DUMMY_MEGATRON: + update_megatron_weight_loader() + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.DUMMY_DTENSOR: + update_dtensor_weight_loader() + return DummyModelLoader(load_config) + + raise ValueError('load format not supported in verl: {}, only support {} and {}'.format( + load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF)) + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(vllm_config=vllm_config) + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + # initialize_dummy_weights(model) + return model.eval() + + +class MegatronLoader(BaseModelLoader): + """Model loader that can load the model weights from partitioned megatron model.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + + def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): + # NOTE(shengguangming) Load the weights from the actor model + pass + # if isinstance(actor_model, nn.Module): + # load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) + # else: + # load_weights(actor_weights=actor_model, vllm_model=model) + # return actor_model + + def load_model(self, actor_model: Union[PreTrainedModel, Dict], vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(vllm_config=vllm_config) + + # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm + if isinstance(actor_model, nn.Module): + load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), + vllm_model=model) + else: + load_megatron_weights(actor_weights=actor_model, vllm_model=model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + # NOTE(sgm) Some weights are point to gpu, but still need this. + model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage + return model.eval() + + +class HFLoader(BaseModelLoader): + """Model loader that can load the model weights from model's full params.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]): + if isinstance(actor_model, Dict): + return actor_model.items() + elif isinstance(actor_model, nn.Module): + return dict(actor_model.named_parameters()).items() + else: + raise ValueError(f'actor model should be Dict or nn.Module, but get {type(actor_model)}') + + def load_model(self, actor_model: Union[PreTrainedModel, Dict], vllm_config: VllmConfig) -> nn.Module: + model_config = vllm_config.model_config + with set_default_torch_dtype(model_config.dtype): + # with torch.device(device_config.device): + # NOTE(sgm): init the model in cpu + model = _initialize_model(vllm_config=vllm_config) + model.load_weights(self._get_weights_iterator(actor_model)) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + # NOTE(sgm) Some weights are point to gpu, but still need this. + model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage + return model.eval() + + +class DTensorLoader(BaseModelLoader): + """Model loader that can load the model weights from partitioned megatron model.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): + # NOTE(shengguangming) Load the weights from the actor model + pass + # if isinstance(actor_model, nn.Module): + # load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) + # else: + # load_weights(actor_weights=actor_model, vllm_model=model) + # return actor_model + + def load_model(self, actor_model: Union[PreTrainedModel, Dict], vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(vllm_config=vllm_config) + + # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm + if isinstance(actor_model, nn.Module): + load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), + vllm_model=model) + else: + load_dtensor_weights(actor_weights=actor_model, vllm_model=model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + # NOTE(sgm) Some weights are point to gpu, but still need this. + model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage + return model.eval() + + +# FIXME(sgm): hack the _get_logits function in vllm v0.4.2 +# as they use ray, the _get_logits result will only need to return to the driver node, +# therefore gather is enough. However, we use SPMD instead of a central scheduler, +# all_gather is required (aligned with v0.2.6) +def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits + + +from vllm.model_executor.layers.logits_processor import LogitsProcessor + + +def logitsprocessor_init(self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None) -> None: + """ + Args: + scale: A scaling factor to apply to the logits. + """ + super(LogitsProcessor, self).__init__() + self.scale = scale + self.vocab_size = vocab_size + # Whether the input is logits (default is hidden states). + self.logits_as_input = logits_as_input + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + # Soft cap the logits. Used in Gemma 2. + self.soft_cap = soft_cap + # Whether to use gather or all-gather to gather the logits. + self.use_gather = False + + +LogitsProcessor.__init__ = logitsprocessor_init # use all_gather \ No newline at end of file diff --git a/verl/third_party/vllm/vllm_v_0_6_4/model_runner.py b/verl/third_party/vllm/vllm_v_0_6_4/model_runner.py new file mode 100644 index 00000000..c16b8073 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_4/model_runner.py @@ -0,0 +1,154 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py + +import warnings +from enum import IntEnum +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +import vllm.envs as envs +from vllm.config import CompilationLevel, VllmConfig +from vllm.logger import init_logger +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor.models.interfaces import supports_lora +from vllm.worker.npu_model_runner import NPUModelRunner +from vllm.prompt_adapter.worker_manager import LRUCacheWorkerPromptAdapterManager +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.utils import DeviceMemoryProfiler, supports_dynamo +from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.platforms import current_platform + +from .model_loader import get_model +from .config import ModelConfig, LoadConfig + +logger = init_logger(__name__) + + +# How batches are constructed. +class BatchType(IntEnum): + # Every batch is prefill. + PREFILL = 0 + # Every batch is decode. + DECODE = 1 + # Batch is a mixture of prefill and decode. + MIXED = 2 + +class ModelRunner(NPUModelRunner): + + def __init__( + self, + model: Union[nn.Module, Dict], # [verl] model itself or its parameter dict + vllm_config: VllmConfig, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + return_hidden_states: bool = False, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + + super().__init__( + vllm_config, + kv_cache_dtype, + is_driver_worker=True, # a hack + return_hidden_states=return_hidden_states, + input_registry=input_registry, + mm_registry=mm_registry) + + # NOTE(sgm): add for verl + self.model = model # this will be replaced by get_model() + + # NOTE(sgm): initialize model using the actor model + def load_model(self) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: + self.model = get_model(self.model, vllm_config=self.vllm_config) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) + + if self.lora_config: + assert supports_lora( + self.model + ), f"{self.model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(self.model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = self.model.config.max_position_embeddings + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + + if self.prompt_adapter_config: + self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, self.device, + self.prompt_adapter_config) + self.model = ( + self.prompt_adapter_manager.create_prompt_adapter_manager( + self.model)) + + if self.kv_cache_dtype == "fp8" and current_platform.is_rocm(): + # Currently only ROCm accepts kv-cache scaling factors + # via quantization_param_path and this will be deprecated + # in the future. + if self.model_config.quantization_param_path is not None: + if callable(getattr(self.model, "load_kv_cache_scales", None)): + warnings.warn( + "Loading kv cache scaling factor from JSON is " + "deprecated and will be removed. Please include " + "kv cache scaling factors in the model checkpoint.", + FutureWarning, + stacklevel=2) + self.model.load_kv_cache_scales( + self.model_config.quantization_param_path) + logger.info("Loaded KV cache scaling factors from %s", + self.model_config.quantization_param_path) + else: + raise RuntimeError( + "Using FP8 KV cache and scaling factors provided but " + "model %s does not support loading scaling factors.", + self.model.__class__) + else: + logger.warning( + "Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!") + + if self.vllm_config.compilation_config.level ==\ + CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + backend = self.vllm_config.compilation_config.init_backend() + self.model = torch.compile( + self.model, + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=backend) \ No newline at end of file diff --git a/verl/third_party/vllm/vllm_v_0_6_4/parallel_state.py b/verl/third_party/vllm/vllm_v_0_6_4/parallel_state.py new file mode 100644 index 00000000..0150c1c6 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_4/parallel_state.py @@ -0,0 +1,312 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""Model and data parallel groups.""" +import os +from typing import Optional + +import torch +import torch.distributed +import vllm.distributed.parallel_state as ps +from vllm.distributed.parallel_state import ( + get_pp_group, + get_world_group, + init_distributed_environment, + init_model_parallel_group, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) +""" +This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron. +- We assume the Megatron tp+dp+pp world is already established before calling this function. + +""" + +# Device mesh for using DTensor +_DEVICE_MESH = None + +# Tensor model parallel group that the current rank belongs to. +_TP = None +# Pipeline model parallel group that the current rank belongs to. +_PP = None + + +# This method is for initializing the ParallelGroup when using HybridEngine +def initialize_parallel_state( + distributed_init_method: str = "env://", + backend: str = "nccl", + tensor_model_parallel_size: int = 1, + num_tp_per_train_tp: int = 1, + pipeline_model_parallel_size: int = 1, +): + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. + rank = int(os.getenv("RANK", "-1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + # Use the world_size set by TORCHRUN + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) + if torch.distributed.get_world_size() > 1: + # NOTE: build a sepearate inference group with infer tp & micro dp + initialize_model_parallel_for_vllm( + tensor_model_parallel_size=tensor_model_parallel_size, + num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp, + ) + else: + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + + +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + # get the backend of _DEVICE_WORLD_GROUP + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + if not model_parallel_is_initialized(): + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + return + + assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( + "tensor parallel group already initialized, but of unexpected size: " + f"{get_tensor_model_parallel_world_size()=} vs. " + f"{tensor_model_parallel_size=}") + pp_world_size = get_pp_group().world_size + assert pp_world_size == pipeline_model_parallel_size, ( + "pipeline parallel group already initialized, but of unexpected size: " + f"{pp_world_size=} vs. " + f"{pipeline_model_parallel_size=}") + + +# TODO(sgm): deviate from the v0.5.4, not pp now +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ps._TP is not None + # and _PIPELINE_MODEL_PARALLEL_GROUP is not None) + + +def initialize_model_parallel_for_vllm( + tensor_model_parallel_size: int, + num_tensor_model_parallel_groups_per_train_tp: int = 1, + pipeline_model_parallel_size: int = 1, +) -> None: + pass + + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + + assert isinstance(tensor_model_parallel_size, int) + + # assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group + # assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group + + # Build the tensor model-parallel groups. + assert ps._TP is None, "tensor model parallel group is already initialized" + + global _TP + + world_size: int = torch.distributed.get_world_size() + + rank = torch.distributed.get_rank() + + backend = torch.distributed.get_backend() + + num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size + + if num_tensor_model_parallel_groups_per_train_tp == 1: + # if tensor_model_parallel_size == train_tensor_parallel_size: + # using the same tp group as Megatron/vllm + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + group_ranks.append(ranks) + _TP = init_model_parallel_group( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + backend=backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + # _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine + else: + # initialize a micro_dp group and a tp group + # assume training tp=4, infer tp=2, then, weight is partitioned as + # [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference + + # Build the inference tp groups + # train_tp = train_tensor_parallel_size + train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size + # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp): + start = train_tp * i + end = train_tp * (i + 1) + for j in range(num_tensor_model_parallel_groups_per_train_tp): + ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp)) + for i in range(len(ranks)): + ranks[i] += j + group_ranks.append(ranks) + _TP = init_model_parallel_group( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + backend=backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + + # Build the pipeline model-parallel groups. + # global _PIPELINE_MODEL_PARALLEL_GROUP + # global _PIPELINE_GLOBAL_RANKS + # assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized") + + # ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group() + # ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks() + + # TODO: init using device mesh (not support hybrid engine now) + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + ps._PP = _PP # for verl + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + NOTE: This method is a hack from the open-sourced version without + asertion of world_size = tp * pp + + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group) + + # NOTE(sgm) we don't assert world_size == tp * pp + # DP is not managed by vllm but by the VeRL WorkerGroup + # if (world_size != + # tensor_model_parallel_size * pipeline_model_parallel_size): + # raise RuntimeError( + # f"world_size ({world_size}) is not equal to " + # f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + # f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") + + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + rank = torch.distributed.get_rank() + global _TP + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + _TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + + # TODO: init using device mesh (not support hybrid engine now) + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + ps._PP = _PP # for verl + + +""" +Device mesh utilities +""" + + +def get_device_mesh(): + assert _DEVICE_MESH is not None, "device mesh is not initialized" + return _DEVICE_MESH + + +""" +Tensor model parallel utilities +""" + + +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP.device_group + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size diff --git a/verl/third_party/vllm/vllm_v_0_6_4/spmd_gpu_executor.py b/verl/third_party/vllm/vllm_v_0_6_4/spmd_gpu_executor.py new file mode 100644 index 00000000..9e13a044 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_4/spmd_gpu_executor.py @@ -0,0 +1,235 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py + +import os +import socket +from typing import Any, Dict, List, Optional, Set, Tuple + +import torch +from vllm.executor.executor_base import ExecutorBase, ExecutorAsyncBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest + +from vllm.config import ParallelConfig, VllmConfig +from .config import ModelConfig, LoadConfig + +logger = init_logger(__name__) + + +class SPMDGPUExecutor(ExecutorBase): + """SPMD-based multi-GPU executor implementations.""" + + def __init__( + self, + model, # pytorch model itself or its parameter dict + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + distributed_init_method = initialize_cluster(vllm_config.parallel_config) + self._init_executor(model, distributed_init_method) + + # TODO(sgm): verl not support speculative decode now + def _init_executor(self, model, distributed_init_method) -> None: + assert (not self.speculative_config), "Speculative decoding not yet supported for multi-GPU backend." + + # Create the parallel worker for each GPU. + self._init_workers_sp(model, distributed_init_method) + + def _init_workers_sp(self, model, distributed_init_method: str): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from .worker import Worker # pylint: disable=import-outside-toplevel + + rank = int(os.getenv("RANK")) + local_rank = int(os.getenv("LOCAL_RANK")) + print(f'local rank {local_rank}') + + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ['NCCL_CUMEM_ENABLE'] = '0' + + self.worker = Worker( + model, + self.vllm_config, + local_rank, + rank, + distributed_init_method, + is_driver_worker=True, + model_runner_cls=None, # use the default one + ) + + # NOTE(shengguangming): torch.distributed.init_process_group will be called inside the init_model() + self.worker.init_device() + self.worker.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self.worker.determine_num_available_blocks() + + # NOTE(shengguangming): Now we don't use a shared centralized controler but each process will + # have its own scheduler + num_gpu_blocks = num_blocks[0] + num_cpu_blocks = num_blocks[1] + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers. + """ + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + if torch.distributed.get_rank() == 0: + print( + f'before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB' + ) + self.worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) + if torch.distributed.get_rank() == 0: + print( + f'after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB' + ) + + # NOTE(sgm): This will not profile & capture the model(CUDAGraph) when rebuilding KVCache + def init_cache_engine(self) -> None: + self.worker._init_cache_engine() + + def free_cache_engine(self) -> None: + self.worker.free_cache_engine() + + def execute_model(self, execute_model_req) -> List[SamplerOutput]: + all_outputs = self.worker.execute_model(execute_model_req=execute_model_req) + + # NOTE(sgm): + # Each GPU in vllm under verl has its own spmd_gpu_executor, therefore all GPUs should return the outputs + # In vllm with ray, only the driver worker returns the sampling results. + return all_outputs + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self.worker.add_lora(lora_request=lora_request) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.worker.remove_lora(lora_id=lora_id) + + def list_loras(self) -> Set[int]: + return self.worker.list_loras() + + def check_health(self) -> None: + # SPMDExecutor will always be healthy as long as + # it's running. + return + + # NOTE(sgm) add for verl to pass the abstract class test, not used + from vllm.prompt_adapter.request import PromptAdapterRequest + + def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool: + assert prompt_adapter_request.prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.worker.add_prompt_adapter(prompt_adapter_request) + + def list_prompt_adapters(self) -> Set[int]: + return self.worker.list_prompt_adapters() + + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.worker.pin_lora(lora_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.worker.pin_prompt_adapter(prompt_adapter_id) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.worker.remove_prompt_adapter(prompt_adapter_id) + + # NOTE(sgm): add for verl + def offload_model_weights(self) -> None: + self.worker.offload_model_weights() + + def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None: + self.worker.sync_model_weights(actor_weights=actor_weights, load_format=load_format) + + +def initialize_cluster( + parallel_config: ParallelConfig, + engine_use_ray: bool = False, + ray_address: Optional[str] = None, +) -> Tuple[str, Optional[None]]: + """Initialize the distributed cluster probably with Ray. + + Args: + parallel_config: The configurations for parallel execution. + + Returns: + The `distributed_init_method` is the address for initializing the + distributed backend. + """ + + # Initialize cluster locally. + port = get_open_port() + # We need to setup the distributed init method to make sure + # the distributed megatron code (e.g., get world size) works correctly. + # distributed_init_method = f"tcp://localhost:{port}" + distributed_init_method = 'env://' + return distributed_init_method + + +def get_open_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +# TODO(sgm): not implemented async executor yet +class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase): + + async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + """Executes one model step on the given sequences.""" + raise NotImplementedError + + async def check_health_async(self) -> None: + """Checks if the executor is healthy. If not, it should raise an + exception.""" + self.check_health() diff --git a/verl/third_party/vllm/vllm_v_0_6_4/tokenizer.py b/verl/third_party/vllm/vllm_v_0_6_4/tokenizer.py new file mode 100644 index 00000000..b0b4d0e2 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_4/tokenizer.py @@ -0,0 +1,40 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py + +from typing import Optional + +from transformers import PreTrainedTokenizer +from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.utils import LRUCache + + +class TokenizerGroup(TokenizerGroup): + """A group of tokenizers that can be used for LoRA adapters.""" + + def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int]): + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = tokenizer + self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None + + # FIXME(sgm): for simplicity, we assign the special token here + @property + def pad_token_id(self): + return self.tokenizer.pad_token_id + + @property + def eos_token_id(self): + return self.tokenizer.eos_token_id diff --git a/verl/third_party/vllm/vllm_v_0_6_4/worker.py b/verl/third_party/vllm/vllm_v_0_6_4/worker.py new file mode 100644 index 00000000..f1e9ced0 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_4/worker.py @@ -0,0 +1,310 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py +"""A GPU worker class.""" +import gc +import os +from typing import Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.distributed +import torch.nn as nn + +# TODO(sgm): check why vllm has similar file in vllm.model_executor.parallel_utils.parallel_state +from vllm.config import ParallelConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_group, init_distributed_environment, set_custom_all_reduce +from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.embedding_model_runner import EmbeddingModelRunner +from vllm.worker.model_runner import GPUModelRunnerBase +from vllm.worker.worker_base import WorkerBase, WorkerInput +from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype +from vllm.worker.model_runner_base import ModelRunnerInputBase + +from .model_runner import ModelRunner +from .megatron_weight_loaders import load_megatron_weights +from .hf_weight_loader import load_hf_weights +from .dtensor_weight_loaders import load_dtensor_weights +from .parallel_state import ensure_model_parallel_initialized +from .config import ModelConfig, LoadConfig, LoadFormat + + + +class Worker(Worker): + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single GPU. The worker is responsible for + maintaining the KV cache and executing the model on the GPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ + + def __init__( + self, + model: Union[nn.Module, Dict], # model itself or its parameter dict + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, + ) -> None: + WorkerBase.__init__(self, vllm_config) + self.parallel_config.rank = rank + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + # if is_driver_worker: + # assert rank % self.parallel_config.tensor_parallel_size == 0, \ + # "Driver worker should be rank 0 of tensor parallel group." + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_config = self.speculative_config + model_config = self.model_config + speculative_args = ( + {} if speculative_config is None or (speculative_config.draft_model_config.model == model_config.model) or + (speculative_config.draft_model_config.hf_config.model_type not in ["medusa", "mlp_speculator"]) else { + "return_hidden_states": True + }) + + ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner + if model_runner_cls is not None: + ModelRunnerClass = model_runner_cls + elif model_config.task == "embedding": + ModelRunnerClass = EmbeddingModelRunner + self.model_runner: GPUModelRunnerBase = ModelRunnerClass( + model, # [VERL]: add for verl + vllm_config=self.vllm_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + **speculative_args, + ) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CacheEngine] = None + # Initialize gpu_cache as embedding models don't initialize kv_caches + self.gpu_cache: Optional[List[List[torch.Tensor]]] = None + + # NOTE(sgm): [VERL] For offloading inference engine params + self.cpu_model = None + + def init_device(self) -> None: + if self.device_config.device.type in ("cuda", "npu"): + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. + self.rank = self.rank if self.rank is not None else int(os.getenv("RANK", "-1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + self.device = torch.device(f"cuda:{local_rank}") + if self.rank < 0: + raise ValueError("Invalid or unspecified rank.") + torch.cuda.set_device(self.device) + + # Use the world_size set by TORCHRUN + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + self.parallel_config.world_size = world_size + + if self.device_config.device.type == "cuda": + _check_if_gpu_supports_dtype(self.model_config.dtype) + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank, + self.device_config.device.type) + # Set random seed. + set_random_seed(self.model_config.seed) + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.cuda.empty_cache() + # torch.cuda.reset_peak_memory_stats() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize() + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + peak_memory = total_gpu_memory - free_gpu_memory + + assert peak_memory > 0, ("Error in memory profiling. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + cache_block_size = self.get_cache_block_size_bytes() + + # NOTE(sgm) [VERL] use the remaining memory + num_gpu_blocks = int((free_gpu_memory * self.cache_config.gpu_memory_utilization) // cache_block_size) + # num_gpu_blocks = int((total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) // cache_block_size) + + num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + + # NOTE(sgm): Add for [VERL], synchronize number of blocks with all the rank + num_gpu_blocks = torch.tensor([num_gpu_blocks], device="cuda") + num_cpu_blocks = torch.tensor([num_cpu_blocks], device="cuda") + + torch.distributed.all_reduce(num_gpu_blocks, + op=torch.distributed.ReduceOp.MIN, + group=get_tensor_model_parallel_group().device_group) + torch.distributed.all_reduce(num_cpu_blocks, + op=torch.distributed.ReduceOp.MIN, + group=get_tensor_model_parallel_group().device_group) + num_gpu_blocks = num_gpu_blocks.item() + num_cpu_blocks = num_cpu_blocks.item() + gc.collect() + torch.cuda.empty_cache() + return num_gpu_blocks, num_cpu_blocks + + def _init_cache_engine(self): + if self.cache_engine is None and self.gpu_cache is None: + super()._init_cache_engine() + + def free_cache_engine(self): + # ensure `enforce_eager=True` + self.cache_engine = None + self.gpu_cache = None + + # NOTE(sgm): [VERL]: adapt from _execute_model_spmd() + def execute_model( + self, + execute_model_req: ExecuteModelRequest, + intermediate_tensors: Optional[IntermediateTensors] = None + ) -> Optional[List[SamplerOutput]]: + """ + Execute model in Single Program Multiple Data (SPMD) fashion. + All workers take the same request, prepare the input and + execute the model. + """ + assert execute_model_req is not None, ( + "_execute_model_spmd() requires each worker to take in an " + "ExecuteModelRequest") + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list)) + + super().execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + return self.model_runner.execute_model( + model_input, + self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, + intermediate_tensors, + ) + + # assume the input is .state_dict() + def sync_model_weights(self, actor_weights: Dict, load_format: str): + if load_format in [LoadFormat.MEGATRON, LoadFormat.AUTO]: + load_megatron_weights(actor_weights, self.model_runner.model) + elif load_format == LoadFormat.HF: + # full model state dict without no sharding + load_hf_weights(actor_weights, self.model_runner.model) + elif load_format == LoadFormat.DTENSOR: + load_dtensor_weights(actor_weights, self.model_runner.model) + + def offload_model_weights(self) -> None: + if self.cpu_model == None: + self.cpu_model = {} + for name, params in self.model_runner.model.named_parameters(): + self.cpu_model[name] = torch.empty_like(params, device='cpu') + params.data = self.cpu_model[name] + else: + for name, params in self.model_runner.model.named_parameters(): + params.data = self.cpu_model[name] + +def init_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = "env://", + local_rank: int = -1, + device_type: str = None, +) -> None: + """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + + # NOTE(sgm) use tcp://localhost:xxxx will hang in HF setting without megatron + if device_type == "npu": + init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, "hccl") + else: + init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) + + ensure_model_parallel_initialized( + tensor_model_parallel_size=parallel_config.tensor_parallel_size, + pipeline_model_parallel_size=parallel_config.pipeline_parallel_size, + ) + + # TODO(sgm): check whether need this + # if pynccl_utils.is_initialized(): + # pynccl_world_size = pynccl_utils.get_world_size() + # if pynccl_world_size != parallel_config.world_size: + # raise RuntimeError( + # "pynccl is already initialized but the pynccl world " + # "size does not match parallel_config.world_size " + # f"({pynccl_world_size} vs. {parallel_config.world_size}).") + # elif parallel_config.world_size > 1: + # # NOTE(woosuk): We don't initialize pynccl process group when world size + # # is 1. + # # NOTE(kaichao): By default, pynccl is initialized for tp group. + # pynccl_utils.init_process_group( + # group=get_tensor_model_parallel_cpu_group()) + + # # Initialize a custom fast all-reduce implementation. + # if not parallel_config.disable_custom_all_reduce: + # init_custom_ar() + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + # if pynccl_utils.is_initialized(): + # pynccl_utils.all_reduce(torch.zeros(1).cuda()) diff --git a/verl/trainer/config/ppo_megatron_trainer_llama_3.2_1b.yaml b/verl/trainer/config/ppo_megatron_trainer_llama_3.2_1b.yaml new file mode 100644 index 00000000..3ba6c286 --- /dev/null +++ b/verl/trainer/config/ppo_megatron_trainer_llama_3.2_1b.yaml @@ -0,0 +1,145 @@ +data: + tokenizer: null + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + prompt_key: prompt + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: 1312 + return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: False + +actor_rollout_ref: + hybrid_engine: True + model: + path: ~/models/deepseek-llm-7b-chat + external_lib: null + override_config: {} + enable_gradient_checkpointing: False + actor: + strategy: megatron # This is for backward-compatibility + ppo_mini_batch_size: 256 + ppo_micro_batch_size: 64 + clip_ratio: 0.2 + entropy_coeff: 0.001 + ppo_epochs: 1 + shuffle: True + optim: + lr: 1e-6 + clip_grad: 1.0 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + megatron: + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug. + sequence_parallel: True + seed: 1 + load_weight: True + ref: + megatron: + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug. + sequence_parallel: True + seed: 1 + load_weight: True + param_offload: False + log_prob_micro_batch_size: 32 + rollout: + name: vllm + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + prompt_length: ${data.max_prompt_length} # for xperf_gpt + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.5 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_megatron + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_num_seqs: 1024 + log_prob_micro_batch_size: 2 + # for hf rollout + do_sample: True + layer_name_map: + qkv_layer_name: qkv + gate_proj_layer_name: gate_up + +critic: + strategy: megatron + optim: + lr: 1e-5 + clip_grad: 1.0 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: {} + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: False + megatron: + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug. + sequence_parallel: True + seed: 1 + load_weight: True + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: 2 + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + cliprange_value: 0.5 + kl_ctrl: + type: fixed + kl_coef: 0.001 + +reward_model: + enable: False + strategy: megatron + megatron: + tensor_model_parallel_size: 4 + pipeline_model_parallel_size: 1 + num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug. + sequence_parallel: True + seed: 1 + model: + input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + load_weight: True + param_offload: False + micro_batch_size: 64 + max_length: null + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + +trainer: + total_epochs: 30 + project_name: verl_examples + experiment_name: gsm8k + logger: ['console', 'wandb'] + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + test_freq: 2 + critic_warmup: 0 + default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index f3c312a6..eddddacc 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -15,11 +15,15 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ +import importlib.util from verl import DataProto import torch from verl.utils.reward_score import gsm8k, math from verl.trainer.ppo.ray_trainer import RayPPOTrainer +# import npu adapter if exit. +if importlib.util.find_spec('mindspeed') is not None: + import mindspeed.megatron_adaptor def _default_compute_score(data_source, solution_str, ground_truth): if data_source == 'openai/gsm8k': @@ -101,7 +105,15 @@ def main(config): def run_ppo(config, compute_score=None): if not ray.is_initialized(): # this is for local ray cluster - ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}}) + ray.init(runtime_env={ + 'env_vars': { + 'TOKENIZERS_PARALLELISM': 'true', + 'NCCL_DEBUG': 'WARN', + # Unlike CUDA, Ascend NPU requires visibility of all devices that need to engage in HCCL communication. + # use this environment variable to enable visiblity of all devices. + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "true" + } + }) ray.get(main_task.remote(config, compute_score)) diff --git a/verl/utils/megatron/optimizer.py b/verl/utils/megatron/optimizer.py index 9ae70b08..f8d32be3 100644 --- a/verl/utils/megatron/optimizer.py +++ b/verl/utils/megatron/optimizer.py @@ -13,12 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pkg_resources from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedSGD as SGD -from megatron.optimizer.distrib_optimizer import DistributedOptimizer -from megatron.optimizer.grad_scaler import ConstantGradScaler, DynamicGradScaler -from megatron.optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer -from megatron.optimizer import get_param_groups + +megatron_version = pkg_resources.get_distribution('megatron_core').version + +if pkg_resources.parse_version(megatron_version) < pkg_resources.parse_version('0.6.0'): + from megatron.optimizer.distrib_optimizer import DistributedOptimizer + from megatron.optimizer.grad_scaler import ConstantGradScaler, DynamicGradScaler + from megatron.optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer + from megatron.optimizer import get_param_groups +else: + from megatron.core.optimizer import get_megatron_optimizer as get_megatron_optimizer_native from verl.utils.megatron.optimizer_config import OptimizerConfig @@ -33,6 +40,14 @@ def get_megatron_optimizer( overlap_param_gather=False # add for verl ): # Base optimizer. + if pkg_resources.parse_version(megatron_version) >= pkg_resources.parse_version('0.6.0'): + return get_megatron_optimizer_native( + config=config, + model_chunks=model, + no_weight_decay_cond=no_weight_decay_cond, + scale_lr_cond=scale_lr_cond, + lr_mult=lr_mult + ) param_groups = get_param_groups(model, no_weight_decay_cond, scale_lr_cond, lr_mult) if config.optimizer == 'adam': diff --git a/verl/utils/megatron/pipeline_parallel.py b/verl/utils/megatron/pipeline_parallel.py index 3a3790bb..c81db232 100644 --- a/verl/utils/megatron/pipeline_parallel.py +++ b/verl/utils/megatron/pipeline_parallel.py @@ -15,12 +15,16 @@ import torch from megatron.core import parallel_state as mpu +from transformers.utils import is_flash_attn_2_available from .sequence_parallel import pad_to_sequence_parallel def compute_transformers_input_shapes(batches, meta_info): - from flash_attn.bert_padding import unpad_input # flash 2 is a must for Megatron + if is_flash_attn_2_available(): + from flash_attn.bert_padding import unpad_input # flash 2 is a must for Megatron + else: + from ...models.llama.megatron.layers.parallel_attention import unpad_input # pre-compute input shapes for each micro-batch at each pp stage input_shapes = [] for model_inputs in batches: diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index fcb6b65a..3543d43a 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -216,7 +216,12 @@ class FakeTimers: """Disable All Megatron Timing with FakeTimers""" def __init__(self): - from megatron.timers import DummyTimer + import pkg_resources + megatron_version = pkg_resources.get_distribution('megatron_core').version + if pkg_resources.parse_version(megatron_version) < pkg_resources.parse_version('0.6.0'): + from megatron.timers import DummyTimer + else: + from megatron.core.timers import DummyTimer self.dummy_timer = DummyTimer() def __call__(self, *args: Any, **kwds: Any) -> Any: diff --git a/verl/utils/memory_buffer.py b/verl/utils/memory_buffer.py index 2e07e42f..98626b0c 100644 --- a/verl/utils/memory_buffer.py +++ b/verl/utils/memory_buffer.py @@ -27,11 +27,14 @@ class MemoryBuffer: memory. It must have a unique type to support this behavior. """ - def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype): + def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype, source: torch.Tensor = None): self.numel = numel self.numel_padded = numel_padded self.dtype = dtype - self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device='cuda', requires_grad=False) + if source is not None: + self.data = source + else: + self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device='cuda', requires_grad=False) def zero(self): """Reset the buffer to zero.""" diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 9ada0230..8caf1901 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -19,14 +19,17 @@ Note that our model doesn't have to be `MegatronModule` because we don't share embedding in the last layer """ +import importlib from functools import partial +import pkg_resources from typing import Iterable, Dict import torch from torch import nn import torch.distributed +if importlib.util.find_spec('mindspeed') is not None: + import mindspeed.megatron_adaptor # from megatron import get_args -from megatron.optimizer import DistributedOptimizer from verl.utils.megatron.optimizer_config import OptimizerConfig from megatron.core import parallel_state as mpu from megatron.core import ModelParallelConfig @@ -42,6 +45,13 @@ from verl.utils.py_functional import append_to_dict from verl.utils.torch_functional import logprobs_from_logits, broadcast_dict_tensor, split_dict_tensor_into_batches +megatron_version = pkg_resources.get_distribution('megatron_core').version + +if pkg_resources.parse_version(megatron_version) < pkg_resources.parse_version('0.6.0'): + from megatron.optimizer import DistributedOptimizer +else: + from megatron.core.optimizer import DistributedOptimizer + __all__ = ['MegatronPPOActor'] @@ -350,14 +360,20 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> Dict: # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm for chunk in self.actor_module: # if use distributed optimizer, zero grad buffer will be handled by optimizer - chunk.zero_grad_buffer(zero_buffer=(not self.actor_optimizer_config.use_distributed_optimizer)) + if pkg_resources.parse_version(megatron_version) < pkg_resources.parse_version('0.6.0'): + chunk.zero_grad_buffer(zero_buffer=(not self.actor_optimizer_config.use_distributed_optimizer)) + else: + chunk.zero_grad() metric_micro_batch = self.forward_backward_batch(data) for metric in metric_micro_batch: append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics. - update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step( - self.megatron_config, self.megatron_config.timers) + if pkg_resources.parse_version(megatron_version) < pkg_resources.parse_version('0.6.0'): + update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step( + self.megatron_config, self.megatron_config.timers) + else: + update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step() if update_successful: # allgather already execute in optimizer.step in new megatron pass diff --git a/verl/workers/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py index f0b044ad..6c590f19 100644 --- a/verl/workers/critic/megatron_critic.py +++ b/verl/workers/critic/megatron_critic.py @@ -15,13 +15,17 @@ Implement a multiprocess PPOCritic """ +import importlib from functools import partial from typing import Iterable import torch import torch.distributed from omegaconf import OmegaConf +import pkg_resources from torch import nn +if importlib.util.find_spec('mindspeed') is not None: + import mindspeed.megatron_adaptor from verl import DataProto from verl.trainer.ppo import core_algos @@ -33,10 +37,15 @@ from verl.utils.megatron import sequence_parallel as sp_utils from verl.utils.megatron.optimizer_config import OptimizerConfig -from megatron.optimizer import DistributedOptimizer from megatron.core import parallel_state as mpu from megatron.core.pipeline_parallel import get_forward_backward_func +megatron_version = pkg_resources.get_distribution('megatron_core').version + +if pkg_resources.parse_version(megatron_version) < pkg_resources.parse_version('0.6.0'): + from megatron.optimizer import DistributedOptimizer +else: + from megatron.core.optimizer import DistributedOptimizer class MegatronPPOCritic(BasePPOCritic): @@ -213,12 +222,18 @@ def update_critic(self, dataloader: Iterable[DataProto]): self.critic_optimizer.zero_grad() # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm for chunk in self.critic_module: - chunk.zero_grad_buffer(zero_buffer=(not self.critic_optimizer_config.use_distributed_optimizer)) + if pkg_resources.parse_version(megatron_version) < pkg_resources.parse_version('0.6.0'): + chunk.zero_grad_buffer(zero_buffer=(not self.critic_optimizer_config.use_distributed_optimizer)) + else: + chunk.zero_grad_buffer() metric_micro_batch = self.forward_backward_batch(data) - update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step( - self.megatron_config, self.megatron_config.timers) + if pkg_resources.parse_version(megatron_version) < pkg_resources.parse_version('0.6.0'): + update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step( + self.megatron_config, self.megatron_config.timers) + else: + update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step() if update_successful: # allgather already execute in optimizer.step in new megatron pass diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index dc531f4d..709e982f 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -15,6 +15,7 @@ The main entry point to run the PPO algorithm """ +import importlib import os import logging import ray @@ -83,6 +84,16 @@ def __init__(self, config: DictConfig, role: str): if self.config.actor.megatron.sequence_parallel: os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' + + # mindspeed adapter need to init global var + if importlib.util.find_spec('mindspeed') is not None: + from megatron.training.arguments import parse_args, validate_args + from megatron.training.global_vars import set_global_variables + + args = parse_args(ignore_unknown_args=True) + validate_args(args, {}) + set_global_variables(args, build_tokenizer=False) + mpu.initialize_model_parallel( tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size, pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size, @@ -142,6 +153,7 @@ def _build_model_optimizer(self, # Step 1: initialize the tokenizer local_path = copy_local_path_from_hdfs(model_path) self.tokenizer = hf_tokenizer(local_path) + self.tokenizer.max_token_id = max(self.tokenizer.get_vocab().values()) # Step 2: get the actor_model_config actor_model_config = AutoConfig.from_pretrained(local_path) @@ -334,7 +346,7 @@ def init_model(self): def update_actor(self, data: DataProto): assert self._is_actor - data.batch = data.batch.cuda() + data.batch = data.batch.to(torch.cuda.current_device()) log_gpu_memory_usage('Before update policy', logger=logger) @@ -353,15 +365,9 @@ def update_actor(self, data: DataProto): def generate_sequences(self, prompts: DataProto): assert self._is_rollout - prompts.batch = prompts.batch.cuda() - meta_info = { - 'eos_token_id': - self.generation_config.eos_token_id - if self.generation_config is not None else self.tokenizer.eos_token_id, - 'pad_token_id': - self.generation_config.pad_token_id - if self.generation_config is not None else self.tokenizer.pad_token_id, - } + # prompts.batch = prompts.batch.cuda() + prompts.batch = prompts.batch.to(torch.cuda.current_device()) + meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id} prompts.meta_info.update(meta_info) with self.sharding_manager: log_gpu_memory_usage('After entering sharding manager', logger=logger) diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 7014ff1e..5c69bf4e 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -83,9 +83,11 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model os.environ['MEGATRON_IMPORT_TIMERS'] = '0' train_tp = kwargs.get('train_tp', None) num_tp_per_train_tp = train_tp // tensor_parallel_size - if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): + if vllm_version in ('0.4.2', '0.5.4', '0.6.3', '0.6.4'): vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size, - num_tp_per_train_tp=num_tp_per_train_tp) + num_tp_per_train_tp=num_tp_per_train_tp, + # hard coding? + backend="hccl") assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \ "model context length should be greater than total sequence length" @@ -115,7 +117,7 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model ) # we may detokenize the result all together later - if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): + if vllm_version in ('0.4.2', '0.5.4', '0.6.3', '0.6.4'): kwargs['detokenize'] = False # supporting adding any sampling params from the config file diff --git a/verl/workers/sharding_manager/megatron_vllm.py b/verl/workers/sharding_manager/megatron_vllm.py index bc07a5a6..1187a74e 100644 --- a/verl/workers/sharding_manager/megatron_vllm.py +++ b/verl/workers/sharding_manager/megatron_vllm.py @@ -15,6 +15,7 @@ This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine. """ +import pkg_resources import torch import torch.distributed as dist @@ -31,6 +32,8 @@ get_weight_buffer_meta_from_module, ) +megatron_version = pkg_resources.get_distribution('megatron_core').version + class AllGatherPPModel: @@ -81,11 +84,24 @@ def __init__(self, model_provider) -> None: def _build_param_buffer(self, pp_rank): """Build the parameter buffer in each pp rank""" - model = self.pp_models[pp_rank] - weight_buffer_meta = get_weight_buffer_meta_from_module(model) - self.memory_buffers[pp_rank] = build_memory_buffer(weight_buffer_meta) + if pkg_resources.parse_version(megatron_version) >= pkg_resources.parse_version('0.6.0') and pp_rank == self._pp_rank: + # in megatron>=0.6, buffer is alread set to current pp_rank model. + from verl.utils.memory_buffer import MemoryBuffer + # TODO: Of course we cant't do so many hard coding in the final version. + source = self._this_rank_models[0].buffers[0].param_data + self.memory_buffers[pp_rank] = { + # hard coding dtype so far. + torch.bfloat16: MemoryBuffer(source.numel(), source.numel(), torch.bfloat16, source) + } + else: + model = self.pp_models[pp_rank] + weight_buffer_meta = get_weight_buffer_meta_from_module(model) + self.memory_buffers[pp_rank] = build_memory_buffer(weight_buffer_meta) def _build_param_references(self, pp_rank, maintain_weight=False): + if pkg_resources.parse_version(megatron_version) >= pkg_resources.parse_version('0.6.0') and pp_rank == self._pp_rank: + # in megatron>=0.6, megatron will set the reference to buffer, we don't want to rebuild it. + return model = self.pp_models[pp_rank] build_memory_reference_from_module(model, self.memory_buffers[pp_rank], maintain_weight=maintain_weight) @@ -120,9 +136,15 @@ def allgather_params(self): for cur_pp_rank in range(self.pp_size): global_src = dist.get_global_rank(group=self.pp_group, group_rank=cur_pp_rank) - # NOTE(sgm): the async op may cause memory leakage of the memory_buffer/pp_models - for memory_buffer in self.memory_buffers[cur_pp_rank].values(): - dist.broadcast(tensor=memory_buffer.data, src=global_src, group=self.pp_group, async_op=False) + if pkg_resources.parse_version(megatron_version) >= pkg_resources.parse_version('0.6.0'): + # The param to buffer map is different in veRL and Megatron, so use param directly + # instead of buffer in megatron>=0.6 so far. + for _, param in sorted(self.pp_models[cur_pp_rank].named_parameters()): + dist.broadcast(tensor=param.data, src=global_src, group=self.pp_group, async_op=False) + else: + # NOTE(sgm): the async op may cause memory leakage of the memory_buffer/pp_models + for memory_buffer in self.memory_buffers[cur_pp_rank].values(): + dist.broadcast(tensor=memory_buffer.data, src=global_src, group=self.pp_group, async_op=False) def forward(self, *inputs, **kwargs): try: