Skip to content

Commit

Permalink
Skywork-o1-Open-PRM-Qwen-2.5 PRMs (#37)
Browse files Browse the repository at this point in the history
* Add Qwen2.5-1.5B-Instruct recipes

* Add Skywork/Skywork-o1-Open-PRM-Qwen-2.5 PRMs

* Fixed style for Skywork PRM
  • Loading branch information
ShayekhBinIslam authored Feb 25, 2025
1 parent 1f4604d commit 1cc70e7
Show file tree
Hide file tree
Showing 5 changed files with 1,077 additions and 0 deletions.
17 changes: 17 additions & 0 deletions recipes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,23 @@ python scripts/test_time_compute.py $CONFIG \
--dataset_split=train
```

Moreover, to override the choice of PRM, include it in the command line arguments as follows:

```shell
# Define variables
export CONFIG=recipes/Qwen2.5-1.5B-Instruct/best_of_n.yaml
export PRM=Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B

# Run test-time compute
python scripts/test_time_compute.py $CONFIG --prm_path=$PRM
```

> Currently supported PRMs: <br>
`RLHFlow/Llama3.1-8B-PRM-Deepseek-Data` (default) <br>
`peiyi9979/math-shepherd-mistral-7b-prm`<br>
`Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B`<br>
`Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B`

## Replicating the blog post results

To replicate the results from our blog post, there are two main steps:
Expand Down
75 changes: 75 additions & 0 deletions src/sal/models/reward_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
)

from sal.config import Config
from sal.models.skywork_o1_prm.io_utils import (
derive_step_rewards,
prepare_batch_input_for_model,
prepare_input,
)
from sal.models.skywork_o1_prm.prm_model import SkyworkPRMModel

CANDIDATE_TOKENS = [648, 387]
STEP_TAG_ID = 12902
Expand Down Expand Up @@ -271,11 +277,80 @@ def _score_batched(
return reshaped_output_scores


class SkyworkO1(PRM):
@classmethod
def _load_model_and_tokenizer(
cls, prm_model_path, **model_kwargs
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
tokenizer = AutoTokenizer.from_pretrained(
prm_model_path, trust_remote_code=True
)
model = SkyworkPRMModel.from_pretrained(
prm_model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
**model_kwargs,
).eval()

return model, tokenizer

def score(
self, questions: list[str], outputs: list[list[str]]
) -> list[list[float]]:
# reference code: https://huggingface.co/Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B#huggingface-inference
all_scores = []
for question, answers in zip(questions, outputs):
processed_data = [
prepare_input(
question, answer, tokenizer=self.tokenizer, step_token="\n"
)
for answer in answers
]
input_ids, steps, reward_flags = zip(*processed_data)
input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
input_ids, reward_flags, self.tokenizer.pad_token_id
)
device = self.model.pretrained_model.device
with torch.no_grad():
_, _, rewards = self.model(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
return_probs=True,
)
all_step_scores = derive_step_rewards(
rewards.detach().to("cpu", dtype=torch.float32), reward_flags
)
all_scores.append(all_step_scores)
return all_scores


class SkyworkO1_1_5B(SkyworkO1):
def load_model_and_tokenizer(
self, **model_kwargs
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
prm_model_path = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
return SkyworkO1._load_model_and_tokenizer(prm_model_path, **model_kwargs)


class SkyworkO1_7B(SkyworkO1):
def load_model_and_tokenizer(
self, **model_kwargs
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
prm_model_path = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B"
return SkyworkO1._load_model_and_tokenizer(prm_model_path, **model_kwargs)


def load_prm(config: Config) -> PRM:
if config.prm_path == "peiyi9979/math-shepherd-mistral-7b-prm":
return MathShepherd(config)

if config.prm_path == "RLHFlow/Llama3.1-8B-PRM-Deepseek-Data":
return RLHFFlow(config)

if config.prm_path == "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B":
return SkyworkO1_1_5B(config)

if config.prm_path == "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B":
return SkyworkO1_7B(config)

raise NotImplementedError(f"PRM {config.prm_path} not implemented")
56 changes: 56 additions & 0 deletions src/sal/models/skywork_o1_prm/io_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Source: https://github.com/SkyworkAI/skywork-o1-prm-inference
import numpy as np
import torch


def prepare_input(problem, response, tokenizer, step_token):
prompt_ids = tokenizer.encode(tokenizer.bos_token + problem + "\n")
response_ids = []
steps = []
reward_flags = [0] * len(prompt_ids)
step_token_id = tokenizer.encode(step_token)[-1]
for idx, step in enumerate(response.split(step_token)):
if step != "":
step_ids = tokenizer.encode(step)
else:
step_ids = []
step_ids += [step_token_id]
step = step + step_token
flag = [0] * len(step_ids)
flag[-1] = 1
response_ids.extend(step_ids)
reward_flags.extend(flag)
steps.append(step)
input_ids = prompt_ids + response_ids
return input_ids, steps, reward_flags


def prepare_batch_input_for_model(input_ids, reward_flags, pad_token_id):
padded_input_ids = torch.nn.utils.rnn.pad_sequence(
[torch.LongTensor(ids) for ids in input_ids],
batch_first=True,
padding_value=pad_token_id,
)
padded_attention_mask = torch.nn.utils.rnn.pad_sequence(
[torch.LongTensor([1] * len(ids)) for ids in input_ids],
batch_first=True,
padding_value=0,
)
padded_reward_flags = torch.nn.utils.rnn.pad_sequence(
[torch.LongTensor(reward_flag) for reward_flag in reward_flags],
batch_first=True,
padding_value=0,
)
return padded_input_ids, padded_attention_mask, padded_reward_flags


def derive_step_rewards(rewards, reward_flags):
batch_size = rewards.shape[0]
batch_step_rewards = []
for i in range(batch_size):
rewards_indices = torch.nonzero(reward_flags[i] == 1).view(-1)
step_rewards = [
rewards[i][rewards_indices[j]].item() for j in range(len(rewards_indices))
]
batch_step_rewards.append(step_rewards)
return batch_step_rewards
Loading

0 comments on commit 1cc70e7

Please sign in to comment.