-
Notifications
You must be signed in to change notification settings - Fork 0
/
discriminator.py
99 lines (91 loc) · 2.34 KB
/
discriminator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch.nn as nn
from cam import CAMAttention
class ConvBlock(nn.Sequential):
def __init__(self,
in_channels: int,
out_channels: int,
reflection_padding: int,
kernel_size: int,
stride: int,
padding: int,
bias: bool,
act: nn.Module = nn.LeakyReLU(0.2, True)
):
super().__init__(
nn.ReflectionPad2d(reflection_padding),
nn.utils.spectral_norm(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
)
),
act
)
class Discriminator(nn.Module):
def __init__(self, in_channels, ndf=64, num_layers=5):
super(Discriminator, self).__init__()
# encoder
enc = [
ConvBlock(
in_channels=in_channels,
out_channels=ndf,
reflection_padding=1,
kernel_size=4,
stride=2,
padding=0,
bias=True
)
]
for i in range(1, num_layers - 2):
mult = 2 ** (i - 1)
enc += [
ConvBlock(
in_channels=ndf * mult,
out_channels=ndf * mult * 2,
reflection_padding=1,
kernel_size=4,
stride=2,
padding=0,
bias=True
)
]
mult = 2 ** (num_layers - 2 - 1)
enc += [
ConvBlock(
in_channels=ndf * mult,
out_channels=ndf * mult * 2,
reflection_padding=1,
kernel_size=4,
stride=1,
padding=0,
bias=True
)
]
self.enc = nn.Sequential(*enc)
# attention
mult = 2 ** (num_layers - 2)
self.cam = CAMAttention(
channels=ndf * mult,
act=nn.LeakyReLU(0.2, True),
spectral_norm=True
)
# head
self.out = ConvBlock(
in_channels=ndf * mult,
reflection_padding=1,
out_channels=1,
kernel_size=4,
stride=1,
padding=0,
bias=False,
act=nn.Identity()
)
def forward(self, x):
x = self.enc(x)
x, cam_logit, heatmap = self.cam(x)
x = self.out(x)
return x, cam_logit, heatmap