-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathelo_train_ppo.py
67 lines (53 loc) · 2.2 KB
/
elo_train_ppo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# this file uses trlx and PPO to apply RLHF to LM
from collections import defaultdict
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import trlx
from trlx.data.configs import TRLConfig
from critic_models import GPTSentimentELOCritic, T5SentimentELOCritic
from ppo_utils import elo_schedule
import pandas as pd
from typing import List
if __name__ == "__main__":
def correct_string(string):
return "Product name: " + string +"\nProduct review: "
# load the dataset
df = pd.read_csv("datasets/product_names.csv")
# convert to list & map
prompts = list(map(correct_string, df["product"].tolist()))
# splits
eval_prompts = prompts[-int(len(prompts)*0.05):]
prompts = prompts[:int(len(prompts)*0.95)]
# load critic model
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base").cuda()
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
prompt_dir = "datasets/prompts_reprocessed.csv"
suffix = "positive"
critic_model = T5SentimentELOCritic(model, tokenizer, prompt_dir, suffix=suffix)
# curry to make a static function
def match_function(prior, player1, player2):
return critic_model.match_function(prior, player1, player2)
# initialize the reward_fn
def reward_fn(samples : List[str]) -> List[float]:
"""
samples: list of strings for the samples
prior: string for the prior
Returns a list of rewards for each sample.
"""
# for each sample, take the text after "Product review: "
samples = [sample.split("Product review:")[1] for sample in samples]
# get the match function, No prior
rewards = torch.tensor(list(elo_schedule(None, samples, match_function)[-1].values()))
# normalize the scores using std and mean
rewards = (rewards - torch.mean(rewards)) / torch.std(rewards)
# return the rewards
return rewards.tolist()
# laod TRLConfig
config = TRLConfig.load_yaml("ppo_config.yml")
model = trlx.train(
"finetuned_student_model/",
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=eval_prompts,
config=config
)