Skip to content

Commit

Permalink
make quality & make style
Browse files Browse the repository at this point in the history
  • Loading branch information
ShikaiChen committed Feb 7, 2025
1 parent 9c38d95 commit 1221ccd
Showing 1 changed file with 52 additions and 46 deletions.
98 changes: 52 additions & 46 deletions rewardbench/models/lenovo.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,30 @@
import torch
from dataclasses import dataclass
from typing import Optional, List, Tuple
import os
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from transformers import Gemma2PreTrainedModel, Gemma2Model
from transformers.utils import ModelOutput
from accelerate import infer_auto_device_map, dispatch_model
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory
from huggingface_hub import snapshot_download
from transformers import Gemma2Model, Gemma2PreTrainedModel
from transformers.utils import ModelOutput


class MultiOutputNN(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dims=[4096, 4096]):
super(MultiOutputNN, self).__init__()

layers = []

layers.append(nn.Linear(input_dim, hidden_dims[0]))
layers.append(nn.LeakyReLU())

for i in range(1, len(hidden_dims)):
layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
layers.append(nn.LeakyReLU())

layers.append(nn.Linear(hidden_dims[-1], output_dim))

self.network = nn.Sequential(*layers)
self.softmax = nn.Softmax(dim=-1)

Expand All @@ -35,12 +34,13 @@ def forward(self, x):


class GatingNN(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim=4096, num_layers=2, temperature=1.0, dropout_prob=0.0, softmax=False):
def __init__(
self, input_dim, output_dim, hidden_dim=4096, num_layers=2, temperature=1.0, dropout_prob=0.0, softmax=False
):
super(GatingNN, self).__init__()
self.temperature = temperature
self.softmax = softmax
layers = []

layers.append(nn.Linear(input_dim, hidden_dim))
layers.append(nn.LeakyReLU())
layers.append(nn.Dropout(dropout_prob))
Expand All @@ -51,7 +51,6 @@ def __init__(self, input_dim, output_dim, hidden_dim=4096, num_layers=2, tempera
layers.append(nn.Dropout(dropout_prob))

layers.append(nn.Linear(hidden_dim, output_dim))

self.network = nn.Sequential(*layers)

def forward(self, x):
Expand All @@ -60,9 +59,9 @@ def forward(self, x):
x = F.softmax(x / self.temperature, dim=1)
return x


@dataclass
class CustomOutput(ModelOutput):

rewards: torch.FloatTensor = None
hidden_state: Optional[Tuple[torch.FloatTensor, ...]] = None
score: Optional[torch.FloatTensor] = None
Expand All @@ -79,23 +78,25 @@ def __init__(self, config):
config_dict = config.to_dict()
self.num_objectives = config_dict.get("num_objectives", 220)
self.regression_layer = MultiOutputNN(config.hidden_size, self.num_objectives)
self.gating_layer = GatingNN(config.hidden_size,
self.num_objectives // 10,
temperature=config_dict.get("temperature", 1.0),
softmax=config_dict.get("softmax", False))
self.gating_layer = GatingNN(
config.hidden_size,
self.num_objectives // 10,
temperature=config_dict.get("temperature", 1.0),
softmax=config_dict.get("softmax", False),
)

def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> CustomOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

Expand Down Expand Up @@ -136,9 +137,12 @@ def forward(
with torch.autocast(device_type=hidden_states.device.type, dtype=torch.float32):
rewards = self.regression_layer(hidden_states)
weights = self.gating_layer(hidden_states)
weights = weights.unsqueeze(1)
weights = weights.unsqueeze(1)
total_reward_distribution = torch.bmm(weights, rewards).squeeze(1)
score = (total_reward_distribution * torch.linspace(0, 1, total_reward_distribution.size(-1)).to(tokens_hidden_states.device)).sum(-1)
score = (
total_reward_distribution
* torch.linspace(0, 1, total_reward_distribution.size(-1)).to(tokens_hidden_states.device)
).sum(-1)
return CustomOutput(
rewards=rewards,
weights=weights,
Expand All @@ -147,44 +151,47 @@ def forward(
score=score,
logits=score,
)

def save_pretrained(self, save_directory: str):
self.model.save_pretrained(save_directory, dtype=torch.bfloat16)
torch.save(self.regression_layer.state_dict(), os.path.join(save_directory, "regression_layer.pt"))
torch.save(self.gating_layer.state_dict(), os.path.join(save_directory, "gating_layer.pt"))
self.config.save_pretrained(save_directory)


@classmethod
def from_pretrained(cls, load_directory, device_map=None, *model_args, **kwargs):
if not os.path.exists(load_directory):
cached_dir = snapshot_download(repo_id=load_directory)
else:
cached_dir = load_directory
model = super(LDLRewardModel27B, cls).from_pretrained(cached_dir, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
model = super(LDLRewardModel27B, cls).from_pretrained(
cached_dir, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)

model.regression_layer = model.regression_layer.float()
regression_layer_path = os.path.join(cached_dir, "regression_layer.pt")
regression_layer_state_dict = torch.load(regression_layer_path, map_location="cpu")
model.regression_layer.load_state_dict(regression_layer_state_dict)

model.gating_layer = model.gating_layer.float()
gating_layer_path = os.path.join(cached_dir, "gating_layer.pt")
gating_layer_state_dict = torch.load(gating_layer_path, map_location="cpu")
model.gating_layer.load_state_dict(gating_layer_state_dict)
if device_map == 'auto' or device_map == 'balanced':
max_memory = get_balanced_memory(model, no_split_module_classes=["Gemma2DecoderLayer", 'Gemma2RMSNorm'])

if device_map == "auto" or device_map == "balanced":
max_memory = get_balanced_memory(model, no_split_module_classes=["Gemma2DecoderLayer", "Gemma2RMSNorm"])
device_map = infer_auto_device_map(
model,
no_split_module_classes=["Gemma2DecoderLayer", 'Gemma2RMSNorm'],
max_memory=max_memory,)
model,
no_split_module_classes=["Gemma2DecoderLayer", "Gemma2RMSNorm"],
max_memory=max_memory,
)
model = dispatch_model(model, device_map=device_map)
elif device_map is not None:
raise NotImplementedError("Write your own device map")

return model


class LenovoPipeline:
def __init__(self, task, model, tokenizer):
self.task = task
Expand Down Expand Up @@ -227,10 +234,9 @@ def __call__(self, samples, return_inputs=False, **kwargs):
input_ids[torch.arange(input_ids.size(0)), seq_second] == bos_token_id
)


if double_bos_mask.any():
inputs['attention_mask'] = inputs['attention_mask'][:,1:]
inputs['input_ids'] = inputs['input_ids'][:,1:]
inputs["attention_mask"] = inputs["attention_mask"][:, 1:]
inputs["input_ids"] = inputs["input_ids"][:, 1:]

with torch.no_grad():
outputs = self.model(**inputs)
Expand Down

0 comments on commit 1221ccd

Please sign in to comment.