-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_lr_schedules.py
67 lines (47 loc) · 1.8 KB
/
test_lr_schedules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
from copy import deepcopy
from lr_schedules import LRConstantSchedule, LRExponentialDecaySchedule, LRCyclingSchedule
def test_lr_constant_schedule():
lr_initial = 0.1
lr = deepcopy(lr_initial)
lr_constant_schedule = LRConstantSchedule(lr_initial)
iter_n = 100
for i in range(iter_n):
lr_constant_schedule.apply_schedule()
lr = lr_constant_schedule.get_lr()
np.testing.assert_array_equal(lr, lr_initial)
print("test_lr_constant_schedule passed")
def test_lr_exponential_decay_schedule():
lr_initial = 0.1
lr = deepcopy(lr_initial)
decay_steps = 10
decay_rate = 0.9
lr_exponential_decay_schedule = LRExponentialDecaySchedule(lr_initial, decay_steps, decay_rate)
iter_n = 100
for i in range(iter_n):
lr_exponential_decay_schedule.apply_schedule()
lr = lr_exponential_decay_schedule.get_lr()
lr_true = lr_initial * decay_rate ** (i / decay_steps)
np.testing.assert_array_equal(lr, lr_true)
print("test_lr_exponential_decay_schedule passed")
def test_lr_cycling_schedule():
lr_initial = 0.1
lr = deepcopy(lr_initial)
lr_max = 0.5
step_size = 10
lr_list = []
lr_true_list = []
lr_cycling_schedule = LRCyclingSchedule(lr_initial, lr_max, step_size)
iter_n = 101
for i in range(iter_n):
lr_cycling_schedule.apply_schedule()
lr = lr_cycling_schedule.get_lr()
cycle = np.floor(1 + i / (2 * step_size))
x = np.abs(i / step_size - 2 * cycle + 1)
lr_true = lr_initial + (lr_max - lr_initial) * np.maximum(0, (1 - x))
np.testing.assert_array_equal(lr, lr_true)
lr_list.append(lr)
lr_true_list.append(lr_true)
print("test_lr_cycling_schedule passed")
# plt.plot(lr_list)
# plt.plot(lr_true_list)