From 7510441dcde1dc9feb5c3bfe47ef76f4fb7fbba0 Mon Sep 17 00:00:00 2001 From: sunyi001 <1659275352@qq.com> Date: Fri, 21 Feb 2025 09:14:19 +0800 Subject: [PATCH] support ASCEND NPU --- examples/grpo_trainer/run_qwen2-7b_npu.sh | 41 ++++ pyproject.toml | 1 - requirements-npu.txt | 18 ++ verl/bert_padding.py | 220 ++++++++++++++++++ verl/single_controller/ray/base.py | 12 +- .../vllm/vllm_spmd/dtensor_weight_loaders.py | 3 +- verl/trainer/fsdp_sft_trainer.py | 31 +-- verl/utils/device.py | 75 ++++++ verl/utils/flops_counter.py | 3 +- verl/utils/fsdp_utils.py | 93 ++++---- verl/workers/actor/dp_actor.py | 8 +- verl/workers/critic/dp_critic.py | 8 +- verl/workers/fsdp_workers.py | 151 +++++++----- .../rollout/vllm_rollout/vllm_rollout_spmd.py | 6 +- verl/workers/sharding_manager/fsdp_vllm.py | 35 +-- 15 files changed, 549 insertions(+), 156 deletions(-) create mode 100644 examples/grpo_trainer/run_qwen2-7b_npu.sh create mode 100644 requirements-npu.txt create mode 100644 verl/bert_padding.py create mode 100644 verl/utils/device.py diff --git a/examples/grpo_trainer/run_qwen2-7b_npu.sh b/examples/grpo_trainer/run_qwen2-7b_npu.sh new file mode 100644 index 00000000..e9a2d8f3 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2-7b_npu.sh @@ -0,0 +1,41 @@ +set -x + + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=32 \ + data.val_batch_size=1312 \ + data.max_prompt_length=64 \ + data.max_response_length=128 \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3a447e05..8f154335 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ dependencies = [ "ray>=2.10", "tensordict<0.6", "transformers", - "vllm<=0.6.3", 'wandb', ] diff --git a/requirements-npu.txt b/requirements-npu.txt new file mode 100644 index 00000000..0ad7f301 --- /dev/null +++ b/requirements-npu.txt @@ -0,0 +1,18 @@ +# requirements.txt records the full set of dependencies for development +accelerate +codetiming +datasets +dill +hydra-core +numpy +pandas +peft +pyarrow>=15.0.0 +pybind11 +pylatexenc +ray +tensordict<0.6 +transformers +wandb +vllm +vllm-ascend diff --git a/verl/bert_padding.py b/verl/bert_padding.py new file mode 100644 index 00000000..d7584beb --- /dev/null +++ b/verl/bert_padding.py @@ -0,0 +1,220 @@ +# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + + +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather( + rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) + ).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], + device=grad_output.device, + dtype=grad_output.dtype, + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros( + first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +class IndexFirstAxisResidual(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + output = input[indices] + # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last + # memory format to channel_first. In other words, input might not be contiguous. + # If we don't detach, Pytorch complains about output being a view and is being modified inplace + return output, input.detach() + + @staticmethod + def backward(ctx, grad_output, grad_residual): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + assert grad_residual.shape[1:] == other_shape + grad_input = grad_residual + # grad_input[indices] += grad_output + indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) + indices = indices.expand_as(grad_output) + grad_input.scatter_add_(0, indices, grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis_residual = IndexFirstAxisResidual.apply + + +def unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): + """ + Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). + The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). + + For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: + ``` + [ + [2, 3, 0, 0, 0, 0], + [3, 2, 0, 0, 0, 0], + [6, 0, 0, 0, 0, 0] + ] + ``` + , which refers to the 3D-attention mask: + ``` + [ + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1] + ] + ] + ```. + + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + """ + length = attention_mask_in_length.sum(dim=-1) + seqlen = attention_mask_in_length.size(-1) + attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), + seqlen) < length.unsqueeze( + 1) + real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() + seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] + indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[-1] + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) \ No newline at end of file diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index 4763a20d..d2c81eff 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -22,6 +22,7 @@ from ray.experimental.state.api import get_actor from verl.single_controller.base import WorkerGroup, ResourcePool, ClassWithInitArgs, Worker +from verl.utils.device import is_cuda_available __all__ = ['Worker'] @@ -68,9 +69,10 @@ def get_placement_groups(self, strategy="STRICT_PACK", name=None): pg_name_prefix = name if name else \ f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" # print(f"pg_name_prefix = {pg_name_prefix}") + device_name = "GPU" if is_cuda_available else "NPU" pg_scheme = [[{ "CPU": self.max_collocate_count, - "GPU": 1 + device_name: 1 } if self.use_gpu else { "CPU": self.max_collocate_count } for _ in range(process_count)] for process_count in self._store] @@ -160,8 +162,10 @@ def __call__(self, } options.update(self._options) - if use_gpu: + if use_gpu and is_cuda_available: options["num_gpus"] = num_gpus + if use_gpu and not is_cuda_available: + options["resources"] = {"NPU": num_gpus} if len(self._additional_resource) > 1: for k, v in self._additional_resource.items(): @@ -379,7 +383,7 @@ def world_size(self): def _bind_workers_method_to_parent(cls, key, user_defined_cls): """ - Binds the methods of each worker to the WorkerDict. + Binds the methods of each worker to the WorkerDict. Note that we only bind public methods that are decorated by register """ for method_name in dir(user_defined_cls): @@ -419,7 +423,7 @@ def _unwrap_ray_remote(cls): def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): """ - This function should return a class instance that delegates the calls to every + This function should return a class instance that delegates the calls to every cls in cls_dict """ cls_dict = {} diff --git a/verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py b/verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py index a3042cab..2fb9f893 100644 --- a/verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py +++ b/verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py @@ -19,6 +19,7 @@ from torch.distributed._tensor import DTensor from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import is_pp_missing_parameter +from verl.utils.device import get_device_name def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: @@ -365,7 +366,7 @@ def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module): weight_loader(actor_weights, vllm_model) # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu # after init, and we need this after sync model weights for in first iter. - vllm_model = vllm_model.cuda() + vllm_model = vllm_model.to(get_device_name()) def _get_model_weight_loader(arch: str): diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index b715c8cd..7148cb8b 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -35,13 +35,14 @@ from verl.utils.torch_functional import get_cosine_schedule_with_warmup from tensordict import TensorDict from torch.utils.data import DataLoader, DistributedSampler -from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +from verl.bert_padding import pad_input, unpad_input, rearrange, index_first_axis from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager from verl.utils.dataset import SFTDataset from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.tracking import Tracking from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group +from verl.utils.device import get_device_name from torch.distributed.device_mesh import DeviceMesh import verl.utils.hdfs_io as hdfs_io @@ -51,6 +52,7 @@ from verl.workers.sharding_manager import FSDPUlyssesShardingManager from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl import DataProto +from verl.utils.device import is_cuda_available logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN')) @@ -106,6 +108,7 @@ def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceM # TODO: add checkpoint manager if self.device_mesh.get_rank() == 0: print(self.config) + self.device = get_device_name() def _normalize_config_bsz(self): dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) @@ -257,7 +260,8 @@ def _build_model_optimizer(self): mixed_precision=mixed_precision, device_mesh=self.device_mesh, sync_module_states=True, - device_id=torch.cuda.current_device(), + device_id=torch.cuda.current_device() if is_cuda_available else + torch.npu.current_device(), cpu_offload=cpu_offload, use_orig_params=False) @@ -289,16 +293,16 @@ def _compute_loss_and_backward(self, batch, do_backward=True): use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 # Move inputs to GPU and prepare loss mask - input_ids = batch['input_ids'].cuda() - attention_mask = batch['attention_mask'].cuda() - position_ids = batch['position_ids'].cuda() - loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).cuda() + input_ids = batch['input_ids'].to(self.device) + attention_mask = batch['attention_mask'].to(self.device) + position_ids = batch['position_ids'].to(self.device) + loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).to(self.device) loss_fct = nn.CrossEntropyLoss(reduction='none') # Context manager for sequence parallel if needed context = self.sharding_manager if use_sp else nullcontext() with context: - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.autocast(device_type=self.device, dtype=torch.bfloat16): if not use_sp: # Standard forward pass without sequence parallel labels = input_ids[:, 1:].contiguous() @@ -412,7 +416,7 @@ def training_step(self, batch: TensorDict): log_gpu_memory_usage('After offload weights', logger=logger) - step_loss = torch.tensor(step_loss).cuda() + step_loss = torch.tensor(step_loss).to(self.device) torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) return {'train/loss': step_loss.detach().item(), 'train/lr(1e-3)': lr * 1e3} @@ -468,7 +472,7 @@ def fit(self): for data in tqdm(self.train_dataloader, total=self.steps_per_epoch, desc=f"Epoch {epoch+1}/{self.config.trainer.total_epochs}"): - data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda() + data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device) metric = self.training_step(data) if rank == 0: tracking.log(data=metric, step=global_step) @@ -479,7 +483,7 @@ def fit(self): # Perform final validation val_losses = [] for val_data in self.val_dataloader: - val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() + val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device) val_loss = self.validation_step(val_data) val_losses.append(val_loss) if rank == 0: @@ -495,7 +499,7 @@ def fit(self): # validation val_losses = [] for data in self.val_dataloader: - data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() + data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device) val_loss = self.validation_step(data) val_losses.append(val_loss) if rank == 0: @@ -520,9 +524,9 @@ def fit(self): def main(config): local_rank, rank, world_size = initialize_global_process_group() - device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('fsdp',)) + device_mesh = init_device_mesh(device_type=get_device_name(), mesh_shape=(world_size,), mesh_dim_names=('fsdp',)) dp_size = world_size // config.ulysses_sequence_parallel_size - ulysses_device_mesh = init_device_mesh(device_type='cuda', + ulysses_device_mesh = init_device_mesh(device_type=get_device_name(), mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=('dp', 'sp')) trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh) @@ -531,3 +535,4 @@ def main(config): if __name__ == '__main__': main() +s \ No newline at end of file diff --git a/verl/utils/device.py b/verl/utils/device.py new file mode 100644 index 00000000..55344e5d --- /dev/null +++ b/verl/utils/device.py @@ -0,0 +1,75 @@ +# This code is inspired by the torchtune. +# https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py + +import os +import logging +from enum import Enum +from typing import Optional + +import torch + +logger = logging.getLogger(__name__) + + +def is_torch_npu_available() -> bool: + """Check the availability of NPU""" + try: + import torch_npu # noqa: F401 + + return torch.npu.is_available() + except ImportError: + return False + + +is_cuda_available = torch.cuda.is_available() +is_npu_available = is_torch_npu_available() + + +def get_device_name() -> str: + """Function that gets the torch.device based on the current machine. + This currently only supports CPU, CUDA, NPU. + Returns: + device + """ + if is_cuda_available: + device = "cuda" + elif is_npu_available: + device = "npu" + else: + device = "cpu" + return device + + +def get_device(device_name: Optional[str] = None) -> torch.device: + """Function that takes an optional device string, verifies it's correct and available given the machine and + distributed settings, and returns a :func:`~torch.device`. If device string is not provided, this function will + infer the device based on the environment. + If CUDA-like is available and being used, this function also sets the CUDA-like device. + Args: + device (Optional[str]): The name of the device to use, e.g. "cuda" or "cpu" or "npu". + Example: + >>> device = get_device("cuda") + >>> device + device(type='cuda', index=0) + Returns: + torch.device: Device + """ + if device_name is None: + device_name = get_device_name() + device = torch.device(device_name) + return device + + +def get_torch_device() -> any: + """Return the corresponding torch attribute based on the device type string. + Returns: + module: The corresponding torch device namespace, or torch.cuda if not found. + """ + device_name = get_device_name() + try: + return getattr(torch, device_name) + except AttributeError: + logger.warning( + f"Device namespace '{device_name}' not found in torch, try to load torch.cuda." + ) + return torch.cuda \ No newline at end of file diff --git a/verl/utils/flops_counter.py b/verl/utils/flops_counter.py index 3c5ac1a9..0b668a03 100644 --- a/verl/utils/flops_counter.py +++ b/verl/utils/flops_counter.py @@ -14,6 +14,7 @@ import torch from transformers import PretrainedConfig, Qwen2Config, LlamaConfig +from verl.utils.device import is_cuda_available VALID_CONFIG_TYPE = (Qwen2Config, LlamaConfig) @@ -30,7 +31,7 @@ def unit_convert(number, level): ptr += 1 return number - device_name = torch.cuda.get_device_name() + device_name = torch.cuda.get_device_name() if is_cuda_available else torch.npu.get_device_name() flops = float("inf") # INF flops for unkown gpu type if "H100" in device_name or "H800" in device_name: flops = 989e12 diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 26b7dbd5..96d2ad8c 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -20,18 +20,18 @@ import os from contextlib import contextmanager from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp._runtime_utils import _lazy_init from transformers.trainer_pt_utils import get_module_class_from_name import torch import torch.nn as nn import torch.distributed as dist +from verl.utils.device import is_cuda_available def init_fn(x: torch.nn.Module): if not torch.distributed.get_rank() == 0: - x = x.to_empty(device=torch.cuda.current_device(), recurse=False) - torch.cuda.empty_cache() + x = x.to_empty(device=torch.cuda.current_device() if is_cuda_available else torch.npu.current_device(), + recurse=False) + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() return x @@ -49,7 +49,7 @@ def get_init_weight_context_manager(use_meta_tensor=True): # Adapted from https://github.com/huggingface/transformers/src/transformers/trainer.py def get_fsdp_wrap_policy(module, config=None, is_lora=False): """Get FSDP wrap policy for the module. - + Args: module: The module to get wrap policy for config: Configuration for wrap policy @@ -107,65 +107,58 @@ def lambda_policy_fn(module): return auto_wrap_policy -@torch.no_grad() -def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): - assert isinstance(model, FSDP) - # lazy init FSDP model - _lazy_init(model, model) - assert model._is_root, f"Only support root model offloading to CPU" - for handle in model._all_handles: - if handle._offload_params: - continue - flat_param = handle.flat_param - assert flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() and \ - id(flat_param.data) != id(flat_param._local_shard) and \ - flat_param.data.size() == flat_param._local_shard.size() - handle.flat_param_to(torch.device("cpu"), non_blocking=True) - # the following still keeps id(._local_shard) != id(.data) - flat_param._local_shard = flat_param.data - assert id(flat_param._local_shard) != id(flat_param.data) - if empty_cache: - torch.cuda.empty_cache() - - -@torch.no_grad() -def load_fsdp_model_to_gpu(model: FSDP): - assert isinstance(model, FSDP) - # lazy init FSDP model - _lazy_init(model, model) - assert model._is_root, f"Only support root model loading to GPU" - device_id = torch.cuda.current_device() - for handle in model._all_handles: - if handle._offload_params: - continue - flat_param = handle.flat_param - handle.flat_param_to(torch.device(f"cuda:{device_id}"), non_blocking=True) - # the following still keeps id(._local_shard) != id(.data) - flat_param._local_shard = flat_param.data - - -@torch.no_grad() +def offload_fsdp_grad(module): + for _, param in module.named_parameters(): + if param.grad is not None: + param.grad = param.grad.to("cpu", non_blocking=True) + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() if is_cuda_available else torch.npu.empty_cache() + + +def load_fsdp_grad(module, device_id): + for _, param in module.named_parameters(): + if param.grad is not None: + param.grad = param.grad.to(device_id, non_blocking=True) + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() + + +def offload_fsdp_param_and_grad(module, offload_grad=False): + for _, param in module.named_parameters(): + if hasattr(param, "_local_shard"): + param._local_shard = param._local_shard.to("cpu", non_blocking=True) + param.data = param.data.to('cpu', non_blocking=True) + if offload_grad and param.grad is not None: + param.grad = param.grad.to("cpu", non_blocking=True) + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() + + +def load_fsdp_param_and_grad(module, device_id, load_grad=False): + for _, param in module.named_parameters(): + if hasattr(param, "_local_shard"): + param._local_shard = param._local_shard.to(device_id, non_blocking=True) + param.data = param.data.to(device_id, non_blocking=True) + if load_grad and param.grad is not None: + param.grad = param.grad.to(device_id, non_blocking=True) + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() + + def offload_fsdp_optimizer(optimizer): - if not optimizer.state: - return for param_group in optimizer.param_groups: for param in param_group['params']: state = optimizer.state[param] for key, value in state.items(): if isinstance(value, torch.Tensor): state[key] = value.to("cpu", non_blocking=True) + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() -@torch.no_grad() def load_fsdp_optimizer(optimizer, device_id): - if not optimizer.state: - return for param_group in optimizer.param_groups: for param in param_group['params']: state = optimizer.state[param] for key, value in state.items(): if isinstance(value, torch.Tensor): state[key] = value.to(device_id, non_blocking=True) + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() @contextmanager @@ -241,7 +234,7 @@ def parallel_load_safetensors(filepath): ckpt_chunks = [ckpt_chunks[rank * size:rank * size + size] for rank in range(world_size)] shard_states = {} - device = torch.cuda.current_device() + device = torch.cuda.current_device() if is_cuda_available else torch.npu.current_device() for rank, files in enumerate(ckpt_chunks): if rank == dist.get_rank(): for file in files: @@ -280,7 +273,7 @@ def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, tor @torch.no_grad() def create_and_sync_state(param_name, state, is_param): assert param_name in shard_states, f"{param_name} not loaded" - device = torch.cuda.current_device() + device = torch.cuda.current_device() if is_cuda_available else torch.npu.current_device() if is_param: param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad) else: # buffer diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 4db326a1..b64e2903 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -30,8 +30,9 @@ from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx import verl.utils.torch_functional as verl_F +from verl.utils.device import get_device_name -from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +from verl.bert_padding import pad_input, unpad_input, rearrange, index_first_axis __all__ = ['DataParallelPPOActor'] @@ -54,6 +55,7 @@ def __init__( self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) + self.device = get_device_name() def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -62,7 +64,7 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, log_probs: # (bs, response_len) """ response_length = micro_batch['responses'].size(-1) - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.autocast(device_type=self.device, dtype=torch.bfloat16): input_ids = micro_batch['input_ids'] batch_size, seqlen = input_ids.shape attention_mask = micro_batch['attention_mask'] @@ -230,7 +232,7 @@ def update_policy(self, data: DataProto): self.actor_optimizer.zero_grad() for data in micro_batches: - data = data.cuda() # actor device is cpu when using offload + data = data.to(self.device) # actor device is cpu when using offload responses = data['responses'] response_length = responses.size(1) attention_mask = data['attention_mask'] diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index f2eb44c2..88b6bf1b 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -30,8 +30,9 @@ from verl.utils.torch_functional import masked_mean from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx +from verl.utils.device import get_device_name -from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +from verl.bert_padding import pad_input, unpad_input, rearrange, index_first_axis __all__ = ['DataParallelPPOCritic'] @@ -46,10 +47,11 @@ def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Opt print(f'Critic use_remove_padding={self.use_remove_padding}') self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) + self.device = get_device_name() def _forward_micro_batch(self, micro_batch): response_length = micro_batch['responses'].size(-1) - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.autocast(device_type=self.device, dtype=torch.bfloat16): input_ids = micro_batch['input_ids'] batch, seqlen = input_ids.shape attention_mask = micro_batch['attention_mask'] @@ -164,7 +166,7 @@ def update_critic(self, data: DataProto): self.critic_optimizer.zero_grad() for data in micro_batches: - data = data.cuda() # critic device is cpu when using offload + data = data.to(self.device) # critic device is cpu when using offload input_ids = data['input_ids'] responses = data['responses'] attention_mask = data['attention_mask'] diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 1660bd06..7ee2538a 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -30,29 +30,31 @@ from verl.utils import hf_tokenizer from verl.utils.debug import log_gpu_memory_usage from verl.utils.fs import copy_local_path_from_hdfs -from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager -from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_model_to_cpu, load_fsdp_optimizer, \ - load_fsdp_model_to_gpu +from verl.utils.fsdp_utils import get_fsdp_wrap_policy, offload_fsdp_grad, init_fn, get_init_weight_context_manager +from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_param_and_grad, load_fsdp_optimizer, \ + load_fsdp_param_and_grad from verl.utils.import_utils import import_external_libs from verl.utils.model import compute_position_id_with_mask from verl.utils.flops_counter import FlopsCounter from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager +from verl.utils.device import get_device_name, is_cuda_available from codetiming import Timer logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) +DEVICE = get_device_name() def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: - device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp']) + device_mesh = init_device_mesh(DEVICE, mesh_shape=(world_size,), mesh_dim_names=['fsdp']) else: raise ValueError( 'HSDP is not supported yet because it produces incorrect results for now. Please set fsdp_size=-1') assert world_size % fsdp_size == 0 - device_mesh = init_device_mesh('cuda', + device_mesh = init_device_mesh(DEVICE, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=['ddp', 'fsdp']) return device_mesh @@ -80,7 +82,7 @@ def __init__(self, config: DictConfig, role: str): self.config = config import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") # build device mesh for FSDP world_size = torch.distributed.get_world_size() @@ -92,7 +94,7 @@ def __init__(self, config: DictConfig, role: str): self.ulysses_sequence_parallel_size = self.config.actor.get('ulysses_sequence_parallel_size', 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh('cuda', + self.ulysses_device_mesh = init_device_mesh(DEVICE, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=['dp', 'sp']) @@ -106,9 +108,11 @@ def __init__(self, config: DictConfig, role: str): self._is_ref = self.role in ['ref', 'actor_rollout_ref'] self._is_offload_param = False + self._is_offload_grad = False self._is_offload_optimizer = False if self._is_actor: self._is_offload_param = self.config.actor.fsdp_config.get('param_offload', False) + self._is_offload_grad = self.config.actor.fsdp_config.get('grad_offload', False) self._is_offload_optimizer = self.config.actor.fsdp_config.get('optimizer_offload', False) elif self._is_ref: # TODO: it seems that manual offload is slowly than FSDP offload @@ -197,7 +201,7 @@ def _build_model_optimizer(self, actor_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=actor_model_config, - attn_implementation='flash_attention_2', + # attn_implementation='flash_attention_2', trust_remote_code=trust_remote_code) # Apply Liger kernel to the model if use_liger is set to True if use_liger: @@ -250,7 +254,7 @@ def _build_model_optimizer(self, param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=torch.npu.current_device(), sharding_strategy=sharding_strategy, # zero3 mixed_precision=mixed_precision, sync_module_states=True, @@ -289,7 +293,7 @@ def _build_rollout(self): infer_tp = self.config.rollout.tensor_model_parallel_size dp = self.world_size // infer_tp assert self.world_size % infer_tp == 0, f'rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}' - rollout_device_mesh = init_device_mesh('cuda', mesh_shape=(dp, infer_tp), mesh_dim_names=['dp', 'infer_tp']) + rollout_device_mesh = init_device_mesh(DEVICE, mesh_shape=(dp, infer_tp), mesh_dim_names=['dp', 'infer_tp']) if self.config.rollout.name == 'hf': from verl.workers.rollout import HFRollout @@ -360,6 +364,10 @@ def init_model(self): # get the original unwrapped module self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module + if self._is_offload_param: + # param is require during state_dict in sharding manager + offload_fsdp_grad(module=self.actor_module_fsdp) + log_gpu_memory_usage('After offload actor grad during init', logger=logger) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) log_gpu_memory_usage('After offload actor optimizer during init', logger=logger) @@ -397,19 +405,21 @@ def init_model(self): lr_scheduler=self.actor_lr_scheduler, tokenizer=self.tokenizer) - torch.cuda.empty_cache() + torch.npu.empty_cache() @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): - data = data.to('cuda') + data = data.to(DEVICE) assert self._is_actor if self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) + load_fsdp_param_and_grad(module=self.actor_module_fsdp, + device_id=torch.npu.current_device(), + load_grad=self._is_offload_grad) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()) + load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.npu.current_device()) - data.batch = data.batch.cuda() + data.batch = data.batch.to(DEVICE) log_gpu_memory_usage('Before update policy', logger=logger) @@ -436,21 +446,23 @@ def update_actor(self, data: DataProto): output = output.to('cpu') if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) + offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) - torch.cuda.empty_cache() + torch.npu.empty_cache() return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto): - prompts = prompts.to('cuda') + prompts = prompts.to(DEVICE) assert self._is_rollout if self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) + load_fsdp_param_and_grad(module=self.actor_module_fsdp, + device_id=torch.npu.current_device(), + load_grad=self._is_offload_grad) - prompts.batch = prompts.batch.cuda() + prompts.batch = prompts.batch.to(DEVICE) meta_info = { 'eos_token_id': self.generation_config.eos_token_id @@ -461,13 +473,6 @@ def generate_sequences(self, prompts: DataProto): } prompts.meta_info.update(meta_info) with self.rollout_sharding_manager: - - # after parameters sync with rollout, offload actor model to CPU - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.actor_optimizer) - log_gpu_memory_usage('After entering rollout sharding manager', logger=logger) prompts = self.rollout_sharding_manager.preprocess_data(prompts) @@ -479,8 +484,11 @@ def generate_sequences(self, prompts: DataProto): output = output.to('cpu') + if self._is_offload_param: + # NOTE(sgm): the grad is already in CPU, only offload param here + offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) # clear kv cache - torch.cuda.empty_cache() + torch.npu.empty_cache() log_gpu_memory_usage('After recompute log prob', logger=logger) return output @@ -488,8 +496,10 @@ def generate_sequences(self, prompts: DataProto): def compute_log_prob(self, data: DataProto): assert self._is_actor if self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) - data = data.to('cuda') + load_fsdp_param_and_grad(module=self.actor_module_fsdp, + device_id=torch.npu.current_device(), + load_grad=self._is_offload_grad) + data = data.to(DEVICE) # we should always recompute old_log_probs when it is HybridEngine data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu data.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu @@ -511,10 +521,11 @@ def compute_log_prob(self, data: DataProto): self.actor.actor_module._handle.reshard(True) if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) + # NOTE(sgm): the grad is already in CPU, only offload param here + offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) # clear kv cache - torch.cuda.empty_cache() + torch.npu.empty_cache() log_gpu_memory_usage('After compute_log_prob', logger=logger) return output @@ -522,7 +533,7 @@ def compute_log_prob(self, data: DataProto): def compute_ref_log_prob(self, data: DataProto): assert self._is_ref - data = data.to('cuda') + data = data.to(DEVICE) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info['micro_batch_size'] = micro_batch_size @@ -542,7 +553,7 @@ def compute_ref_log_prob(self, data: DataProto): if self.world_size > 1: self.ref_policy.actor_module._handle.reshard(True) - torch.cuda.empty_cache() + torch.npu.empty_cache() return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) @@ -551,7 +562,9 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_prev assert self._is_actor import torch if self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) + load_fsdp_param_and_grad(module=self.actor_module_fsdp, + device_id=torch.npu.current_device(), + load_grad=self._is_offload_grad) self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, @@ -560,17 +573,19 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_prev torch.distributed.barrier() if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) + offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, path, del_local_after_load=False): if self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) + load_fsdp_param_and_grad(module=self.actor_module_fsdp, + device_id=torch.npu.current_device(), + load_grad=self._is_offload_grad) self.checkpoint_manager.load_checkpoint(path=path, del_local_after_load=del_local_after_load) if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) + offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) class CriticWorker(Worker): @@ -579,7 +594,7 @@ def __init__(self, config): super().__init__() import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") self.config = config # build device mesh for Ulysses Sequence Parallel @@ -593,7 +608,7 @@ def __init__(self, config): self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh('cuda', + self.ulysses_device_mesh = init_device_mesh(DEVICE, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=['dp', 'sp']) @@ -601,6 +616,7 @@ def __init__(self, config): # set FSDP offload params self._is_offload_param = self.config.model.fsdp_config.param_offload + self._is_offload_grad = self.config.model.fsdp_config.grad_offload self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload # normalize config @@ -704,7 +720,7 @@ def _build_critic_model_optimizer(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=torch.npu.current_device(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, @@ -741,7 +757,7 @@ def init_model(self): self.config) if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) + offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.critic_optimizer) @@ -755,14 +771,16 @@ def init_model(self): lr_scheduler=self.critic_lr_scheduler, tokenizer=self.tokenizer) - torch.cuda.empty_cache() + torch.npu.empty_cache() @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_values(self, data: DataProto): - data = data.to('cuda') + data = data.to(DEVICE) if self._is_offload_param: - load_fsdp_model_to_gpu(self.critic_module) + load_fsdp_param_and_grad(module=self.critic_module, + device_id=torch.npu.current_device(), + load_grad=self._is_offload_grad) micro_batch_size = self.config.forward_micro_batch_size_per_gpu data.meta_info['micro_batch_size'] = micro_batch_size data.meta_info['max_token_len'] = self.config.forward_max_token_len_per_gpu @@ -776,16 +794,19 @@ def compute_values(self, data: DataProto): output = output.to('cpu') if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) + offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad) + torch.npu.empty_cache() return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_critic(self, data: DataProto): - data = data.to('cuda') + data = data.to(DEVICE) if self._is_offload_param: - load_fsdp_model_to_gpu(self.critic_module) + load_fsdp_param_and_grad(module=self.critic_module, + device_id=torch.npu.current_device(), + load_grad=self._is_offload_grad) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device()) + load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.npu.current_device()) # perform forward computation with self.ulysses_sharding_manager: @@ -807,10 +828,10 @@ def update_critic(self, data: DataProto): output = self.ulysses_sharding_manager.postprocess_data(data=output) if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) + offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.critic_optimizer) - torch.cuda.empty_cache() + torch.npu.empty_cache() output = output.to('cpu') return output @@ -818,7 +839,9 @@ def update_critic(self, data: DataProto): def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=False): import torch if self._is_offload_param: - load_fsdp_model_to_gpu(self.critic_module) + load_fsdp_param_and_grad(module=self.critic_module, + device_id=torch.npu.current_device(), + load_grad=self._is_offload_grad) self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, @@ -827,19 +850,21 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_prev torch.distributed.barrier() if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) + offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, path, del_local_after_load=True): import torch if self._is_offload_param: - load_fsdp_model_to_gpu(self.critic_module) + load_fsdp_param_and_grad(module=self.critic_module, + device_id=torch.npu.current_device(), + load_grad=self._is_offload_grad) self.checkpoint_manager.load_checkpoint(path=path, del_local_after_load=del_local_after_load) torch.distributed.barrier() if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) + offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad) # TODO(sgm): we may need to extract it to dp_reward_model.py @@ -852,7 +877,7 @@ def __init__(self, config): super().__init__() import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") self.config = config # build device mesh for Ulysses Sequence Parallel @@ -866,7 +891,7 @@ def __init__(self, config): self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh('cuda', + self.ulysses_device_mesh = init_device_mesh(DEVICE, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=['dp', 'sp']) @@ -931,7 +956,7 @@ def _build_model(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=torch.npu.current_device(), sharding_strategy=sharding_strategy, # zero3 sync_module_states=True, cpu_offload=CPUOffload(offload_params=True), @@ -945,13 +970,13 @@ def init_model(self): # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get('external_lib', None)) self.reward_module = self._build_model(config=self.config) - torch.cuda.empty_cache() + torch.npu.empty_cache() def _forward_micro_batch(self, micro_batch): from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis, rearrange from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad - with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.no_grad(), torch.autocast(device_type=DEVICE, dtype=torch.bfloat16): input_ids = micro_batch['input_ids'] batch_size, seqlen = input_ids.shape attention_mask = micro_batch['attention_mask'] @@ -1077,11 +1102,11 @@ def _switch_chat_template(self, data: DataProto): def compute_rm_score(self, data: DataProto): import itertools from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx - data = data.to('cuda') + data = data.to(DEVICE) if self._do_switch_chat_template: rm_data = self._switch_chat_template(data) - rm_data.batch = rm_data.batch.cuda() + rm_data.batch = rm_data.batch.to(DEVICE) # perform forward computation with self.ulysses_sharding_manager: @@ -1116,5 +1141,5 @@ def compute_rm_score(self, data: DataProto): self.reward_module._handle.reshard(True) output = output.to('cpu') - torch.cuda.empty_cache() + torch.npu.empty_cache() return output diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index bcee3544..88bfc174 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -38,6 +38,7 @@ from vllm.distributed import parallel_state as vllm_ps from vllm import LLM, SamplingParams from verl.third_party.vllm import vllm_version +from verl.utils.device import is_cuda_available # TODO # 1. support pp in vllm @@ -91,7 +92,7 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf self.inference_engine = LLM( model=model_path, - enable_sleep_mode=True, + enable_sleep_mode=True if is_cuda_available else False, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend="external_launcher", dtype=config.dtype, @@ -106,7 +107,8 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf ) # Offload vllm model to reduce peak memory usage - self.inference_engine.sleep(level=1) + if is_cuda_available: + self.inference_engine.sleep(level=1) kwargs = dict( n=1, diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index c79d3031..7e776043 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -25,6 +25,7 @@ from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors) from verl.utils.debug import log_gpu_memory_usage from verl.third_party.vllm import vllm_version +from verl.utils.device import is_cuda_available from .base import BaseShardingManager @@ -57,13 +58,15 @@ def __init__(self, state_dict_config=ShardedStateDictConfig()) # Note that torch_random_states may be different on each dp rank - self.torch_random_states = torch.cuda.get_rng_state() + self.torch_random_states = torch.cuda.get_rng_state() if is_cuda_available else torch.npu.get_rng_state() # get a random rng states if self.device_mesh is not None: gen_dp_rank = self.device_mesh['dp'].get_local_rank() - torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) + torch.cuda.manual_seed(gen_dp_rank + 1000) if is_cuda_available else \ + torch.npu.manual_seed(gen_dp_rank + 1000)# make sure all tp ranks have the same random states + self.gen_random_states = torch.cuda.get_rng_state() if is_cuda_available else torch.npu.get_rng_state() + torch.cuda.set_rng_state(self.torch_random_states) if is_cuda_available else \ + torch.npu.set_rng_state(self.torch_random_states) else: self.gen_random_states = None @@ -76,7 +79,7 @@ def __enter__(self): if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): self.inference_engine.sync_model_weights(params, load_format=load_format) else: - self.inference_engine.wake_up() + # self.inference_engine.wake_up() # TODO(ZSL): deal with 'hf' format if load_format == 'dtensor': from verl.third_party.vllm import load_dtensor_weights @@ -87,27 +90,28 @@ def __enter__(self): log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger) del params - torch.cuda.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() log_gpu_memory_usage('After del state_dict and empty_cache in sharding manager', logger=logger) # TODO: offload FSDP model weights # self.module.cpu() - # torch.cuda.empty_cache() + # torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() # if torch.distributed.get_rank() == 0: # print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB') # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: - self.torch_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.gen_random_states) + self.torch_random_states = torch.cuda.get_rng_state() if is_cuda_available else torch.npu.get_rng_state() + torch.cuda.set_rng_state(self.gen_random_states) if is_cuda_available else \ + torch.npu.set_rng_state(self.gen_random_states) def __exit__(self, exc_type, exc_value, traceback): log_gpu_memory_usage('Before vllm offload in sharding manager', logger=logger) # TODO(ZSL): check this if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): self.inference_engine.offload_model_weights() - else: - self.inference_engine.sleep(level=1) + # else: + # self.inference_engine.sleep(level=1) log_gpu_memory_usage('After vllm offload in sharding manager', logger=logger) # self.module.to('cuda') @@ -117,12 +121,13 @@ def __exit__(self, exc_type, exc_value, traceback): self.module.train() # add empty cache after each compute - torch.cuda.empty_cache() + torch.cuda.empty_cache() if is_cuda_available else torch.npu.empty_cache() # restore random states if self.device_mesh is not None: - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) + self.gen_random_states = torch.cuda.get_rng_state() if is_cuda_available else torch.npu.get_rng_state() + torch.cuda.set_rng_state(self.torch_random_states) if is_cuda_available else \ + torch.npu.set_rng_state(self.torch_random_states) def preprocess_data(self, data: DataProto) -> DataProto: # TODO: Current impl doesn't consider FSDP with torch micro-dp @@ -156,4 +161,4 @@ def postprocess_data(self, data: DataProto) -> DataProto: # TODO: shall we build a micro_dp group for vllm when integrating with vLLM? local_prompts = data.chunk(chunks=tp_size) data = local_prompts[dp_rank % tp_size] - return data + return data \ No newline at end of file