-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
83 lines (77 loc) · 2.44 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
import torch
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import os
class SaveBestModel:
"""
Class to save the best model while training. If the current iteration's
validation mAP is higher than the previous least less, then save the
model state.
"""
def __init__(self, best_valid=1000):
self.best_valid = best_valid
def __call__(
self,
best_valid,
iteration,
encoder,
decoder,
encoder_optimizer,
decoder_optimizer,
voc,
embedding,
train_loss_per_iteration,
valid_loss_per_iteration,
encoder_scheduler,
decoder_scheduler,
directory
):
if best_valid < self.best_valid:
self.best_valid = best_valid
torch.save(
{
"iteration": iteration,
"en": encoder.state_dict(),
"de": decoder.state_dict(),
"en_opt": encoder_optimizer.state_dict(),
"de_opt": decoder_optimizer.state_dict(),
"voc_dict": voc.__dict__,
"embedding": embedding.state_dict(),
"train_loss_per_iteration": train_loss_per_iteration,
"valid_loss_per_iteration": valid_loss_per_iteration,
"encoder_scheduler": encoder_scheduler.state_dict(),
"decoder_scheduler": decoder_scheduler.state_dict()
},
os.path.join(directory, f"best_model_checkpoint.tar"),
)
def save_last_model(
iteration,
encoder,
decoder,
encoder_optimizer,
decoder_optimizer,
voc,
embedding,
train_loss_per_iteration,
valid_loss_per_iteration,
encoder_scheduler,
decoder_scheduler,
directory
):
torch.save(
{
"iteration": iteration,
"en": encoder.state_dict(),
"de": decoder.state_dict(),
"en_opt": encoder_optimizer.state_dict(),
"de_opt": decoder_optimizer.state_dict(),
"voc_dict": voc.__dict__,
"embedding": embedding.state_dict(),
"train_loss_per_iteration": train_loss_per_iteration,
"valid_loss_per_iteration": valid_loss_per_iteration,
"encoder_scheduler": encoder_scheduler.state_dict(),
"decoder_scheduler": decoder_scheduler.state_dict()
},
os.path.join(directory, f"last_model_checkpoint.tar"),
)