-
Notifications
You must be signed in to change notification settings - Fork 20
/
utils.py
59 lines (45 loc) · 1.64 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.autograd import Variable
from mpl_toolkits.axes_grid1 import ImageGrid
from torchvision.transforms import Compose, ToTensor
# compose a transform configuration
transform_config = Compose([ToTensor()])
def mse_loss(input, target):
return torch.sum((input - target).pow(2)) / input.data.nelement()
def l1_loss(input, target):
return torch.sum(torch.abs(input - target)) / input.data.nelement()
def reparameterize(training, mu, logvar):
if training:
std = logvar.mul(0.5).exp_()
eps = Variable(std.data.new(std.size()).normal_())
return eps.mul(std).add_(mu)
else:
return mu
def weights_init(layer):
if isinstance(layer, nn.Conv2d):
layer.weight.data.normal_(0.0, 0.05)
layer.bias.data.zero_()
elif isinstance(layer, nn.BatchNorm2d):
layer.weight.data.normal_(1.0, 0.02)
layer.bias.data.zero_()
elif isinstance(layer, nn.Linear):
layer.weight.data.normal_(0.0, 0.05)
layer.bias.data.zero_()
def imshow_grid(images, shape=[2, 8], name='default', save=False):
"""
Plot images in a grid of a given shape.
Initial code from: https://github.com/pumpikano/tf-dann/blob/master/utils.py
"""
fig = plt.figure(1)
grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)
size = shape[0] * shape[1]
for i in range(size):
grid[i].axis('off')
grid[i].imshow(images[i]) # The AxesGrid object work as a list of axes.
if save:
plt.savefig('reconstructed_images/' + str(name) + '.png')
plt.clf()
else:
plt.show()