forked from pabloswfly/genomcmcgan
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdiscriminator.py
188 lines (155 loc) · 6.05 KB
/
discriminator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from symmetric import Symmetric
class Discriminator(nn.Module):
def __init__(self):
"""Build the discriminator architecture"""
super(Discriminator, self).__init__()
# 1 because it is only 1 channel in a tensor (N, C, H, W)
self.batch1 = nn.BatchNorm2d(1)
self.conv1 = nn.Conv2d(
in_channels=1,
out_channels=64,
kernel_size=(1, 5),
stride=(1, 2),
padding=(0, 2),
bias=False,
)
self.batch2 = nn.BatchNorm2d(64)
self.symm1 = Symmetric("sum", 2)
self.conv2 = nn.Conv2d(
in_channels=64,
out_channels=128,
kernel_size=(1, 5),
stride=(1, 2),
padding=(0, 2),
bias=False,
)
# self.dropout1 = nn.Dropout2d(0.25)
self.batch3 = nn.BatchNorm2d(128)
self.symm2 = Symmetric("sum", 3)
self.fc1 = nn.Linear(128, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 1)
# x represents our data
def forward(self, x):
"""Mark the flow of data throughout the network"""
x = self.batch1(x)
x = self.conv1(x)
x = F.relu(x)
x = self.batch2(x)
x = self.symm1(x)
x = self.conv2(x)
x = F.relu(x)
# x = self.dropout1(x)
x = self.batch3(x)
x = self.symm2(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
output = torch.sigmoid(x)
return output
def weights_init(self, m):
"""Reset parameters and initialize with random weight values"""
if isinstance(m, nn.Conv2d):
# DCGAN paper says all model should be initialized like this
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
if isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
def get_accuracy(self, y_true, y_prob):
"""Compute model accuracy over labelled data"""
y_true = y_true.squeeze()
y_prob = y_prob.squeeze()
y_prob = y_prob > 0.5
return (y_true == y_prob).sum().item() / y_true.size(0)
def fit(self, *, trainflow, valflow, epochs, lr, device, model_selection=False):
"""Train the discriminator model with the Binary Cross-Entropy loss.
trainflow: PyTorch data loader for the training dataset
valflow: PyTorch data loader for the validation dataset
epochs: Number of iterations through the training dataset
lr: Learning rate for gradient descent with Adam
"""
print("Training discriminator")
optimizer = torch.optim.Adam(self.parameters(), lr)
lossf = nn.BCELoss()
best_val_loss = 1.0
self.train()
# Loop over the dataset multiple times
for epoch in range(epochs):
train_loss, val_loss, acc_train, acc_val = 0.0, 0.0, 0.0, 0.0
# For each batch of training data
for i, (inputs, labels) in enumerate(trainflow, 1):
inputs = inputs.to(device)
labels = labels.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Compute model predictions, compute loss and perform back-prop
out = self(inputs)
loss = lossf(out, labels)
loss.mean().backward()
optimizer.step()
# Print statistics
train_loss += loss.item()
acc_train += self.get_accuracy(labels, out)
if i % 20 == 0: # print every 20 mini-batches
print(
"[%d | %d] TRAINING: loss: %.3f | acc: %.3f"
% (
epoch + 1,
i,
train_loss / i,
acc_train / i,
),
end="\r",
)
print("")
# Calculate stats on validation data with no gradient descent
train_acc = acc_train / len(trainflow)
with torch.no_grad():
# For each batch of validation data
for j, (genmats, labels) in enumerate(valflow, 1):
genmats = genmats.to(device)
labels = labels.to(device)
# Compute model predictions, compute loss and stats
preds = self(genmats)
val_loss += lossf(preds, labels).item()
acc_val += self.get_accuracy(labels, preds)
print(
" VALIDATION: loss: %.3f - acc: %.3f"
% (
val_loss / j,
acc_val / j,
),
end="\r",
)
# Save the model weights with the lowest validation error
if (val_loss / j) < best_val_loss:
best_val_loss = val_loss / j
best_train_acc = acc_train / len(trainflow)
best_epoch = epoch + 1
best_model = copy.deepcopy(self.state_dict())
print("")
# Load the model with the lowest validation error
if model_selection:
self.load_state_dict(best_model)
print(
f"Best model has validation loss {best_val_loss:.3f} from {best_epoch}"
)
train_acc = best_train_acc
return train_acc
def predict(self, inputs):
"""Compute model prediction over inputs"""
self.eval()
with torch.no_grad():
preds = self(inputs)
return preds