-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathgenerator.py
96 lines (69 loc) · 2.36 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn
import torchvision.models as models
import numpy as np
SEED = 3
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
class GeneratorNet(nn.Module):
def __init__(self, num_classes=80, norm=True, scale=True):
super(GeneratorNet, self).__init__()
self.add_info = AddInfo()
self.generator = Generator()
self.fc = nn.Linear(1024, num_classes, bias=False)
self.s = nn.Parameter(torch.FloatTensor([10]))
def forward(self, B1=None, B2=None, A=None, classifier=False):
add_info = self.add_info(A, B1, B2)
A_rebuild = self.generator(add_info)
A_rebuild = self.l2_norm(A_rebuild)
score = self.fc(A_rebuild*self.s)
return A_rebuild, score
def weight_norm(self):
w = self.fc.weight.data
norm = w.norm(p=2, dim=1, keepdim=True)
self.fc.weight.data = w.div(norm.expand_as(w))
def l2_norm(self,input):
input_size = input.size()
buffer = torch.pow(input, 2)
norm = torch.sum(buffer, 1).add_(1e-10)
norm = torch.sqrt(norm)
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
output = _output.view(input_size)
return output
class AddInfo(nn.Module):
def __init__(self):
super(AddInfo, self).__init__()
self.fc = nn.Linear(1024, 2048)
self.leakyRelu1 = nn.LeakyReLU(0.2, inplace=True)
self.dropout = nn.Dropout(0.5)
def forward(self, A, B1, B2):
A = self.fc(A)
A = self.leakyRelu1(A)
B1 = self.fc(B1)
B1 = self.leakyRelu1(B1)
B2 = self.fc(B2)
B2 = self.leakyRelu1(B2)
out = A+(B1-B2)
# print(torch.abs(B1-B2))
out = self.dropout(out)
return out
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc = nn.Linear(2048, 1024)
self.leakyRelu = nn.LeakyReLU(0.2, inplace=True)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
out = self.fc(x)
out = self.leakyRelu(out)
out = self.dropout(out)
return out
if __name__ == '__main__':
A = torch.ones((64,1024))
B1 = torch.ones((64,1024))
B2 = torch.ones((64,1024))
gen_net = GeneratorNet()
gen_data, score = gen_net(A, B1, B2)
print(gen_data.shape)
print(score.shape)