-
Notifications
You must be signed in to change notification settings - Fork 13
/
BigGAN.py
executable file
·371 lines (257 loc) · 12.1 KB
/
BigGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
import time
from tensorflow.contrib.opt import MovingAverageOptimizer
from ops import *
from utils import *
import logging
logger = logging.getLogger(__name__)
class BigGAN(object):
def __init__(self, args):
pass
##################################################################################
# Generator
##################################################################################
def generator(self, params, z, labels, is_training=True, reuse=False, getter=None):
logger.debug("generator")
cross_device = params['use_tpu']
with tf.variable_scope("generator", reuse=reuse, custom_getter=getter):
# 6
if params['z_dim'] == 128:
split_dim = 20
split_dim_remainder = params['z_dim'] - (split_dim * 5)
z_split = tf.split(z, num_or_size_splits=[split_dim] * 5 + [split_dim_remainder], axis=-1)
else:
split_dim = params['z_dim'] // 6
split_dim_remainder = params['z_dim'] - (split_dim * 6)
if split_dim_remainder == 0 :
z_split = tf.split(z, num_or_size_splits=[split_dim] * 6, axis=-1)
else :
z_split = tf.split(z, num_or_size_splits=[split_dim] * 5 + [split_dim_remainder], axis=-1)
ch = 16 * params['ch']
sn = params['sn']
x = fully_connected(z_split[0], units=4 * 4 * ch, sn=sn, scope='dense')
x = tf.reshape(x, shape=[-1, 4, 4, ch])
for i in range(params['layers']):
x_size = x.shape[-2]
if params['use_label_cond']:
cond = tf.concat([z_split[i], labels], axis=-1)
else:
cond = z_split[i]
x = resblock_up_condition(x, cond, channels=ch, use_bias=False, is_training=is_training, cross_device=cross_device, sn=sn, scope=f"resblock_up_w{x_size}_ch{ch//params['ch']}")
x_size = x.shape[-2]
if x_size in params['self_attn_res']:
x = self_attention_2(x, channels=ch, sn=sn, scope=f"self_attention_w{x_size}_ch{ch//params['ch']}")
ch = ch // 2
ch = ch * 2
x = batch_norm(x, is_training, cross_device=cross_device)
x = relu(x)
x = conv(x, channels=params['img_ch'], kernel=3, stride=1, pad=1, use_bias=False, sn=sn, scope='G_logit')
x = tanh(x)
# Crop down to expected size if spare pixels
if x.shape[1] > params['img_size']:
logger.warning(f"Cropping off {x.shape[1] - params['img_size']} pixels from width of generated images")
x = x[:,:params['img_size'],:,:]
if x.shape[2] > params['img_size']:
logger.warning(f"Cropping off {x.shape[2] - params['img_size']} pixels from height of generated images")
x = x[:,:,:params['img_size'],:]
assert x.shape[1] == params['img_size'], "Generator architecture does not fit image size"
assert x.shape[2] == params['img_size'], "Generator architecture does not fit image size"
assert x.shape[3] == params['img_ch'], "Generator architecture does not fit image channels"
logger.debug("--")
return x
##################################################################################
# Discriminator
##################################################################################
def discriminator(self, params, x, label, is_training=True, reuse=False):
logger.debug("discriminator")
with tf.variable_scope("discriminator", reuse=reuse):
ch = params['ch']
sn = params['sn']
for i in range(params['layers']):
x_size = x.shape[-2]
x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=sn, scope=f"resblock_down_w{x_size}_ch{ch//params['ch']}")
x_size = x.shape[-2]
if x_size in params['self_attn_res']:
x = self_attention_2(x, channels=ch, sn=sn, scope=f"self_attention_w{x_size}_ch{ch//params['ch']}")
ch = ch * 2
ch = ch // 2
x_size = x.shape[-2]
x = resblock(x, channels=ch, use_bias=False, is_training=is_training, sn=sn, scope=f"resblock_w{x_size}_ch{ch//params['ch']}")
x = relu(x)
x = global_sum_pooling(x)
label_embed = fully_connected(label, units=x.shape[-1], sn=sn, scope='D_label_embed')
label_proj = x * label_embed
x_scalar = fully_connected(x, units=1, sn=sn, scope='D_scalar')
output = x_scalar + tf.reduce_sum(label_proj, axis=-1)
logger.debug("--")
return output
def gradient_penalty(self, real, fake):
if self.gan_type.__contains__('dragan'):
eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.)
_, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
fake = real + 0.5 * x_std * eps
alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
interpolated = real + alpha * (fake - real)
logit = self.discriminator(interpolated, reuse=True)
grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated)
grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
GP = 0
# WGAN - LP
if self.gan_type == 'wgan-lp':
GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))
elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))
return GP
##################################################################################
# Model
##################################################################################
def base_model_fn(self, features, labels, mode, params):
'''
All the model function heavy lifting is done here, agnostic of whether
it'll be used in an Estimator or TPUEstimator
'''
params = EasyDict(**params)
# --------------------------------------------------------------------------
# Core GAN model
# --------------------------------------------------------------------------
# Because we cannot pass in labels in predict mode (despite them being useful
# for GANs), I've passed the labels in as the (otherwise unneeded) features
# it's a bit of a hack, sorry.
if mode == tf.estimator.ModeKeys.PREDICT:
labels = features
# Latent input to generate images
if mode == tf.estimator.ModeKeys.TRAIN:
z = tf.random.normal(shape=[params.batch_size, params.z_dim], name='random_z')
else:
# The "truncated normal" trick to make generated predictions nicer looking
z = tf.random.truncated_normal(shape=[params.batch_size, params.z_dim], name='random_z')
# generate and critique fake images
fake_images = self.generator(params, z, labels)
fake_logits = self.discriminator(params, fake_images, labels)
g_loss = generator_loss(params.gan_type, fake=fake_logits)
# Train the discriminator
if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]:
real_logits = self.discriminator(params, features, labels, reuse=True)
if params.gan_type.__contains__('wgan') or params.gan_type == 'dragan':
GP = self.gradient_penalty(real=features, fake=fake_images)
else:
GP = 0
d_loss = discriminator_loss(params.gan_type, real=real_logits, fake=fake_logits) + GP
else:
d_loss = 0
# --------------------------------------------------------------------------
# Vars for training and evaluation
# --------------------------------------------------------------------------
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]
# --------------------------------------------------------------------------
# Averaging var values can help with eval/prediction
# http://ruishu.io/2017/11/22/ema/
# --------------------------------------------------------------------------
ema = tf.train.ExponentialMovingAverage(decay=params['moving_decay'])
def ema_getter(getter, name, *args, **kwargs):
var = getter(name, *args, **kwargs)
ema_var = ema.average(var)
return ema_var if ema_var is not None else var
# --------------------------------------------------------------------------
# Loss
# --------------------------------------------------------------------------
if mode != tf.estimator.ModeKeys.PREDICT:
loss = g_loss
for i in range(params.n_critic):
loss += d_loss
else:
loss = 0
# --------------------------------------------------------------------------
# Training op
# --------------------------------------------------------------------------
if mode == tf.estimator.ModeKeys.TRAIN:
# Create training ops for both D and G
d_optimizer = tf.train.AdamOptimizer(params.d_lr, beta1=params.beta1, beta2=params.beta2)
if params.use_tpu:
d_optimizer = tf.contrib.tpu.CrossShardOptimizer(d_optimizer)
d_train_op = d_optimizer.minimize(d_loss, var_list=d_vars, global_step=tf.train.get_global_step())
g_optimizer = tf.train.AdamOptimizer(params.g_lr, beta1=params.beta1, beta2=params.beta2)
if params.use_tpu:
g_optimizer = tf.contrib.tpu.CrossShardOptimizer(g_optimizer)
g_train_op = g_optimizer.minimize(g_loss, var_list=g_vars, global_step=tf.train.get_global_step())
# For each training op of G, do n_critic training ops of D
train_ops = [g_train_op]
for i in range(params.n_critic):
train_ops.append(d_train_op)
train_op = tf.group(*train_ops)
with tf.control_dependencies([train_op]):
# Create the shadow variables, and add ops to maintain moving averages
# of var0 and var1. This also creates an op that will update the moving
# averages after each training step. This is what we will use in place
# of the usual training op.
train_op = ema.apply(g_vars)
else:
train_op = None
# --------------------------------------------------------------------------
# Predictions
# --------------------------------------------------------------------------
predict_fake_images = self.generator(params, z, labels, reuse=True, getter=ema_getter)
predictions = {
"fake_image": predict_fake_images,
"labels": labels,
}
# --------------------------------------------------------------------------
# Eval metrics
# --------------------------------------------------------------------------
if mode == tf.estimator.ModeKeys.EVAL:
# Hack to allow it out of a fixed batch size TPU
d_loss_batched = tf.tile(tf.expand_dims(d_loss, 0), [params.batch_size])
g_loss_batched = tf.tile(tf.expand_dims(g_loss, 0), [params.batch_size])
d_grad = tf.gradients(d_loss, d_vars)
g_grad = tf.gradients(g_loss, g_vars)
d_grad_joined = tf.concat([
tf.reshape(i, [-1]) for i in d_grad
], axis=-1)
g_grad_joined = tf.concat([
tf.reshape(i, [-1]) for i in g_grad
], axis=-1)
def metric_fn(d_loss, g_loss, fake_logits, d_grad, g_grad):
return {
"d_loss" : tf.metrics.mean(d_loss),
"g_loss" : tf.metrics.mean(g_loss),
"fake_logits" : tf.metrics.mean(fake_logits),
"d_grad" : tf.metrics.mean(d_grad),
"g_grad" : tf.metrics.mean(g_grad),
}
metric_fn_args = [d_loss_batched, g_loss_batched, fake_logits, d_grad_joined, g_grad_joined]
else:
metric_fn = None
metric_fn_args = None
# --------------------------------------------------------------------------
# Alright, all built!
# --------------------------------------------------------------------------
return loss, train_op, predictions, metric_fn, metric_fn_args
def gpu_model_fn(self, features, labels, mode, params):
loss, train_op, predictions, metric_fn, metric_fn_args = self.base_model_fn(features, labels, mode, params)
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
eval_metric_ops=metric_fn(*metric_fn_args)
)
if mode == tf.estimator.ModeKeys.TRAIN:
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
def tpu_model_fn(self, features, labels, mode, params):
loss, train_op, predictions, metric_fn, metric_fn_args = self.base_model_fn(features, labels, mode, params)
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.contrib.tpu.TPUEstimatorSpec(mode, predictions=predictions)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.contrib.tpu.TPUEstimatorSpec(
mode=mode,
loss=loss,
eval_metrics=(
metric_fn,
metric_fn_args
)
)
if mode == tf.estimator.ModeKeys.TRAIN:
return tf.contrib.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op)