From 1cc70e7b1c386d2ed60ed5a77b8bd93fa9bdd4fc Mon Sep 17 00:00:00 2001 From: Shayekh Islam Date: Tue, 25 Feb 2025 15:32:36 +0600 Subject: [PATCH] Skywork-o1-Open-PRM-Qwen-2.5 PRMs (#37) * Add Qwen2.5-1.5B-Instruct recipes * Add Skywork/Skywork-o1-Open-PRM-Qwen-2.5 PRMs * Fixed style for Skywork PRM --- recipes/README.md | 17 + src/sal/models/reward_models.py | 75 ++ src/sal/models/skywork_o1_prm/io_utils.py | 56 ++ .../models/skywork_o1_prm/modeling_base.py | 669 ++++++++++++++++++ src/sal/models/skywork_o1_prm/prm_model.py | 260 +++++++ 5 files changed, 1077 insertions(+) create mode 100644 src/sal/models/skywork_o1_prm/io_utils.py create mode 100644 src/sal/models/skywork_o1_prm/modeling_base.py create mode 100644 src/sal/models/skywork_o1_prm/prm_model.py diff --git a/recipes/README.md b/recipes/README.md index 26154f1..772b8a8 100644 --- a/recipes/README.md +++ b/recipes/README.md @@ -73,6 +73,23 @@ python scripts/test_time_compute.py $CONFIG \ --dataset_split=train ``` +Moreover, to override the choice of PRM, include it in the command line arguments as follows: + +```shell +# Define variables +export CONFIG=recipes/Qwen2.5-1.5B-Instruct/best_of_n.yaml +export PRM=Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B + +# Run test-time compute +python scripts/test_time_compute.py $CONFIG --prm_path=$PRM +``` + +> Currently supported PRMs:
+`RLHFlow/Llama3.1-8B-PRM-Deepseek-Data` (default)
+`peiyi9979/math-shepherd-mistral-7b-prm`
+`Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B`
+`Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B` + ## Replicating the blog post results To replicate the results from our blog post, there are two main steps: diff --git a/src/sal/models/reward_models.py b/src/sal/models/reward_models.py index ad11b89..7b1e1e5 100644 --- a/src/sal/models/reward_models.py +++ b/src/sal/models/reward_models.py @@ -24,6 +24,12 @@ ) from sal.config import Config +from sal.models.skywork_o1_prm.io_utils import ( + derive_step_rewards, + prepare_batch_input_for_model, + prepare_input, +) +from sal.models.skywork_o1_prm.prm_model import SkyworkPRMModel CANDIDATE_TOKENS = [648, 387] STEP_TAG_ID = 12902 @@ -271,6 +277,69 @@ def _score_batched( return reshaped_output_scores +class SkyworkO1(PRM): + @classmethod + def _load_model_and_tokenizer( + cls, prm_model_path, **model_kwargs + ) -> tuple[PreTrainedModel, PreTrainedTokenizer]: + tokenizer = AutoTokenizer.from_pretrained( + prm_model_path, trust_remote_code=True + ) + model = SkyworkPRMModel.from_pretrained( + prm_model_path, + device_map="auto", + torch_dtype=torch.bfloat16, + **model_kwargs, + ).eval() + + return model, tokenizer + + def score( + self, questions: list[str], outputs: list[list[str]] + ) -> list[list[float]]: + # reference code: https://huggingface.co/Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B#huggingface-inference + all_scores = [] + for question, answers in zip(questions, outputs): + processed_data = [ + prepare_input( + question, answer, tokenizer=self.tokenizer, step_token="\n" + ) + for answer in answers + ] + input_ids, steps, reward_flags = zip(*processed_data) + input_ids, attention_mask, reward_flags = prepare_batch_input_for_model( + input_ids, reward_flags, self.tokenizer.pad_token_id + ) + device = self.model.pretrained_model.device + with torch.no_grad(): + _, _, rewards = self.model( + input_ids=input_ids.to(device), + attention_mask=attention_mask.to(device), + return_probs=True, + ) + all_step_scores = derive_step_rewards( + rewards.detach().to("cpu", dtype=torch.float32), reward_flags + ) + all_scores.append(all_step_scores) + return all_scores + + +class SkyworkO1_1_5B(SkyworkO1): + def load_model_and_tokenizer( + self, **model_kwargs + ) -> tuple[PreTrainedModel, PreTrainedTokenizer]: + prm_model_path = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B" + return SkyworkO1._load_model_and_tokenizer(prm_model_path, **model_kwargs) + + +class SkyworkO1_7B(SkyworkO1): + def load_model_and_tokenizer( + self, **model_kwargs + ) -> tuple[PreTrainedModel, PreTrainedTokenizer]: + prm_model_path = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B" + return SkyworkO1._load_model_and_tokenizer(prm_model_path, **model_kwargs) + + def load_prm(config: Config) -> PRM: if config.prm_path == "peiyi9979/math-shepherd-mistral-7b-prm": return MathShepherd(config) @@ -278,4 +347,10 @@ def load_prm(config: Config) -> PRM: if config.prm_path == "RLHFlow/Llama3.1-8B-PRM-Deepseek-Data": return RLHFFlow(config) + if config.prm_path == "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B": + return SkyworkO1_1_5B(config) + + if config.prm_path == "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B": + return SkyworkO1_7B(config) + raise NotImplementedError(f"PRM {config.prm_path} not implemented") diff --git a/src/sal/models/skywork_o1_prm/io_utils.py b/src/sal/models/skywork_o1_prm/io_utils.py new file mode 100644 index 0000000..134ac76 --- /dev/null +++ b/src/sal/models/skywork_o1_prm/io_utils.py @@ -0,0 +1,56 @@ +# Source: https://github.com/SkyworkAI/skywork-o1-prm-inference +import numpy as np +import torch + + +def prepare_input(problem, response, tokenizer, step_token): + prompt_ids = tokenizer.encode(tokenizer.bos_token + problem + "\n") + response_ids = [] + steps = [] + reward_flags = [0] * len(prompt_ids) + step_token_id = tokenizer.encode(step_token)[-1] + for idx, step in enumerate(response.split(step_token)): + if step != "": + step_ids = tokenizer.encode(step) + else: + step_ids = [] + step_ids += [step_token_id] + step = step + step_token + flag = [0] * len(step_ids) + flag[-1] = 1 + response_ids.extend(step_ids) + reward_flags.extend(flag) + steps.append(step) + input_ids = prompt_ids + response_ids + return input_ids, steps, reward_flags + + +def prepare_batch_input_for_model(input_ids, reward_flags, pad_token_id): + padded_input_ids = torch.nn.utils.rnn.pad_sequence( + [torch.LongTensor(ids) for ids in input_ids], + batch_first=True, + padding_value=pad_token_id, + ) + padded_attention_mask = torch.nn.utils.rnn.pad_sequence( + [torch.LongTensor([1] * len(ids)) for ids in input_ids], + batch_first=True, + padding_value=0, + ) + padded_reward_flags = torch.nn.utils.rnn.pad_sequence( + [torch.LongTensor(reward_flag) for reward_flag in reward_flags], + batch_first=True, + padding_value=0, + ) + return padded_input_ids, padded_attention_mask, padded_reward_flags + + +def derive_step_rewards(rewards, reward_flags): + batch_size = rewards.shape[0] + batch_step_rewards = [] + for i in range(batch_size): + rewards_indices = torch.nonzero(reward_flags[i] == 1).view(-1) + step_rewards = [ + rewards[i][rewards_indices[j]].item() for j in range(len(rewards_indices)) + ] + batch_step_rewards.append(step_rewards) + return batch_step_rewards diff --git a/src/sal/models/skywork_o1_prm/modeling_base.py b/src/sal/models/skywork_o1_prm/modeling_base.py new file mode 100644 index 0000000..62e4ad5 --- /dev/null +++ b/src/sal/models/skywork_o1_prm/modeling_base.py @@ -0,0 +1,669 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Source: https://github.com/SkyworkAI/skywork-o1-prm-inference +import json +import logging +import os +import sys +from copy import deepcopy +from typing import Optional + +import torch +import torch.nn as nn +from accelerate import PartialState +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import ( + EntryNotFoundError, + HFValidationError, + LocalEntryNotFoundError, + RepositoryNotFoundError, +) +from safetensors.torch import load_file as safe_load_file +from transformers import PreTrainedModel + +if sys.version_info < (3, 8): + _is_python_greater_3_8 = False +else: + _is_python_greater_3_8 = True + + +def is_transformers_greater_than(current_version: str) -> bool: + if _is_python_greater_3_8: + from importlib.metadata import version + + _transformers_version = version("transformers") + else: + import pkg_resources + + _transformers_version = pkg_resources.get_distribution("transformers").version + return _transformers_version > current_version + + +if is_transformers_greater_than("4.33.0"): + from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +else: + from transformers.deepspeed import is_deepspeed_zero3_enabled + +LAYER_PATTERNS = [ + "transformer.h.{layer}", + "model.decoder.layers.{layer}", + "gpt_neox.layers.{layer}", + "model.layers.{layer}", +] + + +class PreTrainedModelWrapper(nn.Module): + r""" + A wrapper class around a (`transformers.PreTrainedModel`) to be compatible with the + (`~transformers.PreTrained`) class in order to keep some attributes and methods of the + (`~transformers.PreTrainedModel`) class. + + Attributes: + pretrained_model: (`transformers.PreTrainedModel`) + The model to be wrapped. + parent_class: (`transformers.PreTrainedModel`) + The parent class of the model to be wrapped. + supported_args: (`list`) + The list of arguments that are supported by the wrapper class. + """ + + transformers_parent_class = None + supported_args = None + supported_modules = ("v_head",) + supported_rm_modules = ("score",) + supported_pretrained_model_architectures = PreTrainedModel + + def __init__( + self, + pretrained_model=None, + score_module=None, + supports_rm_adapter=False, + rm_adapter_name=None, + **kwargs, + ): + super().__init__() + self.pretrained_model = pretrained_model + + self.config = pretrained_model.config + self.prepare_inputs_for_generation = ( + pretrained_model.prepare_inputs_for_generation + ) + self.is_loaded_in_8bit = getattr(pretrained_model, "is_loaded_in_8bit", False) + self.is_loaded_in_4bit = getattr(pretrained_model, "is_loaded_in_4bit", False) + self.is_sequential_parallel = False + + if hasattr(pretrained_model, "gradient_checkpointing_disable"): + self.gradient_checkpointing_disable = ( + pretrained_model.gradient_checkpointing_disable + ) + + if hasattr(pretrained_model, "gradient_checkpointing_enable"): + self.gradient_checkpointing_enable = ( + pretrained_model.gradient_checkpointing_enable + ) + + if hasattr(pretrained_model, "enable_input_require_grads"): + self.enable_input_require_grads = ( + pretrained_model.enable_input_require_grads + ) + + self.supports_rm_adapter = supports_rm_adapter + self.rm_adapter_name = rm_adapter_name + self.policy_adapter_name = "default" + if score_module is not None: + self.score = score_module + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Instantiates a new model from a pretrained model from `transformers`. The + pretrained model is loaded using the `from_pretrained` method of the + `transformers.PreTrainedModel` class. The arguments that are specific to the + `transformers.PreTrainedModel` class are passed along this method and filtered + out from the `kwargs` argument. + + + Args: + pretrained_model_name_or_path (`str` or `transformers.PreTrainedModel`): + The path to the pretrained model or its name. + *model_args (`list`, *optional*)): + Additional positional arguments passed along to the underlying model's + `from_pretrained` method. + **kwargs (`dict`, *optional*): + Additional keyword arguments passed along to the underlying model's + `from_pretrained` method. We also pre-process the kwargs to extract + the arguments that are specific to the `transformers.PreTrainedModel` + class and the arguments that are specific to trl models. The kwargs + also support `prepare_model_for_kbit_training` arguments from + `peft` library. + """ + if kwargs is not None: + peft_config = kwargs.pop("peft_config", None) + reward_adapter = kwargs.pop("reward_adapter", None) + reward_adapter_name = kwargs.pop("reward_adapter_name", "reward_adapter") + is_trainable = kwargs.pop("is_trainable", False) + trl_model_args, pretrained_kwargs, peft_quantization_kwargs = ( + cls._split_kwargs(kwargs) + ) + token = pretrained_kwargs.get("token", None) + else: + peft_config = None + is_trainable = False + trl_model_args = {} + pretrained_kwargs = {} + peft_quantization_kwargs = {} + token = None + + if reward_adapter is not None and not isinstance(reward_adapter, str): + raise ValueError( + "The `reward_adapter` argument should be a string representing the name of local path or the Hub id to the Reward Modeling adapter." + ) + + is_peft_model = False + + current_device = cls._get_current_device() + if isinstance(pretrained_model_name_or_path, str): + is_loaded_in_8bit = ( + pretrained_kwargs["load_in_8bit"] + if "load_in_8bit" in pretrained_kwargs + else False + ) + is_loaded_in_4bit = ( + pretrained_kwargs["load_in_4bit"] + if "load_in_4bit" in pretrained_kwargs + else False + ) + else: + is_loaded_in_8bit = getattr( + pretrained_model_name_or_path, "is_loaded_in_8bit", False + ) + is_loaded_in_4bit = getattr( + pretrained_model_name_or_path, "is_loaded_in_4bit", False + ) + + if ( + is_loaded_in_8bit or is_loaded_in_4bit + ) and "device_map" not in pretrained_kwargs: + # warn users + logging.warning( + "The `device_map` argument is not provided. We will override the device_map argument." + " to set the entire" + " model on the current device. If you want to set the model on multiple devices, please provide" + " a custom `device_map` argument." + ) + pretrained_kwargs["device_map"] = {"": current_device} + + # First, load the pre-trained model using the parent-class + # either `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM` + if isinstance(pretrained_model_name_or_path, str): + remote_adapter_config = None + local_adapter_present = os.path.exists( + os.path.join(pretrained_model_name_or_path, "adapter_config.json") + ) + pretrained_model = cls.transformers_parent_class.from_pretrained( + pretrained_model_name_or_path, *model_args, **pretrained_kwargs + ) + + elif isinstance( + pretrained_model_name_or_path, cls.supported_pretrained_model_architectures + ): + pretrained_model = pretrained_model_name_or_path + else: + raise ValueError( + "pretrained_model_name_or_path should be a string or a PreTrainedModel, " + f"but is {type(pretrained_model_name_or_path)}" + ) + + # Add reward modeling adapter if specified + if not is_peft_model and reward_adapter is not None: + raise ValueError("reward_adapter can only be used with a PeftModel. ") + elif is_peft_model and reward_adapter is not None: + score_module = cls.add_and_load_reward_modeling_adapter( + pretrained_model, reward_adapter, reward_adapter_name, token=token + ) + multi_adapter_args = { + "score_module": score_module, + "supports_rm_adapter": True, + "rm_adapter_name": reward_adapter_name, + } + else: + multi_adapter_args = {"supports_rm_adapter": False} + + # Then, create the full model by instantiating the wrapper class + model = cls(pretrained_model, **multi_adapter_args, **trl_model_args) + + # if resume_training, load the state_dict again - this is ok since the + # state_dict is removed from the model after loading it. + is_resuming_training = True + if isinstance(pretrained_model_name_or_path, str): + safe_filename = os.path.join( + pretrained_model_name_or_path, "model.safetensors" + ) + filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") + + sharded_index_filename = os.path.join( + pretrained_model_name_or_path, "pytorch_model.bin.index.json" + ) + safe_sharded_index_filename = os.path.join( + pretrained_model_name_or_path, "model.safetensors.index.json" + ) + is_sharded = False + use_safe = os.path.exists(safe_filename) + + if not (os.path.exists(filename) or os.path.exists(safe_filename)): + # Try with `pytorch_model.bin` + filename, files_to_download, is_sharded, is_resuming_training = ( + cls._get_checkpoint_from_hub( + pretrained_model, + pretrained_model_name_or_path, + sharded_index_filename, + token=token, + ) + ) + # Try with safetensors + if filename is None and files_to_download is None: + ( + safe_filename, + files_to_download, + is_sharded, + is_resuming_training, + ) = cls._get_checkpoint_from_hub( + pretrained_model, + pretrained_model_name_or_path, + safe_sharded_index_filename, + token=token, + model_name="model.safetensors", + model_index_name="model.safetensors.index.json", + ) + use_safe = True + else: + use_safe = False + + loading_func = safe_load_file if use_safe else torch.load + load_kwargs = {} if use_safe else {"map_location": "cpu"} + + if is_resuming_training: + if is_sharded: + # download each file and add it to the state_dict + state_dict = {} + + for shard_file in files_to_download: + filename = hf_hub_download( + pretrained_model_name_or_path, + shard_file, + token=token, + ) + state_dict.update(loading_func(filename, **load_kwargs)) + else: + state_dict = loading_func( + filename if not use_safe else safe_filename, **load_kwargs + ) + + else: + state_dict = pretrained_model_name_or_path.state_dict() + + model.is_peft_model = is_peft_model + model.current_device = current_device + + if is_resuming_training: + model.post_init(state_dict=state_dict) + + return model + + @classmethod + def _get_checkpoint_from_hub( + cls, + pretrained_model, + pretrained_model_name_or_path, + index_filename, + token=None, + model_name="pytorch_model.bin", + model_index_name="pytorch_model.bin.index.json", + ): + files_to_download = None + filename = None + is_resuming_training = True + is_sharded = False + + try: + filename = hf_hub_download( + pretrained_model_name_or_path, + model_name, + token=token, + ) + # sharded + except ( + EntryNotFoundError, + LocalEntryNotFoundError, + HFValidationError, + RepositoryNotFoundError, + ): + if os.path.exists(index_filename): + index_file_name = index_filename + else: + try: + index_file_name = hf_hub_download( + pretrained_model_name_or_path, + model_index_name, + token=token, + ) + except ( + EntryNotFoundError, + LocalEntryNotFoundError, + HFValidationError, + RepositoryNotFoundError, + ): + # not continue training, do not have v_head weight + is_resuming_training = False + logging.warning( + f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', " + f"and no v_head weight is found. This IS expected if you are not resuming PPO training." + ) + # load json + if is_resuming_training: + with open(index_file_name) as f: + index = json.load(f) + # check filename with `v_head` or any known extra module: + files_to_download = set() + for k, v in index["weight_map"].items(): + if any(module in k for module in cls.supported_modules): + files_to_download.add(v) + is_sharded = True + + return filename, files_to_download, is_sharded, is_resuming_training + + @classmethod + def _get_current_device(cls): + r""" + Get the current device. For GPU, we return the local process index using the `accelerate.PartialState` + object to handle corner cases when running scripts in distributed environments. + + Returns: + current_device (`Union[int, str]`): + The current device. + """ + state = PartialState() + return state.local_process_index if torch.cuda.is_available() else "cpu" + + @classmethod + def _split_kwargs(cls, kwargs): + """ + Separate the kwargs from the arguments that we support inside + `supported_args` and the ones that we don't. + """ + check_peft_kwargs = False + + supported_kwargs = {} + unsupported_kwargs = {} + peft_kwargs = {} + + for key, value in kwargs.items(): + if key in cls.supported_args: + supported_kwargs[key] = value + else: + unsupported_kwargs[key] = value + + if check_peft_kwargs: + if key in prepare_model_for_kbit_training.__code__.co_varnames: + peft_kwargs[key] = value + if key in unsupported_kwargs: + unsupported_kwargs.pop(key) + + return supported_kwargs, unsupported_kwargs, peft_kwargs + + @classmethod + def add_and_load_reward_modeling_adapter( + cls, + pretrained_model, + adapter_model_id, + adapter_name="reward_model_adapter", + token=None, + ): + r""" + Add and load a reward modeling adapter. This method can only be used if the + model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id` + argument, pointing to the id of the reward modeling adapter. The latest needs also to contain the + score head in order to produce the reward. + """ + pretrained_model.load_adapter( + adapter_model_id, adapter_name, is_trainable=False + ) + pretrained_model.train() + + filename = os.path.join(adapter_model_id, "adapter_model.bin") + safe_loading = False + if not os.path.exists(filename): + try: + local_filename = hf_hub_download( + adapter_model_id, + "adapter_model.bin", + token=token, + ) + except Exception: + filename = os.path.join(adapter_model_id, "adapter_model.safetensors") + safe_loading = True + if not os.path.exists(filename): + try: + local_filename = hf_hub_download( + adapter_model_id, + "adapter_model.safetensors", + token=token, + ) + except Exception as exc: + raise ValueError( + "Could not find adapter model in the Hub, " + "make sure you have the correct adapter model id." + ) from exc + else: + local_filename = filename + else: + local_filename = filename + + loading_func = safe_load_file if safe_loading else torch.load + load_kwargs = {} if safe_loading else {"map_location": "cpu"} + + adapter_state_dict = loading_func(local_filename, **load_kwargs) + + for score_name_candidate in cls.supported_rm_modules: + if any(score_name_candidate in name for name in adapter_state_dict.keys()): + score_name = score_name_candidate + # we have found the correct head name and can break + break + + score_dict = {} + + for name, param in adapter_state_dict.items(): + if score_name in name: + key_name = ".".join(name.split(".")[-1:]) + score_dict[key_name] = param.to(cls._get_current_device()) + + num_labels, hidden_dim = score_dict["weight"].shape + has_bias = any("bias" in name for name in adapter_state_dict.keys()) + + score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to( + device=cls._get_current_device(), + dtype=pretrained_model.dtype, + ) + score.load_state_dict(score_dict) + for param in score.parameters(): + param.requires_grad = False + + return score + + def push_to_hub(self, *args, **kwargs): + r""" + Push the pretrained model to the hub. This method is a wrapper around + `transformers.PreTrainedModel.push_to_hub`. Please refer to the documentation + of `transformers.PreTrainedModel.push_to_hub` for more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's + `push_to_hub` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's + `push_to_hub` method. + """ + raise NotImplementedError + + def save_pretrained(self, *args, **kwargs): + r""" + Save the pretrained model to a directory. This method is a wrapper around + `transformers.PreTrainedModel.save_pretrained`. Please refer to the documentation + of `transformers.PreTrainedModel.save_pretrained` for more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's + `save_pretrained` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's + `save_pretrained` method. + """ + state_dict = kwargs.get("state_dict") + if state_dict is None: + state_dict = self.state_dict() + kwargs["state_dict"] = state_dict + + # if it is a peft model only save the `v_head` state_dict and + # pop the `state_dict` from the kwargs to avoid slient bugs with `peft` + if self.is_peft_model: + save_path = args[0] + save_path = os.path.join(save_path, "pytorch_model.bin") + torch.save(state_dict, save_path) + _ = kwargs.pop("state_dict", None) + + return self.pretrained_model.save_pretrained(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + r""" + Return the state_dict of the pretrained model. + """ + raise NotImplementedError + + def post_init(self, *args, **kwargs): + r""" + Post initialization method. This method is called after the model is + instantiated and loaded from a checkpoint. It can be used to perform + additional operations such as loading the state_dict. + """ + raise NotImplementedError + + def compute_reward_score(self, input_ids, attention_mask=None, **kwargs): + r""" + Computes the reward score for a given input. The method has first to enable the adapter + and then compute the reward score. After that the model disables the reward modeling + adapter and enables the default ppo adapter again. + """ + if not self.supports_rm_adapter: + raise ValueError("This model does not support reward modeling adapter.") + + # enable rm adapter + self.pretrained_model.set_adapter(self.rm_adapter_name) + self.pretrained_model.eval() + + with torch.no_grad(): + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True, + **kwargs, + ) + + last_hidden_states = base_model_output.hidden_states[-1] + scores = self.score(last_hidden_states) + + self.pretrained_model.set_adapter(self.policy_adapter_name) + self.pretrained_model.eval() + + return scores + + +def create_reference_model( + model: PreTrainedModelWrapper, + num_shared_layers: Optional[int] = None, + pattern: Optional[str] = None, +) -> PreTrainedModelWrapper: + """ + Creates a static reference copy of a model. Note that model will be in `.eval()` mode. + + Args: + model (`PreTrainedModelWrapper`): The model to be copied. + num_shared_layers (`int`, *optional*): The number of initial layers that are shared between both models and kept frozen. + pattern (`str`, *optional*): The shared layers are selected with a string pattern + (e.g. "transformer.h.{layer}" for GPT2) and if a custom pattern is necessary it can be passed here. + + Returns + `PreTrainedModelWrapper` + """ + if is_deepspeed_zero3_enabled(): + raise ValueError( + "DeepSpeed ZeRO-3 is enabled and is not compatible with `create_reference_model()`. Please instantiate your reference model directly with `AutoCausalLM.from_pretrained()`." + ) + + parameter_names = [n for n, _ in model.named_parameters()] + ref_model = deepcopy(model) + + # if no layers are shared, return copy of model + if num_shared_layers is None: + for param_name in parameter_names: + param = ref_model.get_parameter(param_name) + param.requires_grad = False + return ref_model.eval() + + # identify layer name pattern + if pattern is not None: + pattern = pattern.format(layer=num_shared_layers) + else: + for pattern_candidate in LAYER_PATTERNS: + pattern_candidate = pattern_candidate.format(layer=num_shared_layers) + if any(pattern_candidate in name for name in parameter_names): + pattern = pattern_candidate + break + + if pattern is None: + raise ValueError("Layer pattern could not be matched.") + + # divide parameters in shared and unshared parameter lists + shared_param_list = [] + unshared_param_list = [] + + shared_parameter = True + for name, _param in model.named_parameters(): + if pattern in name: + shared_parameter = False + if shared_parameter: + shared_param_list.append(name) + else: + unshared_param_list.append(name) + + # create reference of the original parameter if they are shared + for param_name in shared_param_list: + param = model.get_parameter(param_name) + param.requires_grad = False + + _ref_param = ref_model.get_parameter(param_name) + + # for all other parameters just make sure they don't use gradients + for param_name in unshared_param_list: + param = ref_model.get_parameter(param_name) + param.requires_grad = False + + if pattern is not None and len(unshared_param_list) == 0: + logging.warning( + "Pattern passed or found, but no layers matched in the model. Check for a typo." + ) + + return ref_model.eval() diff --git a/src/sal/models/skywork_o1_prm/prm_model.py b/src/sal/models/skywork_o1_prm/prm_model.py new file mode 100644 index 0000000..907afe8 --- /dev/null +++ b/src/sal/models/skywork_o1_prm/prm_model.py @@ -0,0 +1,260 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Source: https://github.com/SkyworkAI/skywork-o1-prm-inference +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM + +from .modeling_base import PreTrainedModelWrapper + + +class ValueHead(nn.Module): + r""" + The ValueHead class implements a head for GPT2 that returns a scalar for each output token. + """ + + def __init__(self, config, **kwargs): + super().__init__() + if not hasattr(config, "summary_dropout_prob"): + summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1) + else: + summary_dropout_prob = config.summary_dropout_prob + + self.dropout = ( + nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() + ) + + # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m + if hasattr(config, "hidden_size"): + hidden_size = config.hidden_size + if hasattr(config, "word_embed_proj_dim"): + hidden_size = config.word_embed_proj_dim + elif hasattr(config, "is_encoder_decoder"): + if config.is_encoder_decoder and hasattr(config, "decoder"): + if hasattr(config.decoder, "hidden_size"): + hidden_size = config.decoder.hidden_size + + self.summary = nn.Linear(hidden_size, 1) + + self.flatten = nn.Flatten() + + def forward(self, hidden_states): + output = self.dropout(hidden_states) + + # For now force upcast in fp32 if needed. Let's keep the + # output in fp32 for numerical stability. + if output.dtype != self.summary.weight.dtype: + output = output.to(self.summary.weight.dtype) + + output = self.summary(output) + return output + + +class SkyworkPRMModel(PreTrainedModelWrapper): + transformers_parent_class = AutoModelForCausalLM + lm_head_namings = ["lm_head", "embed_out"] + supported_args = ( + "summary_dropout_prob", + "v_head_initializer_range", + "v_head_init_strategy", + ) + + def __init__(self, pretrained_model, **kwargs): + r""" + Initializes the model. + + Args: + pretrained_model (`transformers.PreTrainedModel`): + The model to wrap. It should be a causal language model such as GPT2. + or any model mapped inside the `AutoModelForCausalLM` class. + kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the `ValueHead` class. + """ + super().__init__(pretrained_model, **kwargs) + v_head_kwargs, _, _ = self._split_kwargs(kwargs) + + if not any( + hasattr(self.pretrained_model, attribute) + for attribute in self.lm_head_namings + ): + raise ValueError( + "The model does not have a language model head, please use a model that has one." + ) + + self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) + + self._init_weights(**v_head_kwargs) + + def _init_weights(self, **kwargs): + r""" + Initializes the weights of the value head. The default initialization strategy is random. + Users can pass a different initialization strategy by passing the `v_head_init_strategy` argument + when calling `.from_pretrained`. Supported strategies are: + - `normal`: initializes the weights with a normal distribution. + + Args: + **kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the `ValueHead` class. These arguments + can contain the `v_head_init_strategy` argument as well as the `v_head_initializer_range` + argument. + """ + initializer_range = kwargs.pop("v_head_initializer_range", 0.2) + # random init by default + init_strategy = kwargs.pop("v_head_init_strategy", None) + if init_strategy is None: + # do nothing + pass + elif init_strategy == "normal": + self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) + self.v_head.summary.bias.data.zero_() + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + return_past_key_values=False, + return_probs=False, + **kwargs, + ): + r""" + Applies a forward pass to the wrapped model and returns the logits of the value head. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`): + Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + (see `past_key_values` input) to speed up sequential decoding. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + return_past_key_values (bool): A flag indicating if the computed hidden-states should be returned. + kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the wrapped model. + """ + kwargs["output_hidden_states"] = ( + True # this had already been set in the LORA / PEFT examples + ) + kwargs["past_key_values"] = past_key_values + + if ( + self.is_peft_model + and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING" + ): + kwargs.pop("past_key_values") + + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs, + ) + + last_hidden_state = base_model_output.hidden_states[-1] + lm_logits = base_model_output.logits + loss = base_model_output.loss + + if last_hidden_state.device != self.v_head.summary.weight.device: + last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device) + + value = self.v_head(last_hidden_state).squeeze(-1) # logits_diff + + if return_probs: + value = torch.nn.functional.sigmoid(value) # convert logits_diff_to_Probs + + # force upcast in fp32 if logits are in half-precision + if lm_logits.dtype != torch.float32: + lm_logits = lm_logits.float() + + if return_past_key_values: + return (lm_logits, loss, value, base_model_output.past_key_values) + else: + return (lm_logits, loss, value) + + def generate(self, *args, **kwargs): + r""" + A simple wrapper around the `generate` method of the wrapped model. + Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils) + method of the wrapped model for more information about the supported arguments. + + Args: + *args (`list`, *optional*): + Positional arguments passed to the `generate` method of the wrapped model. + **kwargs (`dict`, *optional*): + Keyword arguments passed to the `generate` method of the wrapped model. + """ + return self.pretrained_model.generate(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + r""" + Returns the state dictionary of the model. We add the state dictionary of the value head + to the state dictionary of the wrapped model by prepending the key with `v_head.`. + """ + if not self.is_peft_model: + pretrained_model_state_dict = self.pretrained_model.state_dict( + *args, **kwargs + ) + else: + # if it is a peft model, only save the v_head + pretrained_model_state_dict = {} + + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + pretrained_model_state_dict[f"v_head.{k}"] = v + return pretrained_model_state_dict + + def push_to_hub(self, *args, **kwargs): + self.pretrained_model.v_head = self.v_head + + return self.pretrained_model.push_to_hub(*args, **kwargs) + + def post_init(self, state_dict): + r""" + We add the state dictionary of the value head to the state dictionary of the wrapped model + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + if hasattr(self.pretrained_model, "hf_device_map"): + if ( + "cpu" in self.pretrained_model.hf_device_map.values() + or "disk" in self.pretrained_model.hf_device_map.values() + ): + raise ValueError( + "The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models." + ) + + first_device = list(set(self.pretrained_model.hf_device_map.values()))[0] + if isinstance(first_device, int): + first_device = f"cuda:{first_device}" + self.v_head = self.v_head.to(first_device) + + def set_device_hook(module, input, outputs): + new_output = () + for output in outputs: + if isinstance(output, torch.Tensor): + new_output += (output.to(first_device),) + else: + new_output += (output,) + return new_output + + self.register_forward_hook(set_device_hook) + + self.is_sequential_parallel = True