-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_img.py
150 lines (113 loc) · 4.77 KB
/
train_img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import torch
from torch import optim, nn
from model import Siren,MLP,DinerMLP,DinerSiren
from dataio import ImageData
import time
import utils
from tqdm.autonotebook import tqdm
from opt import HyperParameters
class Logger:
filename = None
@staticmethod
def write(text):
with open(Logger.filename, 'a') as log_file:
log_file.write(text + '\n')
@staticmethod
def write_file(text):
with open(Logger.filename, 'a') as log_file:
log_file.write(text + '\n')
def train_img(opt):
img_path = opt.img_path
steps = opt.steps
lr = opt.lr
hidden_layers = opt.hidden_layers
hidden_features = opt.hidden_features
sidelength = opt.sidelength
grayscale = opt.grayscale
first_omega_0 = opt.w0
hidden_omega_0 = opt.w0
model_type = opt.model_type
steps_til_summary = opt.steps_til_summary
input_dim = opt.input_dim
epochs = opt.epochs
remain_raw_resolution = opt.remain_raw_resolution
experiment_name = opt.experiment_name
# make directory
log_dir = "log"
utils.cond_mkdir(os.path.join(log_dir,experiment_name))
# check parameters
if steps % steps_til_summary:
raise ValueError("steps_til_summary could not be devided by steps!")
# logger
Logger.filename = os.path.join(log_dir, experiment_name,'log.txt')
device = torch.device('cuda')
criteon = nn.MSELoss()
out_features = 3
Dataset = ImageData(image_path = img_path,
sidelength = sidelength,
grayscale = grayscale,
remain_raw_resolution = remain_raw_resolution)
model_input,gt = Dataset[0]
model_input = model_input.to(device)
gt = gt.to(device)
hash_table_length = model_input.shape[0]
if model_type == 'Siren':
model = Siren(in_features = input_dim,
hidden_features = hidden_features,
hidden_layers = hidden_layers,
out_features = out_features,
).to(device = device)
elif model_type == 'MLP':
model = MLP(in_features = input_dim,
out_features = out_features,
hidden_layers = hidden_layers,
hidden_features= hidden_features,
).to(device = device)
elif model_type == 'DinerSiren':
model = DinerSiren(
hash_table_length = hash_table_length,
in_features = input_dim,
hidden_features = hidden_features,
hidden_layers = hidden_layers,
out_features = out_features,
outermost_linear = True,
first_omega_0 = first_omega_0,
hidden_omega_0 = hidden_omega_0).to(device = device)
elif model_type == 'DinerMLP':
model = DinerMLP(
hash_table_length = hash_table_length,
in_features = input_dim,
hidden_features = hidden_features,
hidden_layers = hidden_layers,
out_features = out_features).to(device = device)
else:
raise NotImplementedError("Model type not supported!")
optimizer = optim.Adam(lr = lr,params = model.parameters())
# training process
with tqdm(total=epochs) as pbar:
max_psnr = 0
time_cost = 0
for epoch in range(epochs):
time_start = time.time()
loss_mse = 0
model_output = model(model_input)
loss_mse = criteon(model_output,gt)
optimizer.zero_grad()
loss_mse.backward()
optimizer.step()
torch.cuda.synchronize()
time_cost += time.time() - time_start
cur_psnr = utils.loss2psnr(loss_mse)
max_psnr = max(max_psnr,cur_psnr)
if (epoch + 1) % steps_til_summary == 0:
log_str = f"[TRAIN] Epoch: {epoch+1} Loss: {loss_mse.item()} PSNR: {cur_psnr} Time: {round(time_cost, 2)}"
Logger.write(log_str)
pbar.update(1)
utils.render_raw_image(model,os.path.join(log_dir,experiment_name,'recon.png'),[1200,1200],linear = False)
recon_psnr = utils.calculate_psnr(os.path.join(log_dir,experiment_name,'recon.png'),img_path)
print(f"Reconstruction PSNR: {recon_psnr:.2f}")
return
if __name__ == "__main__":
opt = HyperParameters()
train_img(opt)