-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodel.py
54 lines (43 loc) · 1.99 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from accelerate import Accelerator
class StepRunner:
def __init__(self, net, loss_fn, accelerator=None, stage="train", metrics_dict=None,
optimizer=None, lr_scheduler=None
):
self.net, self.loss_fn, self.metrics_dict, self.stage = net, loss_fn, metrics_dict, stage
self.optimizer, self.lr_scheduler = optimizer, lr_scheduler
self.accelerator = accelerator if accelerator is not None else Accelerator()
if self.stage == 'train':
self.net.train()
else:
self.net.eval()
def __call__(self, batch):
# loss
with self.accelerator.autocast():
loss = self.net(input_ids=batch["input_ids"], labels=batch["labels"]).loss
# backward()
if self.optimizer is not None and self.stage == "train":
self.accelerator.backward(loss)
if self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.net.parameters(), 1.0)
self.optimizer.step()
if self.lr_scheduler is not None:
self.lr_scheduler.step()
self.optimizer.zero_grad()
all_loss = self.accelerator.gather(loss).sum()
# losses (or plain metrics that can be averaged)
step_losses = {self.stage + "_loss": all_loss.item()}
# metrics (stateful metrics)
step_metrics = {}
if self.stage == "train":
if self.optimizer is not None:
step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
else:
step_metrics['lr'] = 0.0
return step_losses, step_metrics
# 仅仅保存lora可训练参数
def save_ckpt(self, ckpt_path='checkpoint.pt', accelerator=None):
unwrap_net = accelerator.unwrap_model(self.net)
unwrap_net.save_pretrained(ckpt_path)
def load_ckpt(self, ckpt_path='checkpoint.pt'):
self.net = self.net.from_pretrained(self.net, ckpt_path)
self.from_scratch = False