-
Notifications
You must be signed in to change notification settings - Fork 108
/
Copy pathfcgan-tutorial.py
141 lines (97 loc) · 4.04 KB
/
fcgan-tutorial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import matplotlib.pyplot as plt
import numpy as np
from keras.datasets import cifar10, mnist
from keras.layers import (BatchNormalization, Conv2D, Conv2DTranspose, Dense,
Dropout, Flatten, Input, Reshape, UpSampling2D,
ZeroPadding2D)
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Model, Sequential
from keras.optimizers import Adam
from PIL import Image, ImageDraw
# Consistent results
np.random.seed(10)
# The dimension of z
noise_dim = 100
batch_size = 16
steps_per_epoch = 3750 # 60000 / 16
epochs = 10
save_path = 'fcgan-images'
img_rows, img_cols, channels = 28, 28, 1
optimizer = Adam(0.0002, 0.5)
# Create path for saving images
if not os.path.isdir(save_path):
os.mkdir(save_path)
# Load and pre-process data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Normalize to between -1 and 1
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
x_train = x_train.reshape(-1, img_rows*img_cols*channels)
def create_generator():
generator = Sequential()
generator.add(Dense(256, input_dim=noise_dim))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(img_rows*img_cols*channels, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=optimizer)
return generator
def create_descriminator():
discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=img_rows*img_cols*channels))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
return discriminator
discriminator = create_descriminator()
generator = create_generator()
# Make the discriminator untrainable when we are training the generator. This doesn't effect the discriminator by itself
discriminator.trainable = False
# Link the two models to create the GAN
gan_input = Input(shape=(noise_dim,))
fake_image = generator(gan_input)
gan_output = discriminator(fake_image)
gan = Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)
# Display images, and save them if the epoch number is specified
def show_images(noise, epoch=None):
generated_images = generator.predict(noise)
plt.figure(figsize=(10, 10))
for i, image in enumerate(generated_images):
plt.subplot(10, 10, i+1)
if channels == 1:
plt.imshow(image.reshape((img_rows, img_cols)), cmap='gray')
else:
plt.imshow(image.reshape((img_rows, img_cols, channels)))
plt.axis('off')
plt.tight_layout()
if epoch != None:
plt.savefig(f'{save_path}/gan-images_epoch-{epoch}.png')
# Constant noise for viewing how the GAN progresses
static_noise = np.random.normal(0, 1, size=(100, noise_dim))
for epoch in range(epochs):
for batch in range(steps_per_epoch):
noise = np.random.normal(0, 1, size=(batch_size, noise_dim))
fake_x = generator.predict(noise)
real_x = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]
x = np.concatenate((real_x, fake_x))
disc_y = np.zeros(2*batch_size)
disc_y[:batch_size] = 0.9
d_loss = discriminator.train_on_batch(x, disc_y)
y_gen = np.ones(batch_size)
g_loss = gan.train_on_batch(noise, y_gen)
print(f'Epoch: {epoch} \t Discriminator Loss: {d_loss} \t\t Generator Loss: {g_loss}')
if epoch % 10 == 0:
show_images(static_noise, epoch)
frames = []
for image in os.listdir(save_path):
frames.append(Image.open(save_path + '/' + image))
frames[0].save('gan_training.gif', format='GIF', append_images=frames[1:], save_all=True, duration=500, loop=0)
discriminator.save('fcdiscriminator.hdf5')
generator.save('fcgenerator.hdf5')