-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathstyle_transfer.py
177 lines (132 loc) · 6.37 KB
/
style_transfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from __future__ import print_function
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import os
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
import torchvision.models as models
import copy
import argparse
from loss import ContentLoss, GramMatrix, StyleLoss
from utils import image_loader, image_loader_gray, save_image
parser = argparse.ArgumentParser()
parser.add_argument('--content', '-c', type=str, required=True, help='The path to the Content image')
parser.add_argument('--style', '-s', type=str, required=True, help='The path to the style image')
parser.add_argument('--epoch', '-e', type=int, default=300, help='The number of epoch')
parser.add_argument('--content_weight', '-c_w', type=int, default=1, help='The weight of content loss')
parser.add_argument('--style_weight', '-s_w', type=int, default=500, help='The weight of style loss')
parser.add_argument('--initialize_noise', '-i_n', action='store_true', help='Initialize with white noise? elif initialize with content image')
parser.add_argument('--cuda', action='store_true', help='use cuda?')
args = parser.parse_args()
use_cuda = torch.cuda.is_available() and args.cuda
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
# desired size of the output image
imsize = 512 if use_cuda else 128 # use small size if no gpu
style_img = image_loader(args.style, imsize).type(dtype)
content_img = image_loader_gray(args.content, imsize).type(dtype)
if args.initialize_noise:
input_img = Variable(torch.randn(content_img.data.size())).type(dtype)
else:
input_img = image_loader_gray(args.content, imsize).type(dtype)
input_size = Image.open(args.content).size
assert style_img.size() == content_img.size(), \
"we need to import style and content images of the same size"
cnn = models.vgg19(pretrained=True).features
# move it to the GPU if possible:
if use_cuda:
cnn = cnn.cuda()
# desired depth layers to compute style/content losses :
content_layers_default = ['conv_4']
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
def get_style_model_and_losses(cnn, style_img, content_img, style_weight=1000, content_weight=1, content_layers=content_layers_default, style_layers=style_layers_default):
cnn = copy.deepcopy(cnn)
# just in order to have an iterable access to or list of content/syle
# losses
content_losses = []
style_losses = []
model = nn.Sequential() # the new Sequential module network
gram = GramMatrix() # we need a gram module in order to compute style targets
# move these modules to the GPU if possible:
if use_cuda:
model = model.cuda()
gram = gram.cuda()
i = 1
for layer in list(cnn):
if isinstance(layer, nn.Conv2d):
name = "conv_" + str(i)
model.add_module(name, layer)
if name in content_layers:
# add content loss:
target = model(content_img).clone()
content_loss = ContentLoss(target, content_weight)
model.add_module("content_loss_" + str(i), content_loss)
content_losses.append(content_loss)
if name in style_layers:
# add style loss:
target_feature = model(style_img).clone()
target_feature_gram = gram(target_feature)
style_loss = StyleLoss(target_feature_gram, style_weight)
model.add_module("style_loss_" + str(i), style_loss)
style_losses.append(style_loss)
if isinstance(layer, nn.ReLU):
name = "relu_" + str(i)
model.add_module(name, layer)
if name in content_layers:
# add content loss:
target = model(content_img).clone()
content_loss = ContentLoss(target, content_weight)
model.add_module("content_loss_" + str(i), content_loss)
content_losses.append(content_loss)
if name in style_layers:
# add style loss:
target_feature = model(style_img).clone()
target_feature_gram = gram(target_feature)
style_loss = StyleLoss(target_feature_gram, style_weight)
model.add_module("style_loss_" + str(i), style_loss)
style_losses.append(style_loss)
i += 1
if isinstance(layer, nn.MaxPool2d):
name = "pool_" + str(i)
model.add_module(name, layer) # ***
return model, style_losses, content_losses
def get_input_param_optimizer(input_img):
# this line to show that input is a parameter that requires a gradient
input_param = nn.Parameter(input_img.data)
optimizer = optim.LBFGS([input_param])
return input_param, optimizer
def run_style_transfer(cnn, content_img, style_img, input_img, num_steps=300, style_weight=1000, content_weight=1):
"""Run the style transfer."""
print('Building the style transfer model..')
model, style_losses, content_losses = get_style_model_and_losses(cnn, style_img, content_img, style_weight, content_weight)
input_param, optimizer = get_input_param_optimizer(input_img)
print('Optimizing..')
run = [0]
while run[0] <= num_steps:
def closure():
# correct the values of updated input image
input_param.data.clamp_(0, 1)
optimizer.zero_grad()
model(input_param)
style_score = 0
content_score = 0
for sl in style_losses:
style_score += sl.backward()
for cl in content_losses:
content_score += cl.backward()
run[0] += 1
if run[0] % 50 == 0:
print("run {}:".format(run))
print('Style Loss : {:4f} Content Loss: {:4f}'.format(style_score.data[0], content_score.data[0]))
return style_score + content_score
optimizer.step(closure)
# a last correction...
input_param.data.clamp_(0, 1)
return input_param.data
output = run_style_transfer(cnn, content_img, style_img, input_img, args.epoch, args.style_weight, args.content_weight)
name_content, ext = os.path.splitext(os.path.basename(args.content))
name_style, _ = os.path.splitext(os.path.basename(args.style))
fname = name_content+'-'+name_style+ext
save_image(output, size=input_img.data.size()[1:], input_size=input_size, fname=fname)