diff --git a/requirements.txt b/requirements.txt index d45b877..080908b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ gradio==4.8.0 pyyaml pygame opencv-python +info-nce-pytorch +kornia diff --git a/src/modules/SCR.py b/src/modules/SCR.py new file mode 100644 index 0000000..9ee6881 --- /dev/null +++ b/src/modules/SCR.py @@ -0,0 +1,96 @@ +import torch + +import torch.nn as nn +import src.modules.SCRModules as SCRModules + +from info_nce import InfoNCE +import kornia.augmentation as K + +class SCM(nn.Module): + + def __init__(self, + temperature, + mode='training', + image_size=96): + super().__init__() + style_vgg = SCRModules.vgg + style_vgg = nn.Sequential(*list(style_vgg.children())) + self.StyleFeatExtractor = SCRModules.StyleExtractor( + encoder=style_vgg) + self.StyleFeatProjector = SCRModules.Projector() + + if mode == 'training': + self.StyleFeatExtractor.requires_grad_(True) + self.StyleFeatProjector.requires_grad_(True) + else: + self.StyleFeatExtractor.requires_grad_(False) + self.StyleFeatProjector.requires_grad_(False) + + # NCE Loss + self.nce_loss = InfoNCE( + temperature=temperature, + negative_mode='paired', + ) + + # Pos Image random resize and crop + self.patch_sampler = K.RandomResizedCrop( + (image_size, image_size), + scale=(0.8,1.0), + ratio=(0.75,1.33)) + + def forward(self, sample_imgs, pos_imgs, neg_imgs, nce_layers='0,1,2,3,4,5'): + + # Get generated image style embedding + sample_style_embeddings = self.StyleFeatProjector( + self.StyleFeatExtractor( + sample_imgs, + nce_layers), + nce_layers) # out: N * C(2048) + + # Random resize and crop for positive images + pos_imgs = self.patch_sampler(pos_imgs) + # Get positive image style embedding + pos_style_embeddings = self.StyleFeatProjector( + self.StyleFeatExtractor( + pos_imgs, + nce_layers), + nce_layers) + + # Get negative image style embedding + _, num_neg, _, _, _ = neg_imgs.shape + for i in range(num_neg): + neg_imgs_once = neg_imgs[:, i, :, :] + neg_style_embeddings_once = self.StyleFeatProjector( + self.StyleFeatExtractor( + neg_imgs_once, + nce_layers), + nce_layers) + for j, layer_out in enumerate(neg_style_embeddings_once): + if j == 0: + neg_style_embeddings_mid = layer_out[None, :, :] + else: + neg_style_embeddings_mid = torch.cat( + [neg_style_embeddings_mid, layer_out[None, :, :]], + dim=0) + if i == 0: + neg_style_embeddings = neg_style_embeddings_mid[:, :, None, :] + else: + neg_style_embeddings = torch.cat( + [neg_style_embeddings, neg_style_embeddings_mid[:, :, None, :]], + dim=2) + + return sample_style_embeddings, pos_style_embeddings, neg_style_embeddings + + def calculate_nce_loss(self, sample_s, pos_s, neg_s): + + num_layer = neg_s.shape[0] + neg_s_list = [] + for i in range(num_layer): + neg_s_list.append(neg_s[i]) + + total_scm_loss = 0. + for layer, (sample, pos, neg) in enumerate(zip(sample_s, pos_s, neg_s_list)): + scm_loss = self.nce_loss(sample, pos, neg) + total_scm_loss += scm_loss + + return total_scm_loss / num_layer diff --git a/src/modules/SCRModules.py b/src/modules/SCRModules.py new file mode 100644 index 0000000..6157b11 --- /dev/null +++ b/src/modules/SCRModules.py @@ -0,0 +1,125 @@ +import torch +import torch.nn as nn + + +class StyleExtractor(nn.Module): + + def __init__(self, encoder): + + super(StyleExtractor, self).__init__() + enc_layers = list(encoder.children()) + self.enc_1 = nn.Sequential(*enc_layers[:6]) # input -> relu1_1 + self.enc_2 = nn.Sequential(*enc_layers[6:13]) # relu1_1 -> relu2_1 + self.enc_3 = nn.Sequential(*enc_layers[13:20]) # relu2_1 -> relu3_1 + self.enc_4 = nn.Sequential(*enc_layers[20:33]) # relu3_1 -> relu4_1 + self.enc_5 = nn.Sequential(*enc_layers[33:46]) # relu4_1 -> relu5_1 + self.enc_6 = nn.Sequential(*enc_layers[46:69]) # relu5_1 -> relu + + self.conv1x1_0 = nn.Conv2d(128, 64, kernel_size=1, stride=1, bias=True) + self.conv1x1_1 = nn.Conv2d(256, 128, kernel_size=1, stride=1, bias=True) + self.conv1x1_2 = nn.Conv2d(512, 256, kernel_size=1, stride=1, bias=True) + self.conv1x1_3 = nn.Conv2d(1024, 512, kernel_size=1, stride=1, bias=True) + self.conv1x1_4 = nn.Conv2d(1024, 512, kernel_size=1, stride=1, bias=True) + self.conv1x1_5 = nn.Conv2d(1024, 512, kernel_size=1, stride=1, bias=True) + self.relu = nn.ReLU(True) + + def encode_with_intermediate(self, input): + results = [input] + for i in range(6): + func = getattr(self, 'enc_{:d}'.format(i + 1)) + results.append(func(results[-1])) + return results[1:] + + def forward(self, input, index): + + feats = self.encode_with_intermediate(input) + codes = [] + for x in index.split(','): + code = feats[int(x)].clone() + gap = torch.nn.functional.adaptive_avg_pool2d(code, (1,1)) + gmp = torch.nn.functional.adaptive_max_pool2d(code, (1,1)) + conv1x1 = getattr(self, 'conv1x1_{:d}'.format(int(x))) + code = torch.cat([gap, gmp], 1) + code = self.relu(conv1x1(code)) + codes.append(code) + return codes + + +class Projector(nn.Module): + def __init__(self,): + super(Projector, self).__init__() + self.projector0 = nn.Sequential( + nn.Linear(64, 1024), + nn.ReLU(True), + nn.Linear(1024, 2048), + nn.ReLU(True), + nn.Linear(2048, 2048), + ) + self.projector1 = nn.Sequential( + nn.Linear(128, 1024), + nn.ReLU(True), + nn.Linear(1024, 2048), + nn.ReLU(True), + nn.Linear(2048, 2048), + ) + self.projector2 = nn.Sequential( + nn.Linear(256,1024), + nn.ReLU(True), + nn.Linear(1024, 2048), + nn.ReLU(True), + nn.Linear(2048, 2048), + ) + self.projector3 = nn.Sequential( + nn.Linear(512, 1024), + nn.ReLU(True), + nn.Linear(1024, 2048), + nn.ReLU(True), + nn.Linear(2048, 2048), + ) + self.projector4 = nn.Sequential( + nn.Linear(512, 1024), + nn.ReLU(True), + nn.Linear(1024, 2048), + nn.ReLU(True), + nn.Linear(2048, 2048), + ) + self.projector5 = nn.Sequential( + nn.Linear(512, 1024), + nn.ReLU(True), + nn.Linear(1024, 2048), + nn.ReLU(True), + nn.Linear(2048, 2048), + ) + + def forward(self, input, index): + + num = 0 + projections = [] + for x in index.split(','): + projector = getattr(self, 'projector{:d}'.format(int(x))) + code = input[num].view(input[num].size(0), -1) + projection = projector(code).view(code.size(0), -1) + projection = nn.functional.normalize(projection) + projections.append(projection) + num += 1 + return projections + + +def make_layers(cfg, batch_norm=True): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + +vgg = make_layers([3, 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', + 512, 512, 512, 512, 'M', 512, 512, 'M', 512, 512, 'M'])