-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlr_scheduler.py
111 lines (102 loc) · 3.52 KB
/
lr_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import tensorflow as tf
import numpy as np
import os
class EarlyStopping:
"""Stop training when a monitored quantity has stopped improving.
Arguments:
monitor: Quantity to be monitored.
min_delta: Minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no
improvement.
patience: Number of epochs with no improvement
after which training will be stopped.
verbose: verbosity mode.
mode: One of `{"auto", "min", "max"}`. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity.
baseline: Baseline value for the monitored quantity.
Training will stop if the model doesn't show improvement over the
baseline.
restore_best_weights: Whether to restore model weights from
the epoch with the best value of the monitored quantity.
If False, the model weights obtained at the last step of
training are used.
"""
def __init__(self,
model = None,
# monitor= 'val',
min_delta=0,
patience=0,
verbose=0,
mode='min',
baseline=None,
restore_best_weights=True,
args = None,
root = None):
#super(EarlyStopping, self).__init__()
self.model = model
# self.monitor = monitor
self.patience = patience
self.verbose = verbose
self.baseline = baseline
self.min_delta = abs(min_delta)
self.wait = 0
self.stopped_epoch = 0
self.restore_best_weights = restore_best_weights
self.best_weights = None
self.stop_training = False
self.args = args
self.root = root
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
self.on_train_begin()
def on_train_begin(self, logs=None):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
if self.baseline is not None:
self.best = self.baseline
else:
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def on_epoch_end(self, epoch, monitor, logs=None):
# current = self.get_monitor_value(logs)
current = monitor
if current is None:
return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
if self.restore_best_weights:
self.best_weights = self.model.get_weights()
# if self.args.save_model:
# self.save_model()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
# self.model.stop_training = True
self.stop_training = True
if self.restore_best_weights:
if self.verbose > 0:
print('Restoring model weights from the end of the best epoch.')
self.model.set_weights(self.best_weights)
return self.stop_training
def save_model(self):
if self.root:
self.root.save(os.path.join(self.args.load or self.args.path, 'checkpoint.pt'))