-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflasksample.py
144 lines (110 loc) · 4.31 KB
/
flasksample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from flask import Flask, render_template, request
import torch as th
import numpy as np
import data_processing.DataLoader as dl
import yaml
app = Flask(__name__)
caption = ''
@app.route('/')
def homepage():
return render_template("main.html")
@app.route('/', methods=['POST'])
def homepage_result():
caption = request.form["des"]
current_depth = 4
from networks.TextEncoder import Encoder
from networks.ConditionAugmentation import ConditionAugmentor
from networks.C_PRO_GAN import ProGAN
# define the device for the training script
device = th.device("cuda" if th.cuda.is_available() else "cpu")
############################################################################
# load my generator.
def get_config(conf_file):
"""
parse and load the provided configuration
:param conf_file: configuration file
:return: conf => parsed configuration
"""
from easydict import EasyDict as edict
with open(conf_file, "r") as file_descriptor:
data = yaml.load(file_descriptor)
# convert the data into an easyDictionary
return edict(data)
config = get_config("configs\\11.conf")
c_pro_gan = ProGAN(
embedding_size=config.hidden_size,
depth=config.depth,
latent_size=config.latent_size,
learning_rate=config.learning_rate,
beta_1=config.beta_1,
beta_2=config.beta_2,
eps=config.eps,
drift=config.drift,
n_critic=config.n_critic,
device=device
)
c_pro_gan.gen.load_state_dict(th.load("training_runs\\11\\saved_models\\GAN_GEN_3_20.pth"))
###################################################################################
# load my embedding and conditional augmentor
dataset = dl.Face2TextDataset(
pro_pick_file=config.processed_text_file,
img_dir=config.images_dir,
img_transform=dl.get_transform(config.img_dims),
captions_len=config.captions_length
)
text_encoder = Encoder(
embedding_size=config.embedding_size,
vocab_size=dataset.vocab_size,
hidden_size=config.hidden_size,
num_layers=config.num_layers,
device=device
)
text_encoder.load_state_dict(th.load("training_runs\\11\\saved_models\\Encoder_3_20.pth"))
condition_augmenter = ConditionAugmentor(
input_size=config.hidden_size,
latent_size=config.ca_out_size,
device=device
)
condition_augmenter.load_state_dict(th.load("training_runs\\11\\saved_models\\Condition_Augmentor_3_20.pth"))
###################################################################################
# #ask for text description/caption
# caption to text encoding
#caption = input('Enter your desired description : ')
seq = []
for word in caption.split():
seq.append(dataset.rev_vocab[word])
for i in range(len(seq), 100):
seq.append(0)
seq = th.LongTensor(seq)
seq = seq.cuda()
print(type(seq))
print('\nInput : ', caption)
list_seq = [seq for i in range(16)]
print(len(list_seq))
list_seq = th.stack(list_seq)
list_seq = list_seq.cuda()
embeddings = text_encoder(list_seq)
c_not_hats, mus, sigmas = condition_augmenter(embeddings)
z = th.randn(list_seq.shape[0],
c_pro_gan.latent_size - c_not_hats.shape[-1]
).to(device)
gan_input = th.cat((c_not_hats, z), dim=-1)
alpha = 0.007352941176470588
samples = c_pro_gan.gen(gan_input,
current_depth,
alpha)
from torchvision.utils import save_image
from torch.nn.functional import upsample
# from train_network import create_grid
img_file = "static\\" + caption + '.png'
samples = (samples / 2) + 0.5
if int(np.power(2, c_pro_gan.depth - current_depth - 1)) > 1:
samples = upsample(samples, scale_factor=current_depth)
# save image to the disk, the resulting image is <caption>.png
save_image(samples, img_file, nrow=int(np.sqrt(20)))
###################################################################################
# #output the image.
result = "\\static\\"+caption+".png"
return render_template("main.html", result_img=result, result_caption=caption)
if __name__ == "__main__":
app.run(debug=True)