From f584c36fd6a324fca13ee7d18a3d3bfa52c284bb Mon Sep 17 00:00:00 2001 From: Shawn Presser Date: Fri, 2 Oct 2020 11:36:57 -0700 Subject: [PATCH] Fix BigGAN-Deep generator; Support 512x512 --- BigGANdeep.py | 40 +++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/BigGANdeep.py b/BigGANdeep.py index 95763c34..ffdfc4d8 100644 --- a/BigGANdeep.py +++ b/BigGANdeep.py @@ -22,7 +22,7 @@ # Channel ratio is the ratio of class GBlock(nn.Module): def __init__(self, in_channels, out_channels, - which_conv=nn.Conv2d, which_bn=layers.bn, activation=None, + which_conv=layers.SNConv2d, which_bn=layers.bn, activation=None, upsample=None, channel_ratio=4): super(GBlock, self).__init__() @@ -50,22 +50,29 @@ def forward(self, x, y): h = self.conv1(self.activation(self.bn1(x, y))) # Apply next BN-ReLU h = self.activation(self.bn2(h, y)) - # Drop channels in x if necessary - if self.in_channels != self.out_channels: - x = x[:, :self.out_channels] - # Upsample both h and x at this point + # Upsample h if self.upsample: h = self.upsample(h) - x = self.upsample(x) # 3x3 convs h = self.conv2(h) h = self.conv3(self.activation(self.bn3(h, y))) # Final 1x1 conv h = self.conv4(self.activation(self.bn4(h, y))) + # Drop channels in x if necessary + if self.in_channels != self.out_channels: + x = x[:, :self.out_channels] + # Upsample x + if self.upsample: + x = self.upsample(x) return h + x def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'): arch = {} + arch[512] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2, 1]], + 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1, 1]], + 'upsample' : [True] * 7, + 'resolution' : [8, 16, 32, 64, 128, 256, 512], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) for i in range(3,10)}} arch[256] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2]], 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1]], 'upsample' : [True] * 6, @@ -73,13 +80,13 @@ def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'): 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) for i in range(3,9)}} arch[128] = {'in_channels' : [ch * item for item in [16, 16, 8, 4, 2]], - 'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]], + 'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]], 'upsample' : [True] * 5, 'resolution' : [8, 16, 32, 64, 128], 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) for i in range(3,8)}} arch[64] = {'in_channels' : [ch * item for item in [16, 16, 8, 4]], - 'out_channels' : [ch * item for item in [16, 8, 4, 2]], + 'out_channels' : [ch * item for item in [16, 8, 4, 2]], 'upsample' : [True] * 4, 'resolution' : [8, 16, 32, 64], 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) @@ -94,10 +101,10 @@ def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'): return arch class Generator(nn.Module): - def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=128, + def __init__(self, G_ch=128, G_depth=2, dim_z=128, bottom_width=4, resolution=512, G_kernel_size=3, G_attn='64', n_classes=1000, num_G_SVs=1, num_G_SV_itrs=1, - G_shared=True, shared_dim=0, hier=False, + G_shared=True, shared_dim=0, hier=True, cross_replica=False, mybn=False, G_activation=nn.ReLU(inplace=False), G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, @@ -213,7 +220,7 @@ def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=128 cross_replica=self.cross_replica, mybn=self.mybn), self.activation, - self.which_conv(self.arch['out_channels'][-1], 3)) + self.which_conv(self.arch['out_channels'][-1], 128)) # Initialize weights. Optionally skip init for testing. if not skip_init: @@ -265,7 +272,7 @@ def init_weights(self): def forward(self, z, y): # If hierarchical, concatenate zs and ys if self.hier: - z = torch.cat([y, z], 1) + z = torch.cat([z, y], 1) y = z # First linear layer h = self.linear(z) @@ -276,9 +283,12 @@ def forward(self, z, y): # Second inner loop in case block has multiple layers for block in blocklist: h = block(h, y) - - # Apply batchnorm-relu-conv-tanh at output - return torch.tanh(self.output_layer(h)) + # Apply batchnorm-relu-conv + h = self.output_layer(h) + # Take the rgb channels + h = h[:, :3, :, :] + # Apply final tanh at output + return torch.tanh(h) class DBlock(nn.Module): def __init__(self, in_channels, out_channels, which_conv=layers.SNConv2d, wide=True,