Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Remote Reward Server Feature #419

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions examples/ppo_trainer/run_ppo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
set -x

export NCCL_DEBUG=WARN
export WANDB_API_KEY=''
export VLLM_ATTENTION_BACKEND=XFORMERS
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export TOKENIZERS_PARALLELISM=true

PROJECT_NAME=
EXPERIMENT_NAME=
DATA_PATH=
SFT_MODEL_PATH=
CKPT_PATH=

PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=256 \
data.val_batch_size=1024 \
data.max_prompt_length=512 \
data.max_response_length=256 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
critic.optim.lr=1e-5 \
critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
critic.ppo_micro_batch_size_per_gpu=4 \
algorithm.kl_ctrl.kl_coef=0.001 \
+trainer.val_before_train=False \
reward_model.reward_manager=generative \
trainer.default_hdfs_dir=null \
trainer.n_gpus_per_node=1 \
trainer.nnodes=1 \
trainer.save_freq=10 \
trainer.test_freq=10 \
trainer.total_epochs=15 2>&1 | tee verl_demo.log \
trainer.logger=['console','wandb'] \
trainer.project_name=$YOUR_PROJECT_NAME \
trainer.experiment_name=$YOUR_RUN_NAME \
19 changes: 19 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,25 @@ reward_model:
use_dynamic_bsz: ${critic.use_dynamic_bsz}
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
reward_manager: naive
reward_manager_config:
tokenizer: null
server_url: null
model_name: null
scoring_prompt: null
num_samples: 4
sc_mode: mean
api_key: null
timeout: 5
temperature: 1.0
initial_concurrency: 50
min_concurrency: 5
max_concurrency: 200
default_score: 0.001
max_retries: 4
max_tokens: 16
min_score: 0
max_score: 5
apply_chat_template: False

algorithm:
gamma: 1.0
Expand Down
20 changes: 15 additions & 5 deletions verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,20 +98,30 @@ def main_task(config, compute_score=None):
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id

reward_manager_name = config.reward_model.get("reward_manager", "naive")
if reward_manager_name == 'naive':
from verl.workers.reward_manager import NaiveRewardManager
reward_manager_cls = NaiveRewardManager
elif reward_manager_name == 'prime':
from verl.workers.reward_manager import PrimeRewardManager
reward_manager_cls = PrimeRewardManager
elif reward_manager_name == 'generative':
from verl.workers.reward_manager import GenerativeRewardManager
reward_manager_cls = GenerativeRewardManager
else:
raise NotImplementedError
reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)

# Note that we always use function-based RM for validation
val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score)
if reward_manager_name == 'generative':
reward_fn = reward_manager_cls(tokenizer=tokenizer,
num_examine=0,
compute_score=compute_score,
config=config.reward_model.reward_manager_config)
val_reward_fn = reward_manager_cls(tokenizer=tokenizer,
num_examine=1,
compute_score=compute_score,
config=config.reward_model.reward_manager_config)
else:
reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)
val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score)

resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)

Expand Down
11 changes: 6 additions & 5 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,9 @@ def _create_dataloader(self):
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation='error')
# fix the bug of evaluating with different batch size
self.val_dataloader = DataLoader(dataset=self.val_dataset,
batch_size=len(self.val_dataset),
batch_size=self.config.data.val_batch_size,
shuffle=True,
drop_last=True,
collate_fn=collate_fn)
Expand Down Expand Up @@ -577,7 +578,9 @@ def _validate(self):

for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)

# print example of the first batch
print('validation generation start')
print('input:', self.tokenizer.decode(test_batch.batch['input_ids'][1], skip_special_tokens=True))
# we only do validation on rule-based rm
if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model':
return {}
Expand Down Expand Up @@ -844,7 +847,6 @@ def fit(self):
timing_raw = {}

batch: DataProto = DataProto.from_single_dict(batch_dict)

# pop those keys for generation
gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])

Expand All @@ -868,13 +870,12 @@ def fit(self):
batch.batch['reward_baselines'] = reward_baseline_tensor

del gen_baseline_batch, gen_baseline_output

batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
dtype=object)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)

# print shape of the first batch
# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
Expand Down
15 changes: 15 additions & 0 deletions verl/utils/prompts.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Score the following solutions for the questions from 1 to 5, where 1 is the worst and 5 is the best. Return in the format of \\bold{} without any additional formatting.
177 changes: 177 additions & 0 deletions verl/utils/reward_score/generative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import asyncio
import numpy as np
import torch
from typing import List, Dict, Any
from openai import AsyncOpenAI
from transformers import AutoTokenizer


def extract_output(solution_text: str):
# Match everything inside the last \boxed{} in the solution text
boxed_pattern = r'\\bold{(.*)}'
matches = re.findall(boxed_pattern, solution_text)
if matches:
return matches[-1].strip()
return None


