From 3140cc2f9f055f4761825a453de347db43dab304 Mon Sep 17 00:00:00 2001 From: Kinman Lei Date: Sat, 8 Feb 2025 00:55:04 +0800 Subject: [PATCH] [rollout]: fix incorrect response_attention_mask in vLLM rollout (#213) This PR addresses issue https://github.com/volcengine/verl/issues/212. The changes include: - read eos_token_id from generation_config to ensure alignment with vLLM - modified the get_eos_mask function to accept both int and list types for the eos_token parameter. --- verl/utils/model.py | 21 +++++++++++++++++++-- verl/utils/torch_functional.py | 14 +++++++++++--- verl/workers/fsdp_workers.py | 13 +++++++++++-- verl/workers/megatron_workers.py | 15 ++++++++++++--- verl/workers/rollout/naive/naive_rollout.py | 5 ++++- 5 files changed, 57 insertions(+), 11 deletions(-) diff --git a/verl/utils/model.py b/verl/utils/model.py index 9002451a..f319e400 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -16,12 +16,12 @@ """ import os import warnings -from typing import Dict, Type +from typing import Dict, Type, Optional import numpy as np import torch from torch import nn -from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, MistralForSequenceClassification +from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, MistralForSequenceClassification, GenerationConfig from verl.models.registry import ModelRegistry @@ -55,6 +55,23 @@ def get_huggingface_actor_config(model_name: str, override_config_kwargs=None, t return module_config +def get_generation_config( + model: str, + trust_remote_code: bool = False, +) -> Optional[GenerationConfig]: + try: + return GenerationConfig.from_pretrained(model) + except OSError: # Not found + try: + config = get_huggingface_actor_config( + model, + trust_remote_code=trust_remote_code, + ) + return GenerationConfig.from_model_config(config) + except OSError: # Not found + return None + + def create_huggingface_actor(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module: """ diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index cd9547ff..51c7b08d 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -138,13 +138,21 @@ def masked_whiten(values, mask, shift_mean=True): return whitened -def get_eos_mask(response_id: torch.Tensor, eos_token: int = 2, dtype=torch.int64): +def get_eos_mask(response_id: torch.Tensor, eos_token: Union[int, List[int]] = 2, dtype=torch.int64): ''' - e.g. end of sentence token=1 + end of sentence token can be int or list: 1 or [1, 2] + e.g. eos_token=1 response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0] eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0] ''' - eos_mask = response_id.eq(eos_token).long() + if isinstance(eos_token, int): + eos_token = [eos_token] + + eos_mask = torch.zeros_like(response_id, dtype=torch.bool) + for token in eos_token: + eos_mask |= response_id.eq(token) + + eos_mask = eos_mask.long() eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool() eos_mask = torch.logical_not(eos_mask).to(dtype) return eos_mask diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 9a56d07e..357c1415 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -147,7 +147,7 @@ def _build_model_optimizer(self, trust_remote_code=False, use_liger=False, role='actor'): - from verl.utils.model import print_model_size, update_model_config + from verl.utils.model import print_model_size, update_model_config, get_generation_config from verl.utils.torch_dtypes import PrecisionType from transformers import AutoModelForCausalLM, AutoConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, CPUOffload @@ -171,6 +171,8 @@ def _build_model_optimizer(self, # override model kwargs actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) + self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code) + if use_remove_padding: from verl.models.registry import check_model_support_rmpad check_model_support_rmpad(actor_model_config.model_type) @@ -445,7 +447,14 @@ def generate_sequences(self, prompts: DataProto): load_grad=self._is_offload_grad) prompts.batch = prompts.batch.cuda() - meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id} + meta_info = { + 'eos_token_id': + self.generation_config.eos_token_id + if self.generation_config is not None else self.tokenizer.eos_token_id, + 'pad_token_id': + self.generation_config.pad_token_id + if self.generation_config is not None else self.tokenizer.pad_token_id, + } prompts.meta_info.update(meta_info) with self.rollout_sharding_manager: log_gpu_memory_usage('After entering rollout sharding manager', logger=logger) diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 3871dd21..dc531f4d 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -135,9 +135,9 @@ def _build_model_optimizer(self, enable_gradient_checkpointing=False): from verl.utils.megatron.optimizer import get_megatron_optimizer from megatron.core.models.gpt.gpt_model import ModelType - from verl.utils.model import print_model_size, update_model_config + from verl.utils.model import print_model_size, update_model_config, get_generation_config from verl.utils.megatron_utils import get_model, init_megatron_optim_config - from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig + from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, GenerationConfig # Step 1: initialize the tokenizer local_path = copy_local_path_from_hdfs(model_path) @@ -146,6 +146,8 @@ def _build_model_optimizer(self, # Step 2: get the actor_model_config actor_model_config = AutoConfig.from_pretrained(local_path) + self.generation_config = get_generation_config(local_path) + override_config_kwargs = { 'bos_token_id': self.tokenizer.bos_token_id, 'eos_token_id': self.tokenizer.eos_token_id, @@ -352,7 +354,14 @@ def generate_sequences(self, prompts: DataProto): assert self._is_rollout prompts.batch = prompts.batch.cuda() - meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id} + meta_info = { + 'eos_token_id': + self.generation_config.eos_token_id + if self.generation_config is not None else self.tokenizer.eos_token_id, + 'pad_token_id': + self.generation_config.pad_token_id + if self.generation_config is not None else self.tokenizer.pad_token_id, + } prompts.meta_info.update(meta_info) with self.sharding_manager: log_gpu_memory_usage('After entering sharding manager', logger=logger) diff --git a/verl/workers/rollout/naive/naive_rollout.py b/verl/workers/rollout/naive/naive_rollout.py index 6f2e8d59..5fbb93da 100644 --- a/verl/workers/rollout/naive/naive_rollout.py +++ b/verl/workers/rollout/naive/naive_rollout.py @@ -57,6 +57,8 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: # used to construct attention_mask eos_token_id = prompts.meta_info['eos_token_id'] + if isinstance(eos_token, int): + eos_token = [eos_token] batch_size = idx.size(0) prompt_length = idx.size(1) @@ -90,7 +92,8 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1) - prev_attention_mask = torch.logical_and(idx_next != eos_token_id, prev_attention_mask.bool()) + for token_id in eos_token_id: + prev_attention_mask = torch.logical_and(idx_next != token_id, prev_attention_mask.bool()) prev_attention_mask.to(attention_mask.dtype) position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1)