Skip to content

Commit

Permalink
Add stronger reward verification sandbox (#233)
Browse files Browse the repository at this point in the history
Add stronger verification support as is used in
https://github.com/PRIME-RL/PRIME

- [x] Batched verification
- [x] Python interpreter
- [x] Stronger math verifier
- [x] Continuous score for code test

Re-opening #207 to trigger
automatic workflows
  • Loading branch information
ZefanW authored Feb 10, 2025
1 parent 16b1984 commit 5a66ed2
Show file tree
Hide file tree
Showing 17 changed files with 2,284 additions and 80 deletions.
42 changes: 42 additions & 0 deletions .github/workflows/sandbox.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: sandbox

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/sandbox.yml
pull_request:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/sandbox.yml

jobs:
sandbox:
runs-on: [self-hosted, l20-0]
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1"
HF_HUB_ENABLE_HF_TRANSFER: 1
container:
image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install hf_transfer
pip3 install -e .[test]
pip3 install vllm==0.5.4
- name: Running sandbox tests on 8 L20 GPUs
run: |
cd tests/sandbox
pytest -s -x .
5 changes: 5 additions & 0 deletions docs/examples/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ Reward Model
param_offload: False
micro_batch_size_per_gpu: 16
max_length: null
reward_manager: naive
- ``reward_model.enable``: Whether to enable reward model. If False, we
compute the reward only with the user-defined reward functions. In
Expand All @@ -307,6 +308,10 @@ Reward Model
- ``path``: RM's HDFS path or local path. Note that RM only supports
AutoModelForSequenceClassification. Other model types need to define
their own RewardModelWorker and pass it from the code.
- ``reward_model.reward_manager``: Reward Manager. This defines the mechanism
of computing rule-based reward and handling different reward sources. Default
if ``naive``. If all verification functions are multiprocessing-safe, the reward
manager can be set to ``prime`` for parallel verification.

Algorithm
~~~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ dependencies = [
"vllm<=0.6.3",
"peft",
"liger-kernel",
"pylatexenc",
"pyext"
]

# Optional dependencies (extras_require in setup.py)
Expand Down
139 changes: 139 additions & 0 deletions tests/sandbox/test_sandbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json

from verl.utils.reward_score import _default_compute_score
from verl.utils.reward_score.prime_code import apps_check_correctness
import asyncio
from verl.workers.reward_manager.prime import parallel_compute_score_async

prime_math_answers = [
"""\\begin{bmatrix}\n -7 & 6 & -8 \\\\\n 11 & -9 & 12 \\\\\n 15 & -16 & 19 \n \\end{bmatrix}""",
"""\\frac{\\sqrt{505}}{7}""", """x^2 + y^2 + 4x - 6y + 13"""
]
prime_math_gts = [
"""\\begin{pmatrix}\n -7 & 6 & -8 \\\\\n 11 & -9 & 12 \\\\\n 15 & -16 & 19\n \\end{pmatrix}""", # mat test
"""\\frac{\\sqrt{505}}{7}""", # frac test
"""(x + 2)^2 + (y - 3)^2 """ # symbolic test
]

prime_code_answers = [
"""import sys
from collections import deque
def main():
data = sys.stdin.read().split()
it = iter(data)
# Read start and target positions
x0, y0, x1, y1 = int(next(it)), int(next(it)), int(next(it)), int(next(it))
n = int(next(it))
allowed = set()
# The total number of allowed cells is at most 10^5.
for _ in range(n):
r = int(next(it))
a = int(next(it))
b = int(next(it))
for c in range(a, b + 1):
allowed.add((r, c))
# Directions for the king (8 neighboring cells)
directions = [(-1, -1), (-1, 0), (-1, 1),
(0, -1), (0, 1),
(1, -1), (1, 0), (1, 1)]
start = (x0, y0)
target = (x1, y1)
# BFS initialization
queue = deque()
queue.append((x0, y0, 0))
# Mark the starting cell as visited by removing it from allowed set.
allowed.discard(start)
while queue:
x, y, moves = queue.popleft()
if (x, y) == target:
print(moves)
return
for dx, dy in directions:
nx, ny = x + dx, y + dy
if (nx, ny) in allowed:
allowed.remove((nx, ny))
queue.append((nx, ny, moves + 1))
print(-1)
if __name__ == '__main__':
main()
"""
] * 2
prime_code_gts = [
"""{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"2\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}""", # A correct sample
"""{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}"""
] # A failed sample with first several in-out passed

prime_code_scores = [1.0, 0.9]


def test_parallelism():
"""
Test if process pool works properly
"""
sequences_str = []
ground_truth = []
data_sources = []
while len(sequences_str) < 32:
sequences_str.extend(prime_code_answers)
ground_truth.extend(prime_code_gts)
data_sources.extend(['codecontests'] * len(prime_code_answers))

sequences_str.extend(prime_math_answers)
ground_truth.extend(prime_math_gts)
data_sources.extend(['numina_aops_forum'] * len(prime_math_answers))

scores = asyncio.run(
parallel_compute_score_async(_default_compute_score,
sequences_str,
ground_truth,
data_sources,
num_processes=16))
print(scores)


def test_prime_code():
"""
Test PRIME code sandbox.
"""
data_source = 'codecontests'
for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores):
score = _default_compute_score(data_source, completion, ground_truth)
assert float(score) == score_