async def _query_openai_with_semaphore(semaphore: asyncio.Semaphore, client: AsyncOpenAI, sequence_str: str,
config: Dict[str, Any]) -> float:
"""
Request method with semaphore.
"""
async with semaphore:
return await _query_openai_async(client, sequence_str, config)


async def _query_openai_async(client: AsyncOpenAI, sequence_str: str, config) -> float:
"""
Query OpenAI API asynchronously.
"""
max_retries = config.max_retries # Maximum number of retries
retry_count = 0
scoring_prompt = open(config.scoring_prompt, "r").read()
min_score = config.min_score # Minimum valid score
max_score = config.max_score # Maximum valid score
while retry_count < max_retries:
messages = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": scoring_prompt + '\n' + sequence_str
},
]
if config.tokenizer and config.apply_chat_template:
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer)
messages = tokenizer.apply_chat_template(messages, tokenize=False)
try:
response = await client.chat.completions.create(
model=config.model_name,
messages=messages,
max_tokens=config.max_tokens,
temperature=config.temperature,
n=config.num_samples,
)
if config.num_samples == 1:
# Handle single sample case
try:
scores = response.choices[0].message.content.strip()
score = float(extract_output(scores))

# Check if the score is within the valid range
if min_score <= score <= max_score:
return score
else:
print(f"Score {score} out of range [{min_score}, {max_score}]. Retrying...")
retry_count += 1
if retry_count >= max_retries:
print("Max retries reached. Returning default score.")
return config.default_score
continue # Retry the request
except Exception as e:
print(f"Processing error: {e}")
retry_count += 1
if retry_count >= max_retries:
print("Max retries reached. Returning default score.")
return config.default_score
continue # Retry the request
else:
# Handle multiple samples case
raw_scores = [choice.message.content.strip() for choice in response.choices]
valid_scores = []
for score in raw_scores:
try:
extracted_score = float(extract_output(score))
# Check if the score is within the valid range
if min_score <= extracted_score <= max_score:
valid_scores.append(extracted_score)
else:
print(f"Score {extracted_score} out of range [{min_score}, {max_score}]. Skipping...")
except Exception as e:
print(f"Processing error: {e}")
if valid_scores: # If there are any valid scores
if config.sc_mode == "mean":
return float(np.mean(valid_scores))
elif config.sc_mode == "median":
return float(np.median(valid_scores))
elif config.sc_mode == "majority":
return float(np.round(np.mean(valid_scores)))
else:
raise ValueError(f"Unknown consistency mode: {config.sc_mode}")
else:
# No valid scores, retry the request
retry_count += 1
print("No valid scores found. Retrying...")
if retry_count >= max_retries:
print("Max retries reached. Returning default score.")
return config.default_score
continue # Retry the request
except Exception as e:
print(f"Error querying OpenAI API: {e}")
retry_count += 1
if retry_count >= max_retries:
print("Max retries reached. Returning default score.")
return config.default_score
continue # Retry the request


async def process_data_async(data_source: List[str], solution_str: List[str], ground_truth: List[str],
extra_info: List[Dict[str, Any]], config) -> torch.Tensor:
"""
Process data asynchronously using OpenAI API.
"""
reward_tensor = torch.zeros(len(solution_str), dtype=torch.float32)
client = AsyncOpenAI(api_key=config.api_key, base_url=config.server_url)

remaining_tasks = list(range(len(solution_str)))
while remaining_tasks:
# Dynamic semaphore creation
semaphore = asyncio.Semaphore(config.initial_concurrency)
current_batch = remaining_tasks[:config.initial_concurrency]
remaining_tasks = remaining_tasks[config.initial_concurrency:]
tasks = []

for i in current_batch:
prompt = solution_str[i]
response = ground_truth[i]
if response is None:
sequence_str = prompt
else:
sequence_str = f"{prompt}\nReference:\n{response}"

task = _query_openai_with_semaphore(semaphore, client, sequence_str, config)
tasks.append((i, sequence_str, task, data_source[i]))

# Execute tasks in parallel
results = await asyncio.gather(*[task for _, _, task, _ in tasks])

# Adjust concurrency based on success rate
success_rate = sum(1 for r in results if r != config.default_score) / len(results)
if success_rate > 0.6:
config.initial_concurrency = min(config.max_concurrency, int(config.initial_concurrency * 2))
else:
config.initial_concurrency = max(1, int(config.initial_concurrency / 2))

# Update reward tensor
for idx, (i, _, _, _) in enumerate(tasks):
reward_tensor[i] = results[idx]

return reward_tensor
3 changes: 2 additions & 1 deletion verl/workers/reward_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.

from .naive import NaiveRewardManager
from .prime import PrimeRewardManager
from .prime import PrimeRewardManager
from .generative import GenerativeRewardManager
Loading