-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathbatch_norm_layer.py
92 lines (79 loc) · 3.13 KB
/
batch_norm_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn as nn
class BatchNormLayer(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.zeros(1, dim))
self.beta = nn.Parameter(torch.zeros(1, dim))
self.batch_mean = None
self.batch_var = None
def forward(self, x):
if self.training:
m = x.mean(dim=0)
v = x.var(dim=0) + self.eps # torch.mean((x - m) ** 2, axis=0) + self.eps
self.batch_mean = None
else:
if self.batch_mean is None:
self.set_batch_stats_func(x)
m = self.batch_mean.clone()
v = self.batch_var.clone()
x_hat = (x - m) / torch.sqrt(v)
x_hat = x_hat * torch.exp(self.gamma) + self.beta
log_det = torch.sum(self.gamma - 0.5 * torch.log(v))
return x_hat, log_det
def backward(self, x):
if self.training:
m = x.mean(dim=0)
v = x.var(dim=0) + self.eps
self.batch_mean = None
else:
if self.batch_mean is None:
self.set_batch_stats_func(x)
m = self.batch_mean
v = self.batch_var
x_hat = (x - self.beta) * torch.exp(-self.gamma) * torch.sqrt(v) + m
log_det = torch.sum(-self.gamma + 0.5 * torch.log(v))
return x_hat, log_det
def set_batch_stats_func(self, x):
print("setting batch stats for validation")
self.batch_mean = x.mean(dim=0)
self.batch_var = x.var(dim=0) + self.eps
class BatchNorm_running(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.momentum = 0.01
self.gamma = nn.Parameter(torch.zeros(1, dim), requires_grad=True)
self.beta = nn.Parameter(torch.zeros(1, dim), requires_grad=True)
self.running_mean = torch.zeros(1, dim)
self.running_var = torch.ones(1, dim)
def forward(self, x):
if self.training:
m = x.mean(dim=0)
v = x.var(dim=0) + self.eps # torch.mean((x - m) ** 2, axis=0) + self.eps
self.running_mean *= 1 - self.momentum
self.running_mean += self.momentum * m
self.running_var *= 1 - self.momentum
self.running_var += self.momentum * v
else:
m = self.running_mean
v = self.running_var
x_hat = (x - m) / torch.sqrt(v)
x_hat = x_hat * torch.exp(self.gamma) + self.beta
log_det = torch.sum(self.gamma) - 0.5 * torch.sum(torch.log(v))
return x_hat, log_det
def backward(self, x):
if self.training:
m = x.mean(dim=0)
v = x.var(dim=0) + self.eps
self.running_mean *= 1 - self.momentum
self.running_mean += self.momentum * m
self.running_var *= 1 - self.momentum
self.running_var += self.momentum * v
else:
m = self.running_mean
v = self.running_var
x_hat = (x - self.beta) * torch.exp(-self.gamma) * torch.sqrt(v) + m
log_det = torch.sum(-self.gamma + 0.5 * torch.log(v))
return x_hat, log_det