diff --git a/.github/workflows/sandbox.yml b/.github/workflows/sandbox.yml new file mode 100644 index 00000000..a1fdd7cd --- /dev/null +++ b/.github/workflows/sandbox.yml @@ -0,0 +1,42 @@ +name: sandbox + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/sandbox.yml + pull_request: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/sandbox.yml + +jobs: + sandbox: + runs-on: [self-hosted, l20-0] + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1" + HF_HUB_ENABLE_HF_TRANSFER: 1 + container: + image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install hf_transfer + pip3 install -e .[test] + pip3 install vllm==0.5.4 + - name: Running sandbox tests on 8 L20 GPUs + run: | + cd tests/sandbox + pytest -s -x . diff --git a/docs/examples/config.rst b/docs/examples/config.rst index 947d4252..e8ea0588 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -292,6 +292,7 @@ Reward Model param_offload: False micro_batch_size_per_gpu: 16 max_length: null + reward_manager: naive - ``reward_model.enable``: Whether to enable reward model. If False, we compute the reward only with the user-defined reward functions. In @@ -307,6 +308,10 @@ Reward Model - ``path``: RM's HDFS path or local path. Note that RM only supports AutoModelForSequenceClassification. Other model types need to define their own RewardModelWorker and pass it from the code. +- ``reward_model.reward_manager``: Reward Manager. This defines the mechanism + of computing rule-based reward and handling different reward sources. Default + if ``naive``. If all verification functions are multiprocessing-safe, the reward + manager can be set to ``prime`` for parallel verification. Algorithm ~~~~~~~~~ diff --git a/pyproject.toml b/pyproject.toml index b201b335..7b28eee5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,8 @@ dependencies = [ "vllm<=0.6.3", "peft", "liger-kernel", + "pylatexenc", + "pyext" ] # Optional dependencies (extras_require in setup.py) diff --git a/tests/sandbox/test_sandbox.py b/tests/sandbox/test_sandbox.py new file mode 100644 index 00000000..d60474e4 --- /dev/null +++ b/tests/sandbox/test_sandbox.py @@ -0,0 +1,139 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from verl.utils.reward_score import _default_compute_score +from verl.utils.reward_score.prime_code import apps_check_correctness +import asyncio +from verl.workers.reward_manager.prime import parallel_compute_score_async + +prime_math_answers = [ + """\\begin{bmatrix}\n -7 & 6 & -8 \\\\\n 11 & -9 & 12 \\\\\n 15 & -16 & 19 \n \\end{bmatrix}""", + """\\frac{\\sqrt{505}}{7}""", """x^2 + y^2 + 4x - 6y + 13""" +] +prime_math_gts = [ + """\\begin{pmatrix}\n -7 & 6 & -8 \\\\\n 11 & -9 & 12 \\\\\n 15 & -16 & 19\n \\end{pmatrix}""", # mat test + """\\frac{\\sqrt{505}}{7}""", # frac test + """(x + 2)^2 + (y - 3)^2 """ # symbolic test +] + +prime_code_answers = [ + """import sys +from collections import deque + +def main(): + data = sys.stdin.read().split() + it = iter(data) + + # Read start and target positions + x0, y0, x1, y1 = int(next(it)), int(next(it)), int(next(it)), int(next(it)) + + n = int(next(it)) + allowed = set() + # The total number of allowed cells is at most 10^5. + for _ in range(n): + r = int(next(it)) + a = int(next(it)) + b = int(next(it)) + for c in range(a, b + 1): + allowed.add((r, c)) + + # Directions for the king (8 neighboring cells) + directions = [(-1, -1), (-1, 0), (-1, 1), + (0, -1), (0, 1), + (1, -1), (1, 0), (1, 1)] + + start = (x0, y0) + target = (x1, y1) + + # BFS initialization + queue = deque() + queue.append((x0, y0, 0)) + # Mark the starting cell as visited by removing it from allowed set. + allowed.discard(start) + + while queue: + x, y, moves = queue.popleft() + if (x, y) == target: + print(moves) + return + for dx, dy in directions: + nx, ny = x + dx, y + dy + if (nx, ny) in allowed: + allowed.remove((nx, ny)) + queue.append((nx, ny, moves + 1)) + + print(-1) + +if __name__ == '__main__': + main() +""" +] * 2 +prime_code_gts = [ + """{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"2\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}""", # A correct sample + """{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}""" +] # A failed sample with first several in-out passed + +prime_code_scores = [1.0, 0.9] + + +def test_parallelism(): + """ + Test if process pool works properly + """ + sequences_str = [] + ground_truth = [] + data_sources = [] + while len(sequences_str) < 32: + sequences_str.extend(prime_code_answers) + ground_truth.extend(prime_code_gts) + data_sources.extend(['codecontests'] * len(prime_code_answers)) + + sequences_str.extend(prime_math_answers) + ground_truth.extend(prime_math_gts) + data_sources.extend(['numina_aops_forum'] * len(prime_math_answers)) + + scores = asyncio.run( + parallel_compute_score_async(_default_compute_score, + sequences_str, + ground_truth, + data_sources, + num_processes=16)) + print(scores) + + +def test_prime_code(): + """ + Test PRIME code sandbox. + """ + data_source = 'codecontests' + for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores): + score = _default_compute_score(data_source, completion, ground_truth) + assert float(score) == score_ + + +def test_check_correctness(): + completion = prime_code_answers[0] + ground_truth = json.loads(prime_code_gts[0]) + ground_truth_single = {'inputs': ground_truth['inputs'][:1], 'outputs': ground_truth['outputs'][:1]} + res, meta = apps_check_correctness(in_outs=ground_truth_single, generation=completion, timeout=5, debug=False) + print(res, meta) + + +def test_prime_math(): + data_source = 'numina_aops_forum' + for completion, ground_truth in zip(prime_math_answers, prime_math_gts): + score = _default_compute_score(data_source, completion, ground_truth) + assert float(score) == 1.0 diff --git a/tests/sanity/check_license.py b/tests/sanity/check_license.py index d7ef2cec..bc306d59 100644 --- a/tests/sanity/check_license.py +++ b/tests/sanity/check_license.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -license_head = "Copyright 2024 Bytedance Ltd. and/or its affiliates" +license_head_bytedance = "Copyright 2024 Bytedance Ltd. and/or its affiliates" +license_head2_prime = "Copyright 2024 PRIME team and/or its affiliates" from pathlib import Path from argparse import ArgumentParser @@ -27,9 +28,9 @@ for path in pathlist: # because path is object not string path_in_str = str(path.absolute()) - with open(path_in_str, 'r') as f: + print(path_in_str) + with open(path_in_str, 'r', encoding='utf-8') as f: file_content = f.read() - assert license_head in file_content, f'file {path_in_str} does not contain license' - - print(path_in_str) + assert license_head_bytedance in file_content or \ + license_head2_prime in file_content, f'file {path_in_str} does not contain license' diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index a1a42cef..03fed370 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -143,6 +143,7 @@ reward_model: ulysses_sequence_parallel_size: 1 # sp size use_dynamic_bsz: ${critic.use_dynamic_bsz} forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + reward_manager: naive algorithm: gamma: 1.0 diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 628641ec..08b6a5e9 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -14,81 +14,8 @@ """ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ - -from verl import DataProto -import torch -from verl.utils.reward_score import gsm8k, math from verl.trainer.ppo.ray_trainer import RayPPOTrainer - -def _default_compute_score(data_source, solution_str, ground_truth): - if data_source == 'openai/gsm8k': - return gsm8k.compute_score(solution_str, ground_truth) - elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval']: - return math.compute_score(solution_str, ground_truth) - else: - raise NotImplementedError - - -class RewardManager(): - """The reward manager. - """ - - def __init__(self, tokenizer, num_examine, compute_score=None) -> None: - self.tokenizer = tokenizer - self.num_examine = num_examine # the number of batches of decoded responses to print to the console - self.compute_score = compute_score or _default_compute_score - - def __call__(self, data: DataProto): - """We will expand this function gradually based on the available datasets""" - - # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn - if 'rm_scores' in data.batch.keys(): - return data.batch['rm_scores'] - - reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) - - already_print_data_sources = {} - - for i in range(len(data)): - data_item = data[i] # DataProtoItem - - prompt_ids = data_item.batch['prompts'] - - prompt_length = prompt_ids.shape[-1] - - valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum() - valid_prompt_ids = prompt_ids[-valid_prompt_length:] - - response_ids = data_item.batch['responses'] - valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum() - valid_response_ids = response_ids[:valid_response_length] - - # decode - sequences = torch.cat((valid_prompt_ids, valid_response_ids)) - sequences_str = self.tokenizer.decode(sequences) - - ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] - - data_source = data_item.non_tensor_batch['data_source'] - - score = self.compute_score( - data_source=data_source, - solution_str=sequences_str, - ground_truth=ground_truth, - ) - reward_tensor[i, valid_response_length - 1] = score - - if data_source not in already_print_data_sources: - already_print_data_sources[data_source] = 0 - - if already_print_data_sources[data_source] < self.num_examine: - already_print_data_sources[data_source] += 1 - print(sequences_str) - - return reward_tensor - - import ray import hydra @@ -172,10 +99,19 @@ def main_task(config, compute_score=None): role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) mapping[Role.RewardModel] = global_pool_id - reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0, compute_score=compute_score) + reward_manager_name = config.reward_model.get("reward_manager", "naive") + if reward_manager_name == 'naive': + from verl.workers.reward_manager import NaiveRewardManager + reward_manager_cls = NaiveRewardManager + elif reward_manager_name == 'prime': + from verl.workers.reward_manager import PrimeRewardManager + reward_manager_cls = PrimeRewardManager + else: + raise NotImplementedError + reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score) # Note that we always use function-based RM for validation - val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1, compute_score=compute_score) + val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index 1ce90c5e..f53662c1 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -11,3 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# from . import gsm8k, math, prime_math, prime_code + + +def _default_compute_score(data_source, solution_str, ground_truth): + if data_source == 'openai/gsm8k': + from . import gsm8k + res = gsm8k.compute_score(solution_str, ground_truth) + elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval']: + from . import math + res = math.compute_score(solution_str, ground_truth) + elif data_source in [ + 'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12', + 'numina_olympiads' + ]: + from . import prime_math + res = prime_math.compute_score(solution_str, ground_truth) + elif data_source in ['codecontests', 'apps', 'codeforces', 'taco']: + from . import prime_code + res = prime_code.compute_score(solution_str, ground_truth, continuous=True) + else: + raise NotImplementedError + + if isinstance(res, (int, float, bool)): + return float(res) + else: + return float(res[0]) diff --git a/verl/utils/reward_score/prime_code/__init__.py b/verl/utils/reward_score/prime_code/__init__.py new file mode 100644 index 00000000..49b0d30f --- /dev/null +++ b/verl/utils/reward_score/prime_code/__init__.py @@ -0,0 +1,73 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .utils import check_correctness as apps_check_correctness +import json +import re +import traceback + + +def compute_score(completion, test_cases, continuous=False): + # try to get code solution from completion. if the completion is pure code, this will not take effect. + solution = completion.split('```python')[-1].split('```')[0] + try: + try: + if not isinstance(test_cases, dict): + test_cases = json.loads(test_cases) + except Exception as e: + print(f"Error:{e}") + + # Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped. + try: + res, metadata = apps_check_correctness(in_outs=test_cases, generation=solution, timeout=5, debug=False) + metadata = dict(enumerate(metadata))[0] + success = all(map(lambda x: x == True, res)) + if success: + return success, metadata + except Exception as e: + pass + + test_cases_list = [] + inputs = test_cases["inputs"] + outputs = test_cases["outputs"] + for i in range(len(inputs)): + test_cases_list.append({"inputs": [inputs[i]], "outputs": [outputs[i]]}) + + if continuous: + # per sample test: if continuous score is needed, test first 10 samples regardless of failures + # do not test all samples cuz some problems have enormous test cases + metadata_list = [] + res_list = [] + for test_case_id, test_case in enumerate(test_cases_list): + res, metadata = apps_check_correctness(in_outs=test_case, generation=solution, timeout=5, debug=False) + try: + metadata = dict(enumerate(metadata))[0] # metadata can be empty occasionally + except Exception as e: + metadata = {} + metadata["test_case"] = {} + metadata["test_case"]["input"] = str(test_case["inputs"][0]) + metadata["test_case"]["output"] = str(test_case["outputs"][0]) + metadata["test_case"]["res"] = str(res) + metadata_list.append(metadata) + res_list.extend(res) + + if test_case_id >= 9: + break + res_count = len(res_list) if len(res_list) > 0 else 1 + success = sum(map(lambda x: x == True, res_list)) / res_count + except Exception as e: + traceback.print_exc(10) + success = False + metadata_list = None + return success, metadata_list diff --git a/verl/utils/reward_score/prime_code/testing_util.py b/verl/utils/reward_score/prime_code/testing_util.py new file mode 100644 index 00000000..34845d07 --- /dev/null +++ b/verl/utils/reward_score/prime_code/testing_util.py @@ -0,0 +1,721 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import json +import sys +import faulthandler +import platform + +# used for debugging to time steps +from datetime import datetime + +# to run the solution files we're using a timing based approach +import signal + +import numpy as np + +# for capturing the stdout +from io import StringIO + +# used for testing the code that reads from input +from unittest.mock import patch, mock_open + +from pyext import RuntimeModule + +from enum import Enum + +import traceback + + +def truncatefn(s, length=300): + assert isinstance(s, str) + if len(s) <= length: + return s + + return s[:length // 2] + "...(truncated) ..." + s[-length // 2:] + + +class CODE_TYPE(Enum): + call_based = 0 + standard_input = 1 + + +# stuff for setting up signal timer +class TimeoutException(Exception): + pass + + +def timeout_handler(signum, frame): + print("alarm went off") + return + # raise TimeoutException # this is an unhandled exception. just return None is OK + + +signal.signal(signal.SIGALRM, timeout_handler) + +# timeout = 6 # seconds + + +# used to capture stdout as a list +# from https://stackoverflow.com/a/16571630/6416660 +# alternative use redirect_stdout() from contextlib +class Capturing(list): + + def __enter__(self): + self._stdout = sys.stdout + sys.stdout = self._stringio = StringIO() + # Make closing the StringIO a no-op + self._stringio.close = lambda x: 1 + return self + + def __exit__(self, *args): + self.append(self._stringio.getvalue()) + del self._stringio # free up some memory + sys.stdout = self._stdout + + +def only_int_check(val): + return isinstance(val, int) + + +def string_int_check(val): + return isinstance(val, str) and val.isdigit() + + +def combined_int_check(val): + return only_int_check(val) or string_int_check(val) + + +def clean_traceback(error_traceback): + file_start = error_traceback.find('File \"\"') + # print(file_start) + error_traceback = "Traceback (most recent call last):\n " + error_traceback[file_start:] + return error_traceback + + +def run_test(in_outs, test=None, debug=False, timeout=15): + """ + if test(generated_code) is not None it'll try to run the code. + otherwise it'll just return an input and output pair. + """ + # Disable functionalities that can make destructive changes to the test. + reliability_guard() + + if debug: + print(f"start = {datetime.now().time()}") + + if in_outs: + if in_outs.get("fn_name") is None: + which_type = CODE_TYPE.standard_input # Standard input + method_name = None + else: + which_type = CODE_TYPE.call_based # Call-based + method_name = in_outs["fn_name"] + + if debug: + print(f"loaded input_output = {datetime.now().time()}") + + if test is None: + assert False, "should not happen: test code is none" + return in_outs, {"error": "no test code provided"} + elif test is not None: + results = [] + sol = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(6*10**5)\n" + if debug: + print(f"loading test code = {datetime.now().time()}") + + if which_type == CODE_TYPE.call_based: + + sol += test + if debug: + print(f"sol = {sol}") + signal.alarm(timeout) + try: + tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol) + if "class Solution" not in test: + tmp = tmp_sol + else: + tmp = tmp_sol.Solution() + signal.alarm(0) + except Exception as e: + signal.alarm(0) + error_traceback = traceback.format_exc() + if debug: + print(f"type 0 compilation error = {e}") + results.append(-2) + return results, { + "error": repr(e), + # "error_code": -1, + # "error_message": "Compilation Error", + "traceback": clean_traceback(error_traceback), + } + signal.alarm(0) + + elif which_type == CODE_TYPE.standard_input: + # sol + # if code has if __name__ == "__main__": then remove it + try: + astree = ast.parse(test) + last_block = astree.body[-1] + if isinstance(last_block, ast.If): + condition = last_block.test + if ast.unparse(condition).strip() == "__name__ == '__main__'": + test = (ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body)) + except: + pass + + tmp_test = test.split("\n") + + new_test = [] + for x in tmp_test: + if (not x.startswith("from ")) and (not x.startswith("import ")): + new_test.append("\t" + x + "\n") + else: + new_test.append(x + "\n") + tmp_test = new_test + + new_test = "" + started = False + for i in tmp_test: + if i.startswith("\t") and not started: + new_test += "stdin = sys.stdin\nstdout = sys.stdout\n" + new_test += "def code():\n" + new_test += i + started = True + elif started and ((i.startswith("from ")) or (i.startswith("import "))): + new_test += "\t" + i + else: + new_test += i + tmp_test = new_test + + sol += tmp_test + if debug: + print(f"sol = {sol}") + method_name = "code" + signal.alarm(timeout) + try: + tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol) + tmp = tmp_sol + signal.alarm(0) + except Exception as e: + signal.alarm(0) + error_traceback = traceback.format_exc() + if debug: + print(f"type 1 compilation error = {e}") + results.append(-2) + return results, { + "error": repr(e), + # "error_code": -1, + # "error_message": "Compilation Error", + "traceback": clean_traceback(error_traceback), + } + signal.alarm(0) + if debug: + print(f"get method = {datetime.now().time()}") + + try: + method = getattr(tmp, method_name) # get_attr second arg must be str + except: + signal.alarm(0) + error_traceback = traceback.format_exc() + e = sys.exc_info() + print(f"unable to get function error = {e}") + results.append(-2) + return results, { + "error": repr(e), + # "error_code": -1, + # "error_message": "Unable to extract code", + "traceback": clean_traceback(error_traceback), + } + + for index, inputs in enumerate(in_outs["inputs"]): + raw_inputs = inputs + raw_outputs = in_outs["outputs"][index] + if which_type == CODE_TYPE.call_based: + inputs = [json.loads(line) for line in inputs.split("\n")] + in_outs["outputs"][index] = json.loads(in_outs["outputs"][index]) + + truncate_line_size = 300 // (raw_inputs.count("\n") + 1) + raw_inputs = "\n".join( + [truncatefn(line, truncate_line_size) for line in raw_inputs.strip().split("\n")]) + raw_outputs = truncatefn(raw_outputs, 200) + else: + raw_inputs = truncatefn(raw_inputs) + raw_outputs = truncatefn(raw_outputs, 200) + # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list) + try: + if isinstance(inputs[0], dict): + inputs = [{int(k): v for k, v in inputs[0].items()}] + except: + True + try: + if isinstance(in_outs["outputs"][index], dict): + in_outs["outputs"][index] = [{int(k): v for k, v in in_outs["outputs"][index].items()}] + except: + True + try: + if isinstance(in_outs["outputs"][index][0], dict): + in_outs["outputs"][index] = [{int(k): v for k, v in in_outs["outputs"][index][0].items()}] + except: + True + + if debug: + print( + f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}" + ) + if which_type == CODE_TYPE.call_based: # Call-based + signal.alarm(timeout) + faulthandler.enable() + try: + output = method(*inputs) + raw_true_output = output + + raw_true_output_copy = json.dumps(output) + raw_true_output_copy = truncatefn(raw_true_output_copy, 200) + + # ground truth sequences are not tuples + if isinstance(output, tuple): + output = list(output) + + tmp_result = output == in_outs["outputs"][index] + if (isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]): + tmp_result = tmp_result or (output == in_outs["outputs"][index][0]) + + # ground truth sequences are not tuples + try: + if isinstance(output[0], tuple): + tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0]) + except: + True + results.append(tmp_result) + if tmp_result != True: + return results, { + "output": raw_true_output_copy, + "expected": raw_outputs, + "inputs": raw_inputs, + # "error_code": -2, + "error_message": "Wrong Answer", + } + # reset the alarm + signal.alarm(0) + except Exception as e: + signal.alarm(0) + error_traceback = traceback.format_exc() + faulthandler.disable() + if debug: + print(f"Standard input runtime error or time limit exceeded error = {e}") + results.append(-1) + if "timeoutexception" in repr(e).lower(): + return results, { + "error": repr(e), + # "error_code": -3, + # "error_message": "Time Limit Exceeded", + # "inputs": raw_inputs, + # "expected": raw_outputs, + "traceback": clean_traceback(error_traceback), + } + else: + return results, { + "error": repr(e), + # "error_code": -4, + # "error_message": "Runtime Error", + # "inputs": raw_inputs, + # "expected": raw_outputs, + "traceback": clean_traceback(error_traceback), + } + faulthandler.disable() + signal.alarm(0) + if debug: + print( + f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) + elif which_type == CODE_TYPE.standard_input: # Standard input + faulthandler.enable() + passed = False + + if isinstance(inputs, list): + inputs = "\n".join(inputs) + if isinstance(in_outs["outputs"][index], list): + in_outs["outputs"][index] = "\n".join(in_outs["outputs"][index]) + + signal.alarm(timeout) + with Capturing() as output: + try: + call_method(method, inputs) + # reset the alarm + signal.alarm(0) + passed = True + except Exception as e: + # runtime error or took too long + signal.alarm(0) + error_traceback = traceback.format_exc() + print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}") + results.append(-1) + if "timeoutexception" in repr(e).lower(): + return results, { + "error": repr(e), + # "error_code": -3, + # "error_message": "Time Limit Exceeded", + # "inputs": raw_inputs, + # "expected": raw_outputs, + "traceback": clean_traceback(error_traceback), + } + else: + return results, { + "error": repr(e), + # "error_code": -4, + # "error_message": "Runtime Error", + # "inputs": raw_inputs, + # "expected": raw_outputs, + "traceback": clean_traceback(error_traceback), + } + signal.alarm(0) + raw_true_output = output[0] + raw_true_output_copy = truncatefn(raw_true_output, 200) + output = raw_true_output.splitlines() + if not passed: + if debug: + nl = "\n" + if not isinstance(inputs, list): + print( + f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) + else: + print( + f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) + continue + + if passed and debug: + print(f"==> output = {output}, test outputs = {in_outs['outputs'][index]}") + + if custom_compare_(output, in_outs["outputs"][index]): + tmp_result = True + results.append(tmp_result) + continue + + # ground truth sequences are expressed as lists not tuples + if isinstance(output, tuple): + output = list(output) + + tmp_result = False + try: + tmp_result = output == [in_outs["outputs"][index]] + if isinstance(in_outs["outputs"][index], list): + tmp_result = tmp_result or (output == in_outs["outputs"][index]) + if isinstance(output[0], str): + tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index]) + except Exception as e: + if debug: + print(f"Failed check1 exception = {e}") + pass + + if tmp_result == True: + results.append(tmp_result) + continue + + # try one more time without \n + if isinstance(in_outs["outputs"][index], list): + for tmp_index, i in enumerate(in_outs["outputs"][index]): + in_outs["outputs"][index][tmp_index] = i.split("\n") + in_outs["outputs"][index][tmp_index] = [ + x.strip() for x in in_outs["outputs"][index][tmp_index] if x + ] + else: + in_outs["outputs"][index] = in_outs["outputs"][index].split("\n") + in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index])) + in_outs["outputs"][index] = list(map(lambda x: x.strip(), in_outs["outputs"][index])) + + try: + tmp_result = output == [in_outs["outputs"][index]] + if isinstance(in_outs["outputs"][index], list): + tmp_result = tmp_result or (output == in_outs["outputs"][index]) + except Exception as e: + if debug: + print(f"Failed check2 exception = {e}") + pass + + if tmp_result == True: + results.append(tmp_result) + continue + + # try by converting the output into a split up list too + if isinstance(output, list): + output = list(filter(len, output)) + + if debug: + nl = "\n" + if not isinstance(inputs, list): + print( + f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}" + ) + else: + print( + f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}" + ) + + if tmp_result == True: + results.append(tmp_result) + continue + + if debug: + print(f"{tmp_result=} @a") + + try: + tmp_result = output == [in_outs["outputs"][index]] + if isinstance(in_outs["outputs"][index], list): + tmp_result = tmp_result or (output == in_outs["outputs"][index]) + except Exception as e: + if debug: + print(f"Failed check3 exception = {e}") + pass + + if debug: + print(f"{tmp_result=} @b") + + try: + all_ints = all( + combined_int_check(e1) and combined_int_check(e2) + for e1, e2 in zip(output, in_outs["outputs"][index])) + if not all_ints: + if debug: + print([ + combined_int_check(e1) and combined_int_check(e2) + for e1, e2 in zip(output, in_outs["outputs"][index]) + ]) + output_float = [float(e) for e in output] + gt_float = [float(e) for e in in_outs["outputs"][index]] + tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and + np.allclose(output_float, gt_float)) + except Exception as e: + pass + + if debug: + print(f"{tmp_result=} @c") + + try: + if isinstance(output[0], list): + all_ints = all( + combined_int_check(e1) and combined_int_check(e2) + for e1, e2 in zip(output[0], in_outs["outputs"][index])) + if not all_ints: + output_float = [float(e) for e in output[0]] + gt_float = [float(e) for e in in_outs["outputs"][index][0]] + tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and + np.allclose(output_float, gt_float)) + except Exception as e: + pass + + if tmp_result == True: + results.append(tmp_result) + continue + + if debug: + print(f"{tmp_result=} @d") + # try by converting the stuff into split up list + if isinstance(in_outs["outputs"][index], list): + for tmp_index, i in enumerate(in_outs["outputs"][index]): + in_outs["outputs"][index][tmp_index] = set(i.split()) + else: + in_outs["outputs"][index] = set(in_outs["outputs"][index].split()) + + if debug: + print(f"{tmp_result=} @e") + + try: + tmp_result = output == in_outs["outputs"][index] + except Exception as e: + if debug: + print(f"Failed check4 exception = {e}") + continue + + if tmp_result == True: + results.append(tmp_result) + continue + + if debug: + print(f"{tmp_result=} @f") + + # try by converting the output into a split up list too + if isinstance(output, list): + for tmp_index, i in enumerate(output): + output[tmp_index] = i.split() + output = list(filter(len, output)) + for tmp_index, i in enumerate(output): + output[tmp_index] = set(i) + else: + output = output.split() + output = list(filter(len, output)) + output = set(output) + + if debug: + print(f"{tmp_result=} @g") + + if tmp_result == True and debug: + print("PASSED") + + results.append(tmp_result) + if tmp_result != True: + return results, { + "output": raw_true_output_copy, + "expected": raw_outputs, + "inputs": raw_inputs, + # "error_code": -2, + "error_message": "Wrong Answer", + } + + if debug: + nl = "\n" + if not isinstance(inputs, list): + print( + f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) + else: + print( + f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) + + print(f"results = {results}") + + return results, {} + + +def custom_compare_(output, ground_truth): + + if isinstance(output, list): + output_1 = "\n".join(output) + if stripped_string_compare(output_1, ground_truth): + return True + + if isinstance(output, list): + output_2 = [o.lstrip().rstrip() for o in output] + output_2 = "\n".join(output_2) + if stripped_string_compare(output_2, ground_truth): + return True + + return False + + +def stripped_string_compare(s1, s2): + s1 = s1.lstrip().rstrip() + s2 = s2.lstrip().rstrip() + return s1 == s2 + + +def call_method(method, inputs): + + if isinstance(inputs, list): + inputs = "\n".join(inputs) + + inputs_line_iterator = iter(inputs.split("\n")) + + # sys.setrecursionlimit(10000) + + # @patch('builtins.input', side_effect=inputs.split("\n")) + @patch("builtins.open", mock_open(read_data=inputs)) + @patch("sys.stdin", StringIO(inputs)) + @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator)) + @patch("sys.stdin.readlines", lambda *args: inputs.split("\n")) + @patch("sys.stdin.read", lambda *args: inputs) + # @patch('sys.stdout.write', print) + def _inner_call_method(_method): + try: + return _method() + except SystemExit as e: + pass + finally: + pass + + return _inner_call_method(method) + + +def reliability_guard(maximum_memory_bytes=None): + """ + This disables various destructive functions and prevents the generated code + from interfering with the test (e.g. fork bomb, killing other processes, + removing filesystem files, etc.) + WARNING + This function is NOT a security sandbox. Untrusted code, including, model- + generated code, should not be blindly executed outside of one. See the + Codex paper for more information about OpenAI's code sandbox, and proceed + with caution. + """ + + if maximum_memory_bytes is not None: + import resource + + resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) + resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) + if not platform.uname().system == "Darwin": + resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) + + faulthandler.disable() + + import builtins + + builtins.exit = None + builtins.quit = None + + import os + + os.environ["OMP_NUM_THREADS"] = "1" + + os.kill = None + os.system = None # 防止干扰repl评测 + os.putenv = None + os.remove = None + os.removedirs = None + os.rmdir = None + os.fchdir = None + os.setuid = None + os.fork = None + os.forkpty = None + os.killpg = None + os.rename = None + os.renames = None + os.truncate = None + os.replace = None + os.unlink = None + os.fchmod = None + os.fchown = None + os.chmod = None + os.chown = None + os.chroot = None + os.fchdir = None + os.lchflags = None + os.lchmod = None + os.lchown = None + os.getcwd = None + os.chdir = None + + import shutil + + shutil.rmtree = None + shutil.move = None + shutil.chown = None + + import subprocess + + subprocess.Popen = None # type: ignore + + __builtins__["help"] = None + + import sys + + sys.modules["ipdb"] = None + sys.modules["joblib"] = None + sys.modules["resource"] = None + sys.modules["psutil"] = None + sys.modules["tkinter"] = None diff --git a/verl/utils/reward_score/prime_code/utils.py b/verl/utils/reward_score/prime_code/utils.py new file mode 100644 index 00000000..07f3db88 --- /dev/null +++ b/verl/utils/reward_score/prime_code/utils.py @@ -0,0 +1,59 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Borrowed from: https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/utils.py + +import multiprocessing +from typing import Dict, Optional +from datasets import load_dataset +from .testing_util import run_test +import traceback +import os, sys + + +def _temp_run(sample, generation, debug, result, metadata_list, timeout): + with open(os.devnull, 'w') as devnull: + sys.stdout = devnull + sys.stderr = devnull + try: + res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout) + result.append(res) + metadata_list.append(metadata) + except Exception as e: + # print(e) # some tracebacks are extremely long. + traceback.print_exc(10) + result.append([-1 for i in range(len(sample['inputs']))]) + metadata_list.append({}) + + +def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True): + """Check correctness of code generation with a global timeout. + The global timeout is to catch some extreme/rare cases not handled by the timeouts + inside `run_test`""" + + manager = multiprocessing.Manager() + result = manager.list() + metadata_list = manager.list() + p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout)) + p.start() + p.join(timeout=timeout + 1) + if p.is_alive(): + p.kill() + # p.terminate() + if not result: + # consider that all tests failed + result = [[-1 for i in range(len(in_outs["inputs"]))]] + if debug: + print(f"global timeout") + return result[0], metadata_list diff --git a/verl/utils/reward_score/prime_math/__init__.py b/verl/utils/reward_score/prime_math/__init__.py new file mode 100644 index 00000000..39b2c831 --- /dev/null +++ b/verl/utils/reward_score/prime_math/__init__.py @@ -0,0 +1,402 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Answer checker API that uses sympy to simplify expressions and check for equality. + +Call grade_answer(given_answer: str, ground_truth: str). + +FROM: https://github.com/openai/prm800k/blob/main/prm800k/grading/grader.py +""" +import re +import sympy +from pylatexenc import latex2text +from sympy.parsing import sympy_parser + +from . import math_normalize +from .grader import math_equal + +# import math_normalize +# from grader import math_equal + +# sympy might hang -- we don't care about trying to be lenient in these cases +BAD_SUBSTRINGS = ["^{", "^("] +BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"] +TUPLE_CHARS = "()[]" + + +def _sympy_parse(expr: str): + """Parses an expression with sympy.""" + py_expr = expr.replace("^", "**") + return sympy_parser.parse_expr( + py_expr, + transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)), + ) + + +def _parse_latex(expr: str) -> str: + """Attempts to parse latex to an expression sympy can read.""" + expr = expr.replace("\\tfrac", "\\frac") + expr = expr.replace("\\dfrac", "\\frac") + expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. + expr = latex2text.LatexNodes2Text().latex_to_text(expr) + + # Replace the specific characters that this parser uses. + expr = expr.replace("√", "sqrt") + expr = expr.replace("π", "pi") + expr = expr.replace("∞", "inf") + expr = expr.replace("∪", "U") + expr = expr.replace("·", "*") + expr = expr.replace("×", "*") + + return expr.strip() + + +def _is_float(num: str) -> bool: + try: + float(num) + return True + except ValueError: + return False + + +def _is_int(x: float) -> bool: + try: + return abs(x - int(round(x))) <= 1e-7 + except: + return False + + +def _is_frac(expr: str) -> bool: + return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr)) + + +def _str_is_int(x: str) -> bool: + try: + x = _strip_properly_formatted_commas(x) + x = float(x) + return abs(x - int(round(x))) <= 1e-7 + except: + return False + + +def _str_to_int(x: str) -> bool: + x = x.replace(",", "") + x = float(x) + return int(x) + + +def _inject_implicit_mixed_number(step: str): + """ + Automatically make a mixed number evalable + e.g. 7 3/4 => 7+3/4 + """ + p1 = re.compile("([0-9]) +([0-9])") + step = p1.sub("\\1+\\2", step) ## implicit mults + return step + + +def _strip_properly_formatted_commas(expr: str): + # We want to be careful because we don't want to strip tuple commas + p1 = re.compile("(\d)(,)(\d\d\d)($|\D)") + while True: + next_expr = p1.sub("\\1\\3\\4", expr) + if next_expr == expr: + break + expr = next_expr + return next_expr + + +def _normalize(expr: str) -> str: + """Normalize answer expressions.""" + if expr is None: + return None + + # Remove enclosing `\text{}`. + m = re.search("^\\\\text\{(?P.+?)\}$", expr) + if m is not None: + expr = m.group("text") + + expr = expr.replace("\\%", "%") + expr = expr.replace("\\$", "$") + expr = expr.replace("$", "") + expr = expr.replace("%", "") + expr = expr.replace(" or ", " , ") + expr = expr.replace(" and ", " , ") + + expr = expr.replace("million", "*10^6") + expr = expr.replace("billion", "*10^9") + expr = expr.replace("trillion", "*10^12") + + for unit in [ + "degree", + "cm", + "centimeter", + "meter", + "mile", + "second", + "minute", + "hour", + "day", + "week", + "month", + "year", + "foot", + "feet", + "inch", + "yard", + "liter", + ]: + expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) + expr = re.sub(f"\^ *\\\\circ", "", expr) + + if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": + expr = expr[1:-1] + + expr = re.sub(",\\\\! *", "", expr) + if _is_float(expr) and _is_int(float(expr)): + expr = str(int(round(float(expr)))) + if "\\" in expr: + try: + expr = _parse_latex(expr) + except: + pass + + # edge case with mixed numbers and negative signs + expr = re.sub("- *", "-", expr) + + expr = _inject_implicit_mixed_number(expr) + + # don't be case sensitive for text answers + expr = expr.lower() + + if _str_is_int(expr): + expr = str(_str_to_int(expr)) + + return expr + + +def count_unknown_letters_in_expr(expr: str): + expr = expr.replace("sqrt", "") + expr = expr.replace("frac", "") + letters_in_expr = set([x for x in expr if x.isalpha()]) + return len(letters_in_expr) + + +def should_allow_eval(expr: str): + # we don't want to try parsing unknown text or functions of more than two variables + if count_unknown_letters_in_expr(expr) > 2: + return False + + for bad_string in BAD_SUBSTRINGS: + if bad_string in expr: + return False + + for bad_regex in BAD_REGEXES: + if re.search(bad_regex, expr) is not None: + return False + + return True + + +def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): + are_equal = False + try: + expr = f"({ground_truth_normalized})-({given_normalized})" + if should_allow_eval(expr): + sympy_diff = _sympy_parse(expr) + simplified = sympy.simplify(sympy_diff) + if simplified == 0: + are_equal = True + except: + pass + return are_equal + + +def split_tuple(expr: str): + """ + Split the elements in a tuple/interval, while handling well-formatted commas in large numbers + """ + expr = _strip_properly_formatted_commas(expr) + if len(expr) == 0: + return [] + if (len(expr) > 2 and expr[0] in TUPLE_CHARS and expr[-1] in TUPLE_CHARS and + all([ch not in expr[1:-1] for ch in TUPLE_CHARS])): + elems = [elem.strip() for elem in expr[1:-1].split(",")] + else: + elems = [expr] + return elems + + +def grade_answer(given_answer: str, ground_truth: str) -> bool: + """ + The answer will be considered correct if: + (a) it normalizes to the same string as the ground truth answer + OR + (b) sympy can simplify the difference between the expressions to 0 + """ + if given_answer is None: + return False + + ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth) + given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer) + + # be at least as lenient as mathd + if ground_truth_normalized_mathd == given_answer_normalized_mathd: + return True + + ground_truth_normalized = _normalize(ground_truth) + given_normalized = _normalize(given_answer) + + if ground_truth_normalized is None: + return False + + if ground_truth_normalized == given_normalized: + return True + + if len(given_normalized) == 0: + return False + + ground_truth_elems = split_tuple(ground_truth_normalized) + given_elems = split_tuple(given_normalized) + + if len(ground_truth_elems) > 1 and (ground_truth_normalized[0] != given_normalized[0] or + ground_truth_normalized[-1] != given_normalized[-1]): + is_correct = False + elif len(ground_truth_elems) != len(given_elems): + is_correct = False + else: + for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems): + if _is_frac(ground_truth_elem) and _is_frac(given_elem): + # if fractions aren't reduced, then shouldn't be marked as correct + # so, we don't want to allow sympy.simplify in this case + is_correct = ground_truth_elem == given_elem + elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): + # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) + is_correct = False + else: + is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) + if not is_correct: + break + + return is_correct + + +def remove_boxed(s): + left = "\\boxed{" + try: + assert s[:len(left)] == left + assert s[-1] == "}" + return s[len(left):-1] + except: + return None + + +def _last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + left_brace_idx = None + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if left_brace_idx is None: + left_brace_idx = i + elif string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + + i += 1 + + if left_brace_idx is None or right_brace_idx is None: + return None + + return string[left_brace_idx + 1:right_brace_idx].strip() + + +def match_answer(response): + is_matched = False + for ans_marker in ['answer:', "answer is", "answers are"]: + ans_idx = response.lower().rfind(ans_marker) + if ans_idx != -1: + is_matched = True + response = response[ans_idx + len(ans_marker):].strip() + if response.endswith("\n"): + response = response[:-2] + + for ans_marker in ["is answer", "is the answer", "are answers", "are the answers"]: + ans_idx = response.lower().rfind(ans_marker) + if ans_idx != -1: + is_matched = True + response = response[:ans_idx].strip() + if response.endswith("\n"): + response = response[:-2] + + # Find boxed + ans_boxed = _last_boxed_only_string(response) + if ans_boxed: + is_matched = True + response = ans_boxed + + if ". " in response: + dot_idx = response.lower().rfind(". ") + if dot_idx != -1: + response = response[:dot_idx].strip() + + for ans_marker in ['be ', "is ", "are ", "=", ": ", "get ", 'be\n', "is\n", "are\n", ":\n", "get\n"]: + ans_idx = response.lower().rfind(ans_marker) + if ans_idx != -1: + is_matched = True + response = response[ans_idx + len(ans_marker):].strip() + if response.endswith("\n"): + response = response[:-2] + + is_matched = is_matched if any([c.isdigit() for c in response]) else False # answer must have a digit + # Grade + return is_matched, response + + +import math + + +def compute_score(model_output: str, ground_truth: str) -> bool: + model_output = str(model_output) + ground_truth = str(ground_truth) + + is_matched, extracted_model_output = match_answer(model_output) + format_correctness = "Step 2:" in model_output and "\\box" in model_output + + # grade simple algebra questions. if succeeded, return; otherwise, proceed to more complex grading + if grade_answer(extracted_model_output, ground_truth): + return True, True, extracted_model_output + + try: + if "\pi" in extracted_model_output or "\pi" in ground_truth: + equivs = [] + for pi in [math.pi, 3.14]: + equivs.append(math_equal(extracted_model_output, ground_truth, timeout=True, pi=pi)) + is_correct = any(equivs) + else: + is_correct = math_equal(extracted_model_output, ground_truth, timeout=True) + except: + is_correct = False + + return is_correct, format_correctness, extracted_model_output diff --git a/verl/utils/reward_score/prime_math/grader.py b/verl/utils/reward_score/prime_math/grader.py new file mode 100644 index 00000000..e3e1686f --- /dev/null +++ b/verl/utils/reward_score/prime_math/grader.py @@ -0,0 +1,380 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Microsoft Corporation. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE + +# Copyright (c) 2023 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Copyright (c) 2021 Dan Hendrycks +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from: +- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py +- https://github.com/microsoft/ProphetNet/tree/master/CRITIC +- https://github.com/openai/prm800k +""" + +import contextlib +import re +import signal +import math +from math import isclose +from typing import Union + +from sympy import N, simplify +from sympy.parsing.latex import parse_latex +from sympy.parsing.sympy_parser import parse_expr + + +def is_digit(s): + try: + if "{,}" in str(s): + num = float(str(s).replace("{,}", "")) + return True, num + + num = float(str(s).replace(",", "")) + return True, num + except ValueError: + return False, None + + +def normalize(answer, pi) -> str: + # checking if answer is $ and removing $ in that case to compare + if isinstance(answer, str) and bool(re.match(r'\$\d+(\.\d+)?', answer)): + return answer[1:] + + # checking if answer is % or \\% and removing % + if isinstance(answer, str) and (bool(re.match(r'^\d+(\.\d+)?%$', answer)) or + bool(re.match(r'^\d+(\.\d+)?\\%$', answer))): + return answer.replace("\\%", "").replace("%", "") + + # handle base + answer = handle_base(answer) + + # handle pi + answer = handle_pi(answer, pi) + + return answer + + +def handle_base(x) -> str: + if isinstance(x, str) and "_" in x: + # Due to base + x = x.split("_")[0] + x = float(x) + return int(x) + return x + + +def handle_pi(string, pi): + if isinstance(string, str) and "\pi" in string: + # Find the first occurrence of "\pi" + idx = string.find("\pi") + + # Iterate over the string and find all occurrences of "\pi" with a valid previous character + while idx != -1: + + if idx > 0 and string[idx - 1].isdigit(): + # Replace "\pi" with "*math.pi" if the previous character is a digit + string = string[:idx] + f"*{pi}" + string[idx + 3:] + else: + # Replace "\pi" with "1*math.pi" if the previous character is not a digit + string = string[:idx] + f"1*{pi}" + string[idx + 3:] + + # Find the next occurrence of "\pi" + idx = string.find("\pi", idx + 1) + + # Evaluate the expression using eval() function + try: + string = eval(string) + except: + pass + + return string + + +def math_equal(prediction: Union[bool, float, str], + reference: Union[float, str], + include_percentage: bool = True, + tolerance: float = 1e-4, + timeout: float = 10.0, + pi: float = math.pi) -> bool: + """ + Exact match of math if and only if: + 1. numerical equal: both can convert to float and are equal + 2. symbolic equal: both can convert to sympy expression and are equal + """ + + prediction = normalize(prediction, pi) + reference = normalize(reference, pi) + + if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases + prediction = prediction[:1000] + + # 0. string comparison + if isinstance(prediction, str) and isinstance(reference, str): + if prediction.strip().lower() == reference.strip().lower(): + return True + if prediction.replace(" ", "") == reference.replace(" ", ""): + return True + + try: # 1. numerical equal + if is_digit(prediction)[0] and is_digit(reference)[0]: + prediction = is_digit(prediction)[1] + reference = is_digit(reference)[1] + # number questions + if include_percentage: + gt_result = [reference / 100, reference, reference * 100] + else: + gt_result = [reference] + for item in gt_result: + try: + if isclose(item, prediction, rel_tol=tolerance): + return True + except Exception: + continue + return False + except Exception: + pass + + if not prediction and prediction not in [0, False]: + return False + + # 2. symbolic equal + reference = str(reference).strip() + prediction = str(prediction).strip() + + ## deal with [], (), {} + prediction = format_intervals(prediction) + + pred_str, ref_str = prediction, reference + if (prediction.startswith("[") and prediction.endswith("]") and + not reference.startswith("(")) or (prediction.startswith("(") and prediction.endswith(")") and + not reference.startswith("[")): + pred_str = pred_str.strip("[]()") + ref_str = ref_str.strip("[]()") + for s in ["{", "}", "(", ")"]: + ref_str = ref_str.replace(s, "") + pred_str = pred_str.replace(s, "") + if pred_str == ref_str: + return True + + ## [a, b] vs. [c, d], return a==c and b==d + if (prediction and reference and prediction[0] in "([" and prediction[-1] in ")]" and + prediction[0] == reference[0] and prediction[-1] == reference[-1]): + pred_parts = prediction[1:-1].split(",") + ref_parts = reference[1:-1].split(",") + if len(pred_parts) == len(ref_parts): + if all([ + math_equal(pred_pt, ref_pt, include_percentage, tolerance) + for pred_pt, ref_pt in zip(pred_parts, ref_parts) + ]): + return True + + if "," in prediction and "," in reference: + pred_parts = [item.strip() for item in prediction.split(",")] + ref_parts = [item.strip() for item in reference.split(",")] + + if len(pred_parts) == len(ref_parts): + if all([ + math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) + for i in range(len(pred_parts)) + ]): + return True + else: + return False + + # if we have point == tuple of values + if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")": + pred_parts = prediction[prediction.find("(") + 1:-1].split(",") + ref_parts = reference[1:-1].split(",") + if len(pred_parts) == len(ref_parts): + if all([ + math_equal(pred_pt, ref_pt, include_percentage, tolerance) + for pred_pt, ref_pt in zip(pred_parts, ref_parts) + ]): + return True + + # if reference is a matrix + if "\begin{pmatrix}" in reference and prediction.startswith("Matrix"): + try: + pred_matrix = parse_expr(prediction) + ref_matrix_items = reference.split()[1:-1:2] + if len(pred_matrix) == len(ref_matrix_items): + if all([ + math_equal(pred, ref, include_percentage, tolerance) + for ref, pred in zip(ref_matrix_items, pred_matrix) + ]): + return True + except Exception: + pass + elif "\begin{pmatrix}" in reference and prediction.startswith("[") and prediction.endswith("]"): + if isinstance(eval(prediction), list): + try: + pred_matrix = eval(prediction) + # ref_matrix_items = reference.split()[1:-1:2] + ref_matrix_items = reference.lstrip("\\begin{pmatrix}").lstrip("\begin{pmatrix}").rstrip( + "\\end{pmatrix}").rstrip("\end{pmatrix}") + ref_matrix_items = ref_matrix_items.split("\\") + ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items] + if len(pred_matrix) == len(ref_matrix_items): + if all([ + math_equal(pred, ref, include_percentage, tolerance) + for ref, pred in zip(ref_matrix_items, pred_matrix) + ]): + return True + except Exception: + pass + + return symbolic_equal(prediction, reference, tolerance, timeout) + + +def symbolic_equal(a, b, tolerance, timeout=10.0): + + def _parse(s): + for f in [parse_expr, parse_latex]: + try: + with time_limit(timeout): + return f(s) + except Exception: + pass + return s + + a = _parse(a) + b = _parse(b) + + try: + with time_limit(timeout): + if simplify(a - b) == 0: + return True + except Exception: + pass + + try: + with time_limit(timeout): + if isclose(N(a), N(b), rel_tol=tolerance): + return True + except Exception: + pass + return False + + +class TimeoutException(Exception): + pass + + +@contextlib.contextmanager +def time_limit(seconds: float): + + def signal_handler(signum, frame): + raise TimeoutException("Timed out!") + + signal.setitimer(signal.ITIMER_REAL, seconds) + signal.signal(signal.SIGALRM, signal_handler) + try: + yield + finally: + signal.setitimer(signal.ITIMER_REAL, 0) + + +def format_intervals(prediction): + patterns = { + "Interval(": r"^Interval\((.*)\)$", + "Interval.Ropen(": r"^Interval\.Ropen\((.*)\)$", + "Interval.Lopen(": r"^Interval\.Lopen\((.*)\)$", + "Interval.open(": r"^Interval\.open\((.*)\)$", + } + + for key, pattern in patterns.items(): + match = re.match(pattern, prediction) + if match: + inner_content = match.group(1) + + if key == "Interval(": # Intarval(a, b) == [a, b] + return f"[{inner_content}]" + elif key == "Interval.Ropen(": # Intarval.Ropen(a, b) == [a, b) + return f"[{inner_content})" + elif key == "Interval.Lopen(": # Intarval.Lopen(a, b) == (a, b] + return f"({inner_content}]" + elif key == "Interval.open(": # Intarval.open(a, b) == (a, b) + return f"({inner_content})" + + return prediction diff --git a/verl/utils/reward_score/prime_math/math_normalize.py b/verl/utils/reward_score/prime_math/math_normalize.py new file mode 100644 index 00000000..e9921a5a --- /dev/null +++ b/verl/utils/reward_score/prime_math/math_normalize.py @@ -0,0 +1,191 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2021 Dan Hendrycks +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +This logic is largely copied from the Hendrycks' MATH release (math_equivalence). + +From: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py +""" +import re +from typing import Optional + + +def normalize_answer(answer: Optional[str]) -> Optional[str]: + if answer is None: + return None + answer = answer.strip() + try: + # Remove enclosing `\text{}`. + m = re.search("^\\\\text\{(?P.+?)\}$", answer) + if m is not None: + answer = m.group("text").strip() + return _strip_string(answer) + except: + return answer + + +def _fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def _fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except: + return string + + +def _remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) when describing units + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + + +def _fix_sqrt(string): + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +def _strip_string(string): + # linebreaks + string = string.replace("\n", "") + + # remove inverse spaces + string = string.replace("\\!", "") + + # replace \\ with \ + string = string.replace("\\\\", "\\") + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = _remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2: + if len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # fix sqrt3 --> sqrt{3} + string = _fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = _fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = _fix_a_slash_b(string) + + return string \ No newline at end of file diff --git a/verl/workers/reward_manager/__init__.py b/verl/workers/reward_manager/__init__.py new file mode 100644 index 00000000..3de46a77 --- /dev/null +++ b/verl/workers/reward_manager/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .naive import NaiveRewardManager +from .prime import PrimeRewardManager \ No newline at end of file diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py new file mode 100644 index 00000000..8cfae4cf --- /dev/null +++ b/verl/workers/reward_manager/naive.py @@ -0,0 +1,76 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from verl import DataProto +from verl.utils.reward_score import _default_compute_score +import torch + + +class NaiveRewardManager: + """The reward manager. + """ + + def __init__(self, tokenizer, num_examine, compute_score=None) -> None: + self.tokenizer = tokenizer + self.num_examine = num_examine # the number of batches of decoded responses to print to the console + self.compute_score = compute_score or _default_compute_score + + def __call__(self, data: DataProto): + """We will expand this function gradually based on the available datasets""" + + # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn + if 'rm_scores' in data.batch.keys(): + return data.batch['rm_scores'] + + reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + + already_print_data_sources = {} + + for i in range(len(data)): + data_item = data[i] # DataProtoItem + + prompt_ids = data_item.batch['prompts'] + + prompt_length = prompt_ids.shape[-1] + + valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum() + valid_prompt_ids = prompt_ids[-valid_prompt_length:] + + response_ids = data_item.batch['responses'] + valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + # decode + sequences = torch.cat((valid_prompt_ids, valid_response_ids)) + sequences_str = self.tokenizer.decode(sequences) + + ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] + + data_source = data_item.non_tensor_batch['data_source'] + + score = self.compute_score( + data_source=data_source, + solution_str=sequences_str, + ground_truth=ground_truth, + ) + reward_tensor[i, valid_response_length - 1] = score + + if data_source not in already_print_data_sources: + already_print_data_sources[data_source] = 0 + + if already_print_data_sources[data_source] < self.num_examine: + already_print_data_sources[data_source] += 1 + print(sequences_str) + + return reward_tensor diff --git a/verl/workers/reward_manager/prime.py b/verl/workers/reward_manager/prime.py new file mode 100644 index 00000000..f2f8856f --- /dev/null +++ b/verl/workers/reward_manager/prime.py @@ -0,0 +1,134 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from concurrent.futures import ProcessPoolExecutor +from functools import partial + +import torch + +from verl import DataProto +from verl.utils.reward_score import _default_compute_score + + +async def single_compute_score(evaluation_func, completion, reference, task, executor, timeout=300.): + loop = asyncio.get_running_loop() + try: + # Ensure process_completion is called properly + tasks = [ + asyncio.wait_for( + loop.run_in_executor( + executor, + partial(evaluation_func, task, completion, reference) # Ensure synchronous + ), + timeout=timeout) + ] + return await asyncio.gather(*tasks) + except asyncio.TimeoutError: + print(f"Timeout occurred for completion: {completion}") + return None # Default value for timed-out rows + except Exception as e: + print(f"Error processing completion: {completion[:10]}, Error: {e}") + return None # Default value for failed rows + + +async def parallel_compute_score_async(evaluation_func, completions, references, tasks, num_processes=64): + scores = [] + with ProcessPoolExecutor(max_workers=num_processes) as executor: + # Create tasks for all rows + tasks_async = [ + single_compute_score(evaluation_func, completion, reference, task, executor, timeout=300.) + for completion, reference, task in zip(completions, references, tasks) + ] + # to prevent very occasional starvation caused by some anomalous programs ( like infinite loop ), the exceptions in async programs will instantly halt the evaluation, and all summoned processes will be killed. + try: + results = await asyncio.gather(*tasks_async, return_exceptions=False) + except: + for pid, proc in executor._processes.items(): + try: + proc.kill() + except Exception as kill_err: + print('shut down failed: ' + str(kill_err)) + raise + + # Process results + for result, completion, reference, task in zip(results, completions, references, tasks): + if isinstance(result, Exception) or result is None: + # Handle failed or timed-out tasks + scores.append(0.0) + elif isinstance(result[0], (int, float, bool)): + scores.append(float(result[0])) + else: + scores.append(float(result[0][0])) + return scores + + +class PrimeRewardManager: + """ + The Reward Manager used in https://github.com/PRIME-RL/PRIME + """ + + def __init__(self, tokenizer, num_examine, compute_score=None) -> None: + self.tokenizer = tokenizer + self.num_examine = num_examine # the number of batches of decoded responses to print to the console + self.compute_score = compute_score or _default_compute_score + + def __call__(self, data: DataProto): + """We will expand this function gradually based on the available datasets""" + + # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn + if 'rm_scores' in data.batch.keys(): + return data.batch['rm_scores'] + + reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + + already_print_data_sources = {} + + # batched scoring + prompt_ids = data.batch['prompts'] + prompt_length = prompt_ids.shape[-1] + + response_ids = data.batch['responses'] + valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(dim=-1) + sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True) + ground_truth = [data_item.non_tensor_batch['reward_model']['ground_truth'] for data_item in data] + data_sources = data.non_tensor_batch['data_source'] + + assert len(sequences_str) == len(ground_truth) == len(data_sources) + try: + scores = asyncio.run( + parallel_compute_score_async(self.compute_score, + sequences_str, + ground_truth, + data_sources, + num_processes=64)) + except asyncio.TimeoutError as e: + print('Global timeout in reward computing! Setting all as 0.') + scores = [0. for _ in range(len(sequences_str))] + except Exception as e: + print(f"Unexpected error in batched reward computing. Setting all as 0.: {e}") + scores = [0. for _ in range(len(sequences_str))] + + for i in range(len(data)): + data_source = data_sources[i] + reward_tensor[i, valid_response_length[i].item() - 1] = scores[i] + + if data_source not in already_print_data_sources: + already_print_data_sources[data_source] = 0 + + if already_print_data_sources[data_source] < self.num_examine: + already_print_data_sources[data_source] += 1 + print(sequences_str) + + return reward_tensor