Skip to content

Commit

Permalink
[Update] Add SCR module
Browse files Browse the repository at this point in the history
  • Loading branch information
yeungchenwa committed Jan 27, 2024
1 parent b4ef070 commit 02cd756
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 0 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ gradio==4.8.0
pyyaml
pygame
opencv-python
info-nce-pytorch
kornia
96 changes: 96 additions & 0 deletions src/modules/SCR.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch

import torch.nn as nn
import src.modules.SCRModules as SCRModules

from info_nce import InfoNCE
import kornia.augmentation as K

class SCM(nn.Module):

def __init__(self,
temperature,
mode='training',
image_size=96):
super().__init__()
style_vgg = SCRModules.vgg
style_vgg = nn.Sequential(*list(style_vgg.children()))
self.StyleFeatExtractor = SCRModules.StyleExtractor(
encoder=style_vgg)
self.StyleFeatProjector = SCRModules.Projector()

if mode == 'training':
self.StyleFeatExtractor.requires_grad_(True)
self.StyleFeatProjector.requires_grad_(True)
else:
self.StyleFeatExtractor.requires_grad_(False)
self.StyleFeatProjector.requires_grad_(False)

# NCE Loss
self.nce_loss = InfoNCE(
temperature=temperature,
negative_mode='paired',
)

# Pos Image random resize and crop
self.patch_sampler = K.RandomResizedCrop(
(image_size, image_size),
scale=(0.8,1.0),
ratio=(0.75,1.33))

def forward(self, sample_imgs, pos_imgs, neg_imgs, nce_layers='0,1,2,3,4,5'):

# Get generated image style embedding
sample_style_embeddings = self.StyleFeatProjector(
self.StyleFeatExtractor(
sample_imgs,
nce_layers),
nce_layers) # out: N * C(2048)

# Random resize and crop for positive images
pos_imgs = self.patch_sampler(pos_imgs)
# Get positive image style embedding
pos_style_embeddings = self.StyleFeatProjector(
self.StyleFeatExtractor(
pos_imgs,
nce_layers),
nce_layers)

# Get negative image style embedding
_, num_neg, _, _, _ = neg_imgs.shape
for i in range(num_neg):
neg_imgs_once = neg_imgs[:, i, :, :]
neg_style_embeddings_once = self.StyleFeatProjector(
self.StyleFeatExtractor(
neg_imgs_once,
nce_layers),
nce_layers)
for j, layer_out in enumerate(neg_style_embeddings_once):
if j == 0:
neg_style_embeddings_mid = layer_out[None, :, :]
else:
neg_style_embeddings_mid = torch.cat(
[neg_style_embeddings_mid, layer_out[None, :, :]],
dim=0)
if i == 0:
neg_style_embeddings = neg_style_embeddings_mid[:, :, None, :]
else:
neg_style_embeddings = torch.cat(
[neg_style_embeddings, neg_style_embeddings_mid[:, :, None, :]],
dim=2)

return sample_style_embeddings, pos_style_embeddings, neg_style_embeddings

def calculate_nce_loss(self, sample_s, pos_s, neg_s):

num_layer = neg_s.shape[0]
neg_s_list = []
for i in range(num_layer):
neg_s_list.append(neg_s[i])

total_scm_loss = 0.
for layer, (sample, pos, neg) in enumerate(zip(sample_s, pos_s, neg_s_list)):
scm_loss = self.nce_loss(sample, pos, neg)
total_scm_loss += scm_loss

return total_scm_loss / num_layer
125 changes: 125 additions & 0 deletions src/modules/SCRModules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import torch
import torch.nn as nn


class StyleExtractor(nn.Module):

def __init__(self, encoder):

super(StyleExtractor, self).__init__()
enc_layers = list(encoder.children())
self.enc_1 = nn.Sequential(*enc_layers[:6]) # input -> relu1_1
self.enc_2 = nn.Sequential(*enc_layers[6:13]) # relu1_1 -> relu2_1
self.enc_3 = nn.Sequential(*enc_layers[13:20]) # relu2_1 -> relu3_1
self.enc_4 = nn.Sequential(*enc_layers[20:33]) # relu3_1 -> relu4_1
self.enc_5 = nn.Sequential(*enc_layers[33:46]) # relu4_1 -> relu5_1
self.enc_6 = nn.Sequential(*enc_layers[46:69]) # relu5_1 -> relu

self.conv1x1_0 = nn.Conv2d(128, 64, kernel_size=1, stride=1, bias=True)
self.conv1x1_1 = nn.Conv2d(256, 128, kernel_size=1, stride=1, bias=True)
self.conv1x1_2 = nn.Conv2d(512, 256, kernel_size=1, stride=1, bias=True)
self.conv1x1_3 = nn.Conv2d(1024, 512, kernel_size=1, stride=1, bias=True)
self.conv1x1_4 = nn.Conv2d(1024, 512, kernel_size=1, stride=1, bias=True)
self.conv1x1_5 = nn.Conv2d(1024, 512, kernel_size=1, stride=1, bias=True)
self.relu = nn.ReLU(True)

def encode_with_intermediate(self, input):
results = [input]
for i in range(6):
func = getattr(self, 'enc_{:d}'.format(i + 1))
results.append(func(results[-1]))
return results[1:]

def forward(self, input, index):

feats = self.encode_with_intermediate(input)
codes = []
for x in index.split(','):
code = feats[int(x)].clone()
gap = torch.nn.functional.adaptive_avg_pool2d(code, (1,1))
gmp = torch.nn.functional.adaptive_max_pool2d(code, (1,1))
conv1x1 = getattr(self, 'conv1x1_{:d}'.format(int(x)))
code = torch.cat([gap, gmp], 1)
code = self.relu(conv1x1(code))
codes.append(code)
return codes


class Projector(nn.Module):
def __init__(self,):
super(Projector, self).__init__()
self.projector0 = nn.Sequential(
nn.Linear(64, 1024),
nn.ReLU(True),
nn.Linear(1024, 2048),
nn.ReLU(True),
nn.Linear(2048, 2048),
)
self.projector1 = nn.Sequential(
nn.Linear(128, 1024),
nn.ReLU(True),
nn.Linear(1024, 2048),
nn.ReLU(True),
nn.Linear(2048, 2048),
)
self.projector2 = nn.Sequential(
nn.Linear(256,1024),
nn.ReLU(True),
nn.Linear(1024, 2048),
nn.ReLU(True),
nn.Linear(2048, 2048),
)
self.projector3 = nn.Sequential(
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Linear(1024, 2048),
nn.ReLU(True),
nn.Linear(2048, 2048),
)
self.projector4 = nn.Sequential(
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Linear(1024, 2048),
nn.ReLU(True),
nn.Linear(2048, 2048),
)
self.projector5 = nn.Sequential(
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Linear(1024, 2048),
nn.ReLU(True),
nn.Linear(2048, 2048),
)

def forward(self, input, index):

num = 0
projections = []
for x in index.split(','):
projector = getattr(self, 'projector{:d}'.format(int(x)))
code = input[num].view(input[num].size(0), -1)
projection = projector(code).view(code.size(0), -1)
projection = nn.functional.normalize(projection)
projections.append(projection)
num += 1
return projections


def make_layers(cfg, batch_norm=True):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)


vgg = make_layers([3, 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
512, 512, 512, 512, 'M', 512, 512, 'M', 512, 512, 'M'])

0 comments on commit 02cd756

Please sign in to comment.