def test_check_correctness():
completion = prime_code_answers[0]
ground_truth = json.loads(prime_code_gts[0])
ground_truth_single = {'inputs': ground_truth['inputs'][:1], 'outputs': ground_truth['outputs'][:1]}
res, meta = apps_check_correctness(in_outs=ground_truth_single, generation=completion, timeout=5, debug=False)
print(res, meta)


def test_prime_math():
data_source = 'numina_aops_forum'
for completion, ground_truth in zip(prime_math_answers, prime_math_gts):
score = _default_compute_score(data_source, completion, ground_truth)
assert float(score) == 1.0
11 changes: 6 additions & 5 deletions tests/sanity/check_license.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

license_head = "Copyright 2024 Bytedance Ltd. and/or its affiliates"
license_head_bytedance = "Copyright 2024 Bytedance Ltd. and/or its affiliates"
license_head2_prime = "Copyright 2024 PRIME team and/or its affiliates"

from pathlib import Path
from argparse import ArgumentParser
Expand All @@ -27,9 +28,9 @@
for path in pathlist:
# because path is object not string
path_in_str = str(path.absolute())
with open(path_in_str, 'r') as f:
print(path_in_str)
with open(path_in_str, 'r', encoding='utf-8') as f:
file_content = f.read()

assert license_head in file_content, f'file {path_in_str} does not contain license'

print(path_in_str)
assert license_head_bytedance in file_content or \
license_head2_prime in file_content, f'file {path_in_str} does not contain license'
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ reward_model:
ulysses_sequence_parallel_size: 1 # sp size
use_dynamic_bsz: ${critic.use_dynamic_bsz}
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
reward_manager: naive

algorithm:
gamma: 1.0
Expand Down
86 changes: 11 additions & 75 deletions verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,81 +14,8 @@
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""

from verl import DataProto
import torch
from verl.utils.reward_score import gsm8k, math
from verl.trainer.ppo.ray_trainer import RayPPOTrainer


def _default_compute_score(data_source, solution_str, ground_truth):
if data_source == 'openai/gsm8k':
return gsm8k.compute_score(solution_str, ground_truth)
elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval']:
return math.compute_score(solution_str, ground_truth)
else:
raise NotImplementedError


class RewardManager():
"""The reward manager.
"""

def __init__(self, tokenizer, num_examine, compute_score=None) -> None:
self.tokenizer = tokenizer
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
self.compute_score = compute_score or _default_compute_score

def __call__(self, data: DataProto):
"""We will expand this function gradually based on the available datasets"""

# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if 'rm_scores' in data.batch.keys():
return data.batch['rm_scores']

reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32)

already_print_data_sources = {}

for i in range(len(data)):
data_item = data[i] # DataProtoItem

prompt_ids = data_item.batch['prompts']

prompt_length = prompt_ids.shape[-1]

valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]

response_ids = data_item.batch['responses']
valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
valid_response_ids = response_ids[:valid_response_length]

# decode
sequences = torch.cat((valid_prompt_ids, valid_response_ids))
sequences_str = self.tokenizer.decode(sequences)

ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']

data_source = data_item.non_tensor_batch['data_source']

score = self.compute_score(
data_source=data_source,
solution_str=sequences_str,
ground_truth=ground_truth,
)
reward_tensor[i, valid_response_length - 1] = score

if data_source not in already_print_data_sources:
already_print_data_sources[data_source] = 0

if already_print_data_sources[data_source] < self.num_examine:
already_print_data_sources[data_source] += 1
print(sequences_str)

return reward_tensor


import ray
import hydra

Expand Down Expand Up @@ -172,10 +99,19 @@ def main_task(config, compute_score=None):
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id

reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)
reward_manager_name = config.reward_model.get("reward_manager", "naive")
if reward_manager_name == 'naive':
from verl.workers.reward_manager import NaiveRewardManager
reward_manager_cls = NaiveRewardManager
elif reward_manager_name == 'prime':
from verl.workers.reward_manager import PrimeRewardManager
reward_manager_cls = PrimeRewardManager
else:
raise NotImplementedError
reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)

# Note that we always use function-based RM for validation
val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1, compute_score=compute_score)
val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score)

resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)

Expand Down
26 changes: 26 additions & 0 deletions verl/utils/reward_score/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# from . import gsm8k, math, prime_math, prime_code


def _default_compute_score(data_source, solution_str, ground_truth):
if data_source == 'openai/gsm8k':
from . import gsm8k
res = gsm8k.compute_score(solution_str, ground_truth)
elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval']:
from . import math
res = math.compute_score(solution_str, ground_truth)
elif data_source in [
'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12',
'numina_olympiads'
]:
from . import prime_math
res = prime_math.compute_score(solution_str, ground_truth)
elif data_source in ['codecontests', 'apps', 'codeforces', 'taco']:
from . import prime_code
res = prime_code.compute_score(solution_str, ground_truth, continuous=True)
else:
raise NotImplementedError

if isinstance(res, (int, float, bool)):
return float(res)
else:
return float(res[0])
Loading

0 comments on commit 5a66ed2

Please sign in to comment.