-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b4ef070
commit 02cd756
Showing
3 changed files
with
223 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,5 @@ gradio==4.8.0 | |
pyyaml | ||
pygame | ||
opencv-python | ||
info-nce-pytorch | ||
kornia |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import torch | ||
|
||
import torch.nn as nn | ||
import src.modules.SCRModules as SCRModules | ||
|
||
from info_nce import InfoNCE | ||
import kornia.augmentation as K | ||
|
||
class SCM(nn.Module): | ||
|
||
def __init__(self, | ||
temperature, | ||
mode='training', | ||
image_size=96): | ||
super().__init__() | ||
style_vgg = SCRModules.vgg | ||
style_vgg = nn.Sequential(*list(style_vgg.children())) | ||
self.StyleFeatExtractor = SCRModules.StyleExtractor( | ||
encoder=style_vgg) | ||
self.StyleFeatProjector = SCRModules.Projector() | ||
|
||
if mode == 'training': | ||
self.StyleFeatExtractor.requires_grad_(True) | ||
self.StyleFeatProjector.requires_grad_(True) | ||
else: | ||
self.StyleFeatExtractor.requires_grad_(False) | ||
self.StyleFeatProjector.requires_grad_(False) | ||
|
||
# NCE Loss | ||
self.nce_loss = InfoNCE( | ||
temperature=temperature, | ||
negative_mode='paired', | ||
) | ||
|
||
# Pos Image random resize and crop | ||
self.patch_sampler = K.RandomResizedCrop( | ||
(image_size, image_size), | ||
scale=(0.8,1.0), | ||
ratio=(0.75,1.33)) | ||
|
||
def forward(self, sample_imgs, pos_imgs, neg_imgs, nce_layers='0,1,2,3,4,5'): | ||
|
||
# Get generated image style embedding | ||
sample_style_embeddings = self.StyleFeatProjector( | ||
self.StyleFeatExtractor( | ||
sample_imgs, | ||
nce_layers), | ||
nce_layers) # out: N * C(2048) | ||
|
||
# Random resize and crop for positive images | ||
pos_imgs = self.patch_sampler(pos_imgs) | ||
# Get positive image style embedding | ||
pos_style_embeddings = self.StyleFeatProjector( | ||
self.StyleFeatExtractor( | ||
pos_imgs, | ||
nce_layers), | ||
nce_layers) | ||
|
||
# Get negative image style embedding | ||
_, num_neg, _, _, _ = neg_imgs.shape | ||
for i in range(num_neg): | ||
neg_imgs_once = neg_imgs[:, i, :, :] | ||
neg_style_embeddings_once = self.StyleFeatProjector( | ||
self.StyleFeatExtractor( | ||
neg_imgs_once, | ||
nce_layers), | ||
nce_layers) | ||
for j, layer_out in enumerate(neg_style_embeddings_once): | ||
if j == 0: | ||
neg_style_embeddings_mid = layer_out[None, :, :] | ||
else: | ||
neg_style_embeddings_mid = torch.cat( | ||
[neg_style_embeddings_mid, layer_out[None, :, :]], | ||
dim=0) | ||
if i == 0: | ||
neg_style_embeddings = neg_style_embeddings_mid[:, :, None, :] | ||
else: | ||
neg_style_embeddings = torch.cat( | ||
[neg_style_embeddings, neg_style_embeddings_mid[:, :, None, :]], | ||
dim=2) | ||
|
||
return sample_style_embeddings, pos_style_embeddings, neg_style_embeddings | ||
|
||
def calculate_nce_loss(self, sample_s, pos_s, neg_s): | ||
|
||
num_layer = neg_s.shape[0] | ||
neg_s_list = [] | ||
for i in range(num_layer): | ||
neg_s_list.append(neg_s[i]) | ||
|
||
total_scm_loss = 0. | ||
for layer, (sample, pos, neg) in enumerate(zip(sample_s, pos_s, neg_s_list)): | ||
scm_loss = self.nce_loss(sample, pos, neg) | ||
total_scm_loss += scm_loss | ||
|
||
return total_scm_loss / num_layer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class StyleExtractor(nn.Module): | ||
|
||
def __init__(self, encoder): | ||
|
||
super(StyleExtractor, self).__init__() | ||
enc_layers = list(encoder.children()) | ||
self.enc_1 = nn.Sequential(*enc_layers[:6]) # input -> relu1_1 | ||
self.enc_2 = nn.Sequential(*enc_layers[6:13]) # relu1_1 -> relu2_1 | ||
self.enc_3 = nn.Sequential(*enc_layers[13:20]) # relu2_1 -> relu3_1 | ||
self.enc_4 = nn.Sequential(*enc_layers[20:33]) # relu3_1 -> relu4_1 | ||
self.enc_5 = nn.Sequential(*enc_layers[33:46]) # relu4_1 -> relu5_1 | ||
self.enc_6 = nn.Sequential(*enc_layers[46:69]) # relu5_1 -> relu | ||
|
||
self.conv1x1_0 = nn.Conv2d(128, 64, kernel_size=1, stride=1, bias=True) | ||
self.conv1x1_1 = nn.Conv2d(256, 128, kernel_size=1, stride=1, bias=True) | ||
self.conv1x1_2 = nn.Conv2d(512, 256, kernel_size=1, stride=1, bias=True) | ||
self.conv1x1_3 = nn.Conv2d(1024, 512, kernel_size=1, stride=1, bias=True) | ||
self.conv1x1_4 = nn.Conv2d(1024, 512, kernel_size=1, stride=1, bias=True) | ||
self.conv1x1_5 = nn.Conv2d(1024, 512, kernel_size=1, stride=1, bias=True) | ||
self.relu = nn.ReLU(True) | ||
|
||
def encode_with_intermediate(self, input): | ||
results = [input] | ||
for i in range(6): | ||
func = getattr(self, 'enc_{:d}'.format(i + 1)) | ||
results.append(func(results[-1])) | ||
return results[1:] | ||
|
||
def forward(self, input, index): | ||
|
||
feats = self.encode_with_intermediate(input) | ||
codes = [] | ||
for x in index.split(','): | ||
code = feats[int(x)].clone() | ||
gap = torch.nn.functional.adaptive_avg_pool2d(code, (1,1)) | ||
gmp = torch.nn.functional.adaptive_max_pool2d(code, (1,1)) | ||
conv1x1 = getattr(self, 'conv1x1_{:d}'.format(int(x))) | ||
code = torch.cat([gap, gmp], 1) | ||
code = self.relu(conv1x1(code)) | ||
codes.append(code) | ||
return codes | ||
|
||
|
||
class Projector(nn.Module): | ||
def __init__(self,): | ||
super(Projector, self).__init__() | ||
self.projector0 = nn.Sequential( | ||
nn.Linear(64, 1024), | ||
nn.ReLU(True), | ||
nn.Linear(1024, 2048), | ||
nn.ReLU(True), | ||
nn.Linear(2048, 2048), | ||
) | ||
self.projector1 = nn.Sequential( | ||
nn.Linear(128, 1024), | ||
nn.ReLU(True), | ||
nn.Linear(1024, 2048), | ||
nn.ReLU(True), | ||
nn.Linear(2048, 2048), | ||
) | ||
self.projector2 = nn.Sequential( | ||
nn.Linear(256,1024), | ||
nn.ReLU(True), | ||
nn.Linear(1024, 2048), | ||
nn.ReLU(True), | ||
nn.Linear(2048, 2048), | ||
) | ||
self.projector3 = nn.Sequential( | ||
nn.Linear(512, 1024), | ||
nn.ReLU(True), | ||
nn.Linear(1024, 2048), | ||
nn.ReLU(True), | ||
nn.Linear(2048, 2048), | ||
) | ||
self.projector4 = nn.Sequential( | ||
nn.Linear(512, 1024), | ||
nn.ReLU(True), | ||
nn.Linear(1024, 2048), | ||
nn.ReLU(True), | ||
nn.Linear(2048, 2048), | ||
) | ||
self.projector5 = nn.Sequential( | ||
nn.Linear(512, 1024), | ||
nn.ReLU(True), | ||
nn.Linear(1024, 2048), | ||
nn.ReLU(True), | ||
nn.Linear(2048, 2048), | ||
) | ||
|
||
def forward(self, input, index): | ||
|
||
num = 0 | ||
projections = [] | ||
for x in index.split(','): | ||
projector = getattr(self, 'projector{:d}'.format(int(x))) | ||
code = input[num].view(input[num].size(0), -1) | ||
projection = projector(code).view(code.size(0), -1) | ||
projection = nn.functional.normalize(projection) | ||
projections.append(projection) | ||
num += 1 | ||
return projections | ||
|
||
|
||
def make_layers(cfg, batch_norm=True): | ||
layers = [] | ||
in_channels = 3 | ||
for v in cfg: | ||
if v == 'M': | ||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)] | ||
else: | ||
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) | ||
if batch_norm: | ||
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] | ||
else: | ||
layers += [conv2d, nn.ReLU(inplace=True)] | ||
in_channels = v | ||
return nn.Sequential(*layers) | ||
|
||
|
||
vgg = make_layers([3, 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', | ||
512, 512, 512, 512, 'M', 512, 512, 'M', 512, 512, 'M']) |