-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
de282dd
commit 8e6c9a4
Showing
3 changed files
with
174 additions
and
0 deletions.
There are no files selected for viewing
51 changes: 51 additions & 0 deletions
51
applications/ColossalChat/coati/distributed/reward/reward_fn.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import torch | ||
|
||
from .reward_utils import extract_solution, validate_response_structure | ||
|
||
|
||
def math_reward_fn(input_ids, **kwargs): | ||
# apply varifiable reward | ||
# reward 10 points if the final answer is correct, reward 1 point if format is correct | ||
|
||
gt_answer = kwargs["gt_answer"] | ||
tokenizer = kwargs["tokenizer"] | ||
s, e = kwargs["response_start"], kwargs["response_end"] | ||
reward = torch.tensor(0.0).to(input_ids.device) | ||
if gt_answer is None: | ||
return reward | ||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) | ||
final_answer, processed_str = extract_solution(decoded_final_answer) | ||
|
||
format_valid = validate_response_structure(processed_str, kwargs["tags"]) | ||
if not format_valid: | ||
return reward | ||
else: | ||
reward += 1.0 | ||
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): | ||
reward = reward + 9.0 | ||
return reward | ||
|
||
|
||
def gsm8k_reward_fn(input_ids, **kwargs): | ||
gt_answer = kwargs["gt_answer"] | ||
tokenizer = kwargs["tokenizer"] | ||
s, e = kwargs["response_start"], kwargs["response_end"] | ||
reward = torch.tensor(0.0).to(input_ids.device) | ||
if gt_answer is None: | ||
return reward | ||
decoded_final_answer = tokenizer.decode(input_ids[s:e], skip_special_tokens=True) | ||
final_answer, processed_str = extract_solution(decoded_final_answer) | ||
is_valid = True | ||
try: | ||
int(final_answer.strip()) | ||
except Exception: | ||
is_valid = False | ||
|
||
format_valid = validate_response_structure(processed_str, kwargs["tags"]) | ||
if not is_valid or not format_valid: | ||
return reward | ||
else: | ||
reward += 1.0 | ||
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): | ||
reward = reward + 9.0 | ||
return reward |
76 changes: 76 additions & 0 deletions
76
applications/ColossalChat/coati/distributed/reward/reward_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright Unakar | ||
# Modified from https://github.com/Unakar/Logic-RL/blob/086373176ac198c97277ff50f4b6e7e1bfe669d3/verl/utils/reward_score/kk.py#L99 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import re | ||
from typing import Dict, Optional, Tuple | ||
|
||
|
||
def validate_response_structure(processed_str: str, tags: Dict = None) -> bool: | ||
"""Performs comprehensive validation of response structure. | ||
Args: | ||
processed_str: Processed response string from the model | ||
Returns: | ||
Boolean indicating whether all formatting requirements are met | ||
""" | ||
validation_passed = True | ||
# Check required tags | ||
if tags is None: | ||
tags = { | ||
"think_start": {"text": "<think>", "num_occur": 1}, | ||
"think_end": {"text": "</think>", "num_occur": 1}, | ||
"answer_start": {"text": "<answer>", "num_occur": 1}, | ||
"answer_end": {"text": "</answer>", "num_occur": 1}, | ||
} | ||
positions = {} | ||
for tag_name, tag_info in tags.items(): | ||
tag_str = tag_info["text"] | ||
expected_count = tag_info["num_occur"] | ||
count = processed_str.count(tag_str) | ||
positions[tag_name] = pos = processed_str.find(tag_str) | ||
if count != expected_count: | ||
validation_passed = False | ||
# Verify tag order | ||
if ( | ||
positions["think_start"] > positions["think_end"] | ||
or positions["think_end"] > positions["answer_start"] | ||
or positions["answer_start"] > positions["answer_end"] | ||
): | ||
validation_passed = False | ||
if len(processed_str) - positions["answer_end"] != len(tags["answer_end"]["text"]): | ||
validation_passed = False | ||
return validation_passed | ||
|
||
|
||
def extract_solution(solution_str: str) -> Tuple[Optional[str], str]: | ||
"""Extracts the final answer from the model's response string. | ||
Args: | ||
solution_str: Raw response string from the language model | ||
Returns: | ||
Tuple containing (extracted_answer, processed_string) | ||
""" | ||
|
||
# Extract final answer using XML-style tags | ||
answer_pattern = r"<answer>(.*?)</answer>" | ||
matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL)) | ||
|
||
if not matches: | ||
return None, solution_str | ||
|
||
final_answer = matches[-1].group(1).strip() | ||
return final_answer, solution_str |
47 changes: 47 additions & 0 deletions
47
applications/ColossalChat/coati/distributed/reward/verifiable_reward.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" | ||
Function-based reward verification module. | ||
""" | ||
|
||
from typing import Any, Dict, List | ||
|
||
import torch | ||
|
||
|
||
class VerifiableReward: | ||
def __init__(self, reward_fn: List[callable], reward_args: List[Dict[str, Any]]): | ||
self.reward_fn = reward_fn | ||
self.reward_args = reward_args | ||
|
||
def __call__( | ||
self, | ||
input_ids: torch.LongTensor, | ||
attention_mask: torch.LongTensor, | ||
response_start: List[int] = None, | ||
response_end: List[int] = None, | ||
gt_answer: List[str] = None, | ||
) -> torch.Tensor: | ||
# Get batch size | ||
bs = input_ids.size(0) | ||
# Initialize reward | ||
reward = torch.zeros(bs, device=input_ids.device) | ||
|
||
# Loop through reward functions | ||
for reward_fn in self.reward_fn_list: | ||
# Apply the reward function to the entire batch at once | ||
reward_batch = torch.stack( | ||
[ | ||
reward_fn( | ||
input_ids[i], | ||
attention_mask[i], | ||
response_start=response_start[i], | ||
response_end=response_end[i], | ||
gt_answer=gt_answer[i], | ||
**self.kwargs, | ||
) | ||
for i in range(bs) | ||
], | ||
dim=0, | ||
) | ||
|
||
rewards += reward_batch | ||
return rewards |