-
Notifications
You must be signed in to change notification settings - Fork 1
/
optim.py
176 lines (134 loc) · 6.88 KB
/
optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import jax
import jax.numpy as jnp
import numpy as np
import copy
from util import normal_like_tree
from typing import NamedTuple
class TrainState(NamedTuple):
"""
collects the all the state required for neural network training
"""
optstate: dict
netstate: None
rngkey: None
def build_sgd_optimizer(lossgrad,
learningrate : float,
momentum : float,
wdecay : float):
def init(weightinit, netstate, rngkey):
optstate = dict()
optstate['w'] = copy.deepcopy(weightinit)
optstate['gm'] = jax.tree_map(lambda p : jnp.zeros(shape=p.shape), weightinit)
optstate['alpha'] = learningrate
return TrainState(optstate = optstate,
netstate = netstate,
rngkey = rngkey)
def step(trainstate, minibatch, lrfactor):
optstate = trainstate.optstate
(loss, netstate), grad = lossgrad(optstate['w'], trainstate.netstate, minibatch, is_training=True)
# momentum
optstate['gm'] = jax.tree_map(
lambda gm, g, w: momentum * gm + g + wdecay * w, optstate['gm'], grad, optstate['w'])
# weight update
optstate['w'] = jax.tree_map(lambda p, gm: p - learningrate * lrfactor * gm, optstate['w'], optstate['gm'])
newtrainstate = trainstate._replace(
optstate = optstate,
netstate = netstate)
return newtrainstate, loss
return init, step
def build_sam_optimizer(lossgrad,
learningrate : float,
momentum : float,
wdecay : float,
rho : float,
msharpness : int):
def init(weightinit, netstate, rngkey):
optstate = dict()
optstate['w'] = copy.deepcopy(weightinit)
optstate['gm'] = jax.tree_map(lambda p : jnp.zeros(shape=p.shape), weightinit)
optstate['alpha'] = learningrate
return TrainState(optstate = optstate,
netstate = netstate,
rngkey = rngkey)
def _sam_gradient(trainstate, X_subbatch, y_subbatch):
(_, netstate), grad = lossgrad(trainstate.optstate['w'], trainstate.netstate, (X_subbatch, y_subbatch), is_training = True)
grad_norm = jnp.sqrt(sum([jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(grad)]))
perturbed_params = jax.tree_map(lambda p, g: p + rho * g / grad_norm, trainstate.optstate['w'], grad)
(loss, netstate), perturbed_grad= lossgrad(perturbed_params, netstate, (X_subbatch, y_subbatch), is_training = True)
return perturbed_grad, netstate, loss
def step(trainstate, minibatch, lrfactor):
optstate = trainstate.optstate
# split batch to simulate m-sharpness on one GPU
X_batch = minibatch[0].reshape(msharpness, -1, *minibatch[0].shape[1:])
y_batch = minibatch[1].reshape(msharpness, -1, *minibatch[1].shape[1:])
grad, netstate, loss = jax.vmap(_sam_gradient, in_axes=(None, 0, 0))(trainstate, X_batch, y_batch)
grad = jax.tree_map(lambda g : jnp.mean(g, axis=0), grad)
netstate = jax.tree_map(lambda p : p[0], netstate)
loss = jnp.mean(loss)
# momentum
optstate['gm'] = jax.tree_map(
lambda gm, g, w: momentum * gm + g + wdecay * w, optstate['gm'], grad, optstate['w'])
# weight update
optstate['w'] = jax.tree_map(lambda p, gm: p - learningrate * lrfactor * gm, optstate['w'], optstate['gm'])
newtrainstate = trainstate._replace(
optstate = optstate,
netstate = netstate)
return newtrainstate, loss
return init, step
def build_bsam_optimizer(lossgrad,
learningrate : float,
beta1 : float,
beta2 : float,
wdecay : float,
rho : float,
msharpness : int,
Ndata : int,
s_init : float,
damping : float):
def init(weightinit, netstate, rngkey):
optstate = dict()
optstate['w'] = copy.deepcopy(weightinit)
optstate['gm'] = jax.tree_map(lambda p : jnp.zeros(shape=p.shape), weightinit)
optstate['alpha'] = learningrate
optstate['s'] = jax.tree_map(lambda p : s_init * jnp.ones(shape=p.shape), weightinit)
return TrainState(optstate = optstate,
netstate = netstate,
rngkey = rngkey)
def _bsam_gradient(trainstate, X_subbatch, y_subbatch, rngkey):
optstate = trainstate.optstate
# noisy sample
noise, _ = normal_like_tree(optstate['w'], rngkey)
noisy_param = jax.tree_map(lambda n, mu, s: mu + \
jnp.sqrt(1.0 / (Ndata * s)) * n, noise, optstate['w'], optstate['s'])
# gradient at noisy sample
(_, netstate), grad = lossgrad(noisy_param, trainstate.netstate, (X_subbatch, y_subbatch), is_training = True)
perturbed_params = jax.tree_map(lambda p, g, s: p + rho * g / s, optstate['w'], grad, optstate['s'])
(loss, netstate), perturbed_grad = lossgrad(perturbed_params, netstate, (X_subbatch, y_subbatch), is_training = True)
gs = jax.tree_map(lambda g, s: jnp.sqrt(s * (g ** 2.0)), grad, optstate['s'])
return gs, perturbed_grad, netstate, loss
def step(trainstate, minibatch, lrfactor):
optstate = trainstate.optstate
rngkey = trainstate.rngkey
# split batch to simulate m-sharpness on one GPU
rngkeys = jax.random.split(rngkey, msharpness + 1)
X_batch = minibatch[0].reshape(msharpness, -1, *minibatch[0].shape[1:])
y_batch = minibatch[1].reshape(msharpness, -1, *minibatch[1].shape[1:])
gs, grad, netstate, loss = jax.vmap(_bsam_gradient, in_axes=(None, 0, 0, 0))(trainstate, X_batch, y_batch, rngkeys[0:msharpness])
gs = jax.tree_map(lambda g : jnp.mean(g, axis=0), gs)
grad = jax.tree_map(lambda g : jnp.mean(g, axis=0), grad)
netstate = jax.tree_map(lambda p : p[0], netstate)
loss = jnp.mean(loss)
# momentum
optstate['gm'] = jax.tree_map(
lambda gm, g, w: beta1 * gm + (1 - beta1) * (g + wdecay * w), optstate['gm'], grad, optstate['w'])
# weight update
optstate['w'] = jax.tree_map(lambda p, gm, s: p - learningrate * lrfactor * gm / s, optstate['w'], optstate['gm'], optstate['s'])
# update precision
optstate['s'] = jax.tree_map(lambda s, gs: beta2 * s + (1 - beta2) * (gs + damping + wdecay),
optstate['s'], gs)
newtrainstate = trainstate._replace(
optstate = optstate,
netstate = netstate,
rngkey = rngkeys[-1])
return newtrainstate, loss
return init, step