-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathppo_config.yml
51 lines (46 loc) · 1.92 KB
/
ppo_config.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
model:
model_path: "finetuned_student_model/" # Name of hf model to load
tokenizer_path: "finetuned_student_model/" # Name of hf tokenizer to load
model_type: "AcceleratePPOModel" # Name of accelerate model type to load
num_layers_unfrozen: 2 # Number of bottom layers to freeze during training
train:
seq_length: 200 # Size of LM context
epochs: 100 # Train for max(epochs, total_steps)
total_steps: 10000 # Train for max(epochs, total_steps)a
batch_size: 48 # batch size
lr_init: 1.412e-4 # init learning rate
lr_target: 1.412e-4 # target final learning rate
opt_betas: [0.9, 0.95] # adam betas
opt_eps: 1.0e-8 # adam eps
weight_decay: 1.0e-6 # weight decay param
checkpoint_interval: 10000 # checkpoint interval
eval_interval: 16 # eval interval
pipeline: "PromptPipeline" # prompt pipeline to load
orchestrator: "PPOOrchestrator" # orchestrator to load
project_name: "0Shot-Critics"
entity_name: "louiscastricato"
method:
name: 'ppoconfig' # Name of RL method config
num_rollouts: 128 # Number of rollouts to collect per epoch
chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator
ppo_epochs: 4 # Number of ppo epochs
init_kl_coef: 0.2 # init kl coefficient
target: 6 # target kl coefficient, set None for fixed kl coef
horizon: 10000 # PPO horizon
gamma: 1 # PPO discount
lam: 0.95 # PPO lambda
cliprange: 0.2 # clip range
cliprange_value: 0.2 # clip range
vf_coef: 2.3 # value term weight
scale_reward: "running" # False | "ref" | "running" estimate against which to scale rewards
ref_mean: null
ref_std: null # rescale rewards with this deviation
cliprange_reward: 10
gen_kwargs:
max_new_tokens: 100 # LM max new tokens per step
min_length: 75 # LM min tokens per step
top_k: 50 # top k
top_p: 0.9 # top p
repetition_penalty: 1.2 # repetition penalty
temperature: 1.0 # temperature
do_sample: True # sample