Skip to content

Commit

Permalink
[rollout]: fix incorrect response_attention_mask in vLLM rollout (#213)
Browse files Browse the repository at this point in the history
This PR addresses issue #212.

The changes include:
- read eos_token_id from generation_config to ensure alignment with vLLM
- modified the get_eos_mask function to accept both int and list types
for the eos_token parameter.
  • Loading branch information
kinman0224 authored Feb 7, 2025
1 parent 27484a7 commit 3140cc2
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 11 deletions.
21 changes: 19 additions & 2 deletions verl/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
"""
import os
import warnings
from typing import Dict, Type
from typing import Dict, Type, Optional

import numpy as np
import torch
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, MistralForSequenceClassification
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, MistralForSequenceClassification, GenerationConfig
from verl.models.registry import ModelRegistry


Expand Down Expand Up @@ -55,6 +55,23 @@ def get_huggingface_actor_config(model_name: str, override_config_kwargs=None, t
return module_config


def get_generation_config(
model: str,
trust_remote_code: bool = False,
) -> Optional[GenerationConfig]:
try:
return GenerationConfig.from_pretrained(model)
except OSError: # Not found
try:
config = get_huggingface_actor_config(
model,
trust_remote_code=trust_remote_code,
)
return GenerationConfig.from_model_config(config)
except OSError: # Not found
return None


def create_huggingface_actor(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module:
"""
Expand Down
14 changes: 11 additions & 3 deletions verl/utils/torch_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,21 @@ def masked_whiten(values, mask, shift_mean=True):
return whitened


def get_eos_mask(response_id: torch.Tensor, eos_token: int = 2, dtype=torch.int64):
def get_eos_mask(response_id: torch.Tensor, eos_token: Union[int, List[int]] = 2, dtype=torch.int64):
'''
e.g. end of sentence token=1
end of sentence token can be int or list: 1 or [1, 2]
e.g. eos_token=1
response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0]
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
'''
eos_mask = response_id.eq(eos_token).long()
if isinstance(eos_token, int):
eos_token = [eos_token]

eos_mask = torch.zeros_like(response_id, dtype=torch.bool)
for token in eos_token:
eos_mask |= response_id.eq(token)

eos_mask = eos_mask.long()
eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool()
eos_mask = torch.logical_not(eos_mask).to(dtype)
return eos_mask
Expand Down
13 changes: 11 additions & 2 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _build_model_optimizer(self,
trust_remote_code=False,
use_liger=False,
role='actor'):
from verl.utils.model import print_model_size, update_model_config
from verl.utils.model import print_model_size, update_model_config, get_generation_config
from verl.utils.torch_dtypes import PrecisionType
from transformers import AutoModelForCausalLM, AutoConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, CPUOffload
Expand All @@ -171,6 +171,8 @@ def _build_model_optimizer(self,
# override model kwargs
actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)

self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)

if use_remove_padding:
from verl.models.registry import check_model_support_rmpad
check_model_support_rmpad(actor_model_config.model_type)
Expand Down Expand Up @@ -445,7 +447,14 @@ def generate_sequences(self, prompts: DataProto):
load_grad=self._is_offload_grad)

prompts.batch = prompts.batch.cuda()
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id}
meta_info = {
'eos_token_id':
self.generation_config.eos_token_id
if self.generation_config is not None else self.tokenizer.eos_token_id,
'pad_token_id':
self.generation_config.pad_token_id
if self.generation_config is not None else self.tokenizer.pad_token_id,
}
prompts.meta_info.update(meta_info)
with self.rollout_sharding_manager:
log_gpu_memory_usage('After entering rollout sharding manager', logger=logger)
Expand Down
15 changes: 12 additions & 3 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def _build_model_optimizer(self,
enable_gradient_checkpointing=False):
from verl.utils.megatron.optimizer import get_megatron_optimizer
from megatron.core.models.gpt.gpt_model import ModelType
from verl.utils.model import print_model_size, update_model_config
from verl.utils.model import print_model_size, update_model_config, get_generation_config
from verl.utils.megatron_utils import get_model, init_megatron_optim_config
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, GenerationConfig

# Step 1: initialize the tokenizer
local_path = copy_local_path_from_hdfs(model_path)
Expand All @@ -146,6 +146,8 @@ def _build_model_optimizer(self,
# Step 2: get the actor_model_config
actor_model_config = AutoConfig.from_pretrained(local_path)

self.generation_config = get_generation_config(local_path)

override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
Expand Down Expand Up @@ -352,7 +354,14 @@ def generate_sequences(self, prompts: DataProto):
assert self._is_rollout

prompts.batch = prompts.batch.cuda()
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id}
meta_info = {
'eos_token_id':
self.generation_config.eos_token_id
if self.generation_config is not None else self.tokenizer.eos_token_id,
'pad_token_id':
self.generation_config.pad_token_id
if self.generation_config is not None else self.tokenizer.pad_token_id,
}
prompts.meta_info.update(meta_info)
with self.sharding_manager:
log_gpu_memory_usage('After entering sharding manager', logger=logger)
Expand Down
5 changes: 4 additions & 1 deletion verl/workers/rollout/naive/naive_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:

# used to construct attention_mask
eos_token_id = prompts.meta_info['eos_token_id']
if isinstance(eos_token, int):
eos_token = [eos_token]

batch_size = idx.size(0)
prompt_length = idx.size(1)
Expand Down Expand Up @@ -90,7 +92,8 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:

attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1)

prev_attention_mask = torch.logical_and(idx_next != eos_token_id, prev_attention_mask.bool())
for token_id in eos_token_id:
prev_attention_mask = torch.logical_and(idx_next != token_id, prev_attention_mask.bool())
prev_attention_mask.to(attention_mask.dtype)

position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1)
Expand Down

0 comments on commit 3140cc2

Please sign in to comment.