-
Notifications
You must be signed in to change notification settings - Fork 0
/
gan.py
197 lines (154 loc) · 7.21 KB
/
gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import os
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm # this is used for progress bars
import numpy as np
import torch.nn as nn
import torch
import argparse
import imageio
import time
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", default=100, type=int, help="Number of epochs to run the network for")
parser.add_argument("--latent_dim", default=64, type=int, help="The size of the noise vector to pass in the Generator")
parser.add_argument("--batch_size", default=100, type=int, help="Size of mini batch that is fed into one step of network training")
parser.add_argument("--img_size", default=28, type=int, help="Size of generated images")
parser.add_argument("--digit", default=0, type=int, help="MNIST digit to be generated")
parser.add_argument("--no_gif", default=True, action='store_false', dest='gif', help="Whether to make a gif over number of epochs or not")
parser.add_argument("--no_samples", default=True, action='store_false', dest='samples', help="Whether to generate a sample from each batch from a training epoch")
options = parser.parse_args()
print("Selected arguments to run: ")
print(options)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device: ", device)
class Generator(nn.Module):
"""
The Generator class is responsible for generating new data which must pass through the discriminator.
To generate new data, random noise input is required.
The output size in img_size x img_size or N x img_size x img_size for batch processing.
"""
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(options.latent_dim, 128),
nn.Sigmoid(),
nn.Linear(128, 256),
nn.Sigmoid(),
nn.Linear(256, 512),
nn.Sigmoid(),
nn.Linear(512, options.img_size * options.img_size),
nn.Sigmoid()
)
def forward(self, x):
img = self.model(x)
img = torch.reshape(img, (options.batch_size, options.img_size, options.img_size))
return img
class Discriminator(nn.Module):
"""
The Discrimniator class is responsible for binary classification.
Final output size is 1 or N x 1 for batch processing.
"""
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(options.img_size * options.img_size, 512),
nn.Sigmoid(),
nn.Linear(512, 256),
nn.Sigmoid(),
nn.Linear(256, 128),
nn.Sigmoid(),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, img):
is_valid = self.model(img)
return is_valid
def dataLoader(images):
""" Function is responsible for reading image files from data/mnist/trainingSet/trainingSet/{mnist_number} directory """
for path, dirs, filenames in os.walk(f"data/mnist/trainingSet/trainingSet/{options.digit}"):
for filename in tqdm(filenames):
img = cv2.imread(f"{path}/{filename}", 0)
images.append(img)
def gifMaker():
""" Function makes the output from the training of GAN into a gif for easy result viewing """
if not os.path.isdir("gifs"):
os.mkdir("gifs")
if options.samples:
if not os.path.isdir("gifs"):
os.mkdir("gifs")
gif = []
path = f"samples/{options.digit}"
for filename in range(options.epochs):
gif.append(imageio.imread(f"{path}/sample_epoch_{filename}.jpg"))
imageio.mimsave(f"gifs/{options.digit}.gif", gif, fps=24)
def writeImages(batch, epoch):
"""
Function to make our output for each epoch more meaningful as well as better represented
"""
if options.samples:
images = batch.cpu().detach().numpy() * 255
# writing this for sanity check
cv2.imwrite(f"./samples/{options.digit}/sample_epoch_{epoch}.jpg", images[0])
def train(generator, discriminator, images):
# optimizers that we will use
generater_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
# the loss function
adv_loss = nn.BCELoss()
# total number of real data points that we have
N = images.shape[0]
for epoch in range(options.epochs):
z_mini_batch = torch.Tensor(np.random.normal(0, 1, size=(options.batch_size, options.latent_dim))) # size (batch_size x laten_dim)
z_mini_batch = z_mini_batch.to(device)
random_choice = np.random.choice(np.arange(N), size=options.batch_size) # size (batch_size x 1)
real_images = torch.Tensor(np.array([images[choice] for choice in random_choice])) # size (batch_size x img_size x img_size x channels)
real_images = real_images.to(device)
# labels for real and fake data
real_data_labels = torch.Tensor(np.ones(shape=(options.batch_size, 1))) # size (batch_size x 1)
real_data_labels = real_data_labels.to(device)
fake_data_labels = torch.Tensor(np.zeros(shape=(options.batch_size, 1))) # size (batch_size x 1)
fake_data_labels = fake_data_labels.to(device)
# train the discriminator
discriminator_optimizer.zero_grad()
fake_images = generator(z_mini_batch) # size (batch_size x img_size x img_size x channels)
real_loss = adv_loss(discriminator(real_images), real_data_labels)
fake_loss = adv_loss(discriminator(fake_images), fake_data_labels)
discriminator_loss = (real_loss + fake_loss) / 2
# backprop
discriminator_loss.backward()
discriminator_optimizer.step()
# train the generator
generator.zero_grad()
z_mini_batch = torch.Tensor(np.random.normal(0, 1, size=(options.batch_size, options.latent_dim))) # size (batch_size x laten_dim)
z_mini_batch = z_mini_batch.to(device)
fake_images = generator(z_mini_batch) # size (batch_size x img_size x img_size x channels)
""" we pass real data labels here since we want the discriminator to identify these images as real """
generator_loss = adv_loss(discriminator(fake_images), real_data_labels)
# backprop
generator_loss.backward()
generater_optimizer.step()
print(f"Epoch: {epoch + 1} / {options.epochs}, Generator Loss: {generator_loss} , Discriminator Loss: {discriminator_loss}")
writeImages(fake_images, epoch)
def main():
images = []
# load the data from images
dataLoader(images)
images = np.array(images)
print("Images loaded", images.shape)
generator = Generator()
discriminator = Discriminator()
# let's go faster
generator.to(device)
discriminator.to(device)
# ensure that there is a place to store generated samples and plots
if options.samples:
if not os.path.isdir("samples"):
os.mkdir("samples")
if not os.path.isdir(f"samples/{options.digit}"):
os.mkdir(f"samples/{options.digit}")
train(generator, discriminator, images)
if options.gif:
gifMaker()
if __name__ == "__main__":
main()