-
Notifications
You must be signed in to change notification settings - Fork 1
/
models.py
159 lines (139 loc) · 5.65 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
""" Various neural network models in haiku
Inspired by:
https://github.com/izmailovpavel/neurips_bdl_starter_kit/blob/main/jax_models.py
"""
import haiku as hk
import jax
import jax.numpy as jnp
from haiku.initializers import Constant
import functools
he_normal = hk.initializers.VarianceScaling(2.0, 'fan_in', 'truncated_normal')
_DEFAULT_BN_CONFIG = {
'decay_rate': 0.9,
'eps': 1e-5,
'create_scale': True,
'create_offset': True
}
class FilterResponseNorm(hk.Module):
def __init__(self, eps=1e-6, name='frn'):
super().__init__(name=name)
self.eps = eps
def __call__(self, x, **unused_kwargs):
del unused_kwargs
par_shape = (1, 1, 1, x.shape[-1]) # [1,1,1,C]
tau = hk.get_parameter('tau', par_shape, x.dtype, init=jnp.zeros)
beta = hk.get_parameter('beta', par_shape, x.dtype, init=jnp.zeros)
gamma = hk.get_parameter('gamma', par_shape, x.dtype, init=jnp.ones)
nu2 = jnp.mean(jnp.square(x), axis=[1, 2], keepdims=True)
x = x * jax.lax.rsqrt(nu2 + self.eps)
y = gamma * x + beta
z = jnp.maximum(y, tau)
return z
def _resnet_layer(
inputs, num_filters, normalization_layer, kernel_size=3, strides=1,
activation=lambda x: x, use_bias=True, is_training=True
):
x = inputs
x = hk.Conv2D(
num_filters, kernel_size, stride=strides, padding='same',
w_init=he_normal, with_bias=use_bias)(x)
x = normalization_layer()(x, is_training=is_training)
x = activation(x)
return x
def make_resnet_fn(
num_classes: int,
depth: int,
normalization_layer,
width: int = 16,
use_bias: bool = True,
activation=jax.nn.relu,
):
num_res_blocks = (depth - 2) // 6
if (depth - 2) % 6 != 0:
raise ValueError('depth must be 6n+2 (e.g. 20, 32, 44).')
def forward(x):
num_filters = width
x = _resnet_layer(
x, num_filters=num_filters, activation=activation,
use_bias=use_bias,
normalization_layer=normalization_layer
)
for stack in range(3):
for res_block in range(num_res_blocks):
strides = 1
if stack > 0 and res_block == 0: # first layer but not first stack
strides = 2 # downsample
y = _resnet_layer(
x, num_filters=num_filters, strides=strides,
activation=activation,
use_bias=use_bias, is_training=True,
normalization_layer=normalization_layer)
y = _resnet_layer(
y, num_filters=num_filters, use_bias=use_bias,
is_training=True,
normalization_layer=normalization_layer)
if stack > 0 and res_block == 0: # first layer but not first stack
# linear projection residual shortcut connection to match changed dims
x = _resnet_layer(
x, num_filters=num_filters, kernel_size=1,
strides=strides,
use_bias=use_bias, is_training=True,
normalization_layer=normalization_layer)
x = activation(x + y)
num_filters *= 2
x = hk.AvgPool((8, 8, 1), 8, 'VALID')(x)
x = hk.Flatten()(x)
logits = hk.Linear(num_classes, w_init=he_normal)(x)
return logits
return forward
def make_resnet20_frn_fn(num_classes, activation=jax.nn.relu):
return make_resnet_fn(
num_classes, depth=20, normalization_layer=FilterResponseNorm,
activation=activation)
def make_mlp_fn(output_dim, layer_dims, nonlinearity = jax.nn.elu):
biasinit = Constant(0.05)
def forward(inp):
out = hk.Flatten()(inp)
for layer_dim in layer_dims:
out = hk.Linear(layer_dim, b_init=biasinit)(out)
out = nonlinearity(out)
return hk.Linear(output_dim, b_init=biasinit)(out)
return forward
def make_cnn_fn(output_dim, width=4, nonlinearity = jax.nn.elu):
def forward(x):
biasinit = Constant(0.05)
cnn = hk.Sequential([
hk.Conv2D(output_channels=32 * width, kernel_shape=5, padding="SAME"),
nonlinearity,
hk.MaxPool(window_shape=3, strides=2, padding="VALID"),
hk.Conv2D(output_channels=64 * width, kernel_shape=5, padding="SAME"),
nonlinearity,
hk.MaxPool(window_shape=3, strides=2, padding="VALID"),
hk.Conv2D(output_channels=128 * width, kernel_shape=5, padding="SAME"),
nonlinearity,
hk.MaxPool(window_shape=3, strides=2, padding="VALID"),
hk.Flatten(),
hk.Linear(128 * width, b_init = biasinit),
nonlinearity,
hk.Linear(64 * width, b_init = biasinit),
nonlinearity,
hk.Linear(output_dim, b_init = biasinit),
])
return cnn(x)
return forward
def get_model(model_name, num_classes, **kwargs):
_MODEL_FNS = {
"resnet20": functools.partial(
make_resnet20_frn_fn, activation=lambda x: x),
"cnn": functools.partial(
make_cnn_fn, width=4, nonlinearity=jax.nn.elu),
"mlp": functools.partial(
make_mlp_fn, layer_dims=[1024, 512, 256, 256, 256], nonlinearity=jax.nn.tanh),
"mlptiny": functools.partial(
make_mlp_fn, layer_dims=[128, 64, 16, 16, 16], nonlinearity=jax.nn.tanh)
}
if model_name not in _MODEL_FNS.keys():
raise NameError('Available keys:', _MODEL_FNS.keys())
net_fn = _MODEL_FNS[model_name](num_classes, **kwargs)
net = hk.transform(net_fn)
return net.apply, net.init