-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
Copy pathoptimizer.py
65 lines (55 loc) · 2.49 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import mxnet as mx
import mxnet.optimizer as optimizer
from mxnet.ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as
NDabs)
#from mxnet.ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
# mp_sgd_update, mp_sgd_mom_update, square, ftrl_update)
class ONadam(optimizer.Optimizer):
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
schedule_decay=0.004,
**kwargs):
super(ONadam, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.schedule_decay = schedule_decay
self.m_schedule = 1.
def create_state(self, index, weight):
return (
zeros(weight.shape, weight.context, dtype=weight.dtype), # mean
zeros(weight.shape, weight.context,
dtype=weight.dtype)) # variance
def update(self, index, weight, grad, state):
assert (isinstance(weight, NDArray))
assert (isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]
# preprocess grad
#grad = grad * self.rescale_grad + wd * weight
grad *= self.rescale_grad + wd * weight
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
# warming momentum schedule
momentum_t = self.beta1 * (1. - 0.5 *
(pow(0.96, t * self.schedule_decay)))
momentum_t_1 = self.beta1 * (1. - 0.5 *
(pow(0.96,
(t + 1) * self.schedule_decay)))
self.m_schedule = self.m_schedule * momentum_t
m_schedule_next = self.m_schedule * momentum_t_1
# update m_t and v_t
m_t, v_t = state
m_t[:] = self.beta1 * m_t + (1. - self.beta1) * grad
v_t[:] = self.beta2 * v_t + (1. - self.beta2) * grad * grad
grad_prime = grad / (1. - self.m_schedule)
m_t_prime = m_t / (1. - m_schedule_next)
v_t_prime = v_t / (1. - pow(self.beta2, t))
m_t_bar = (1. - momentum_t) * grad_prime + momentum_t_1 * m_t_prime
# update weight
weight[:] -= lr * m_t_bar / (sqrt(v_t_prime) + self.epsilon)