forked from nianticlabs/monodepth2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
630 lines (489 loc) · 25.2 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
# Copyright Niantic 2019. Patent Pending. All rights reserved.
#
# This software is licensed under the terms of the Monodepth2 licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.
from __future__ import absolute_import, division, print_function
import numpy as np
import time
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
import json
from utils import *
from kitti_utils import *
from layers import *
import datasets
import networks
from IPython import embed
class Trainer:
def __init__(self, options):
self.opt = options
self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name)
# checking height and width are multiples of 32
assert self.opt.height % 32 == 0, "'height' must be a multiple of 32"
assert self.opt.width % 32 == 0, "'width' must be a multiple of 32"
self.models = {}
self.parameters_to_train = []
self.device = torch.device("cpu" if self.opt.no_cuda else "cuda")
self.num_scales = len(self.opt.scales)
self.num_input_frames = len(self.opt.frame_ids)
self.num_pose_frames = 2 if self.opt.pose_model_input == "pairs" else self.num_input_frames
assert self.opt.frame_ids[0] == 0, "frame_ids must start with 0"
self.use_pose_net = not (self.opt.use_stereo and self.opt.frame_ids == [0])
if self.opt.use_stereo:
self.opt.frame_ids.append("s")
self.models["encoder"] = networks.ResnetEncoder(
self.opt.num_layers, self.opt.weights_init == "pretrained")
self.models["encoder"].to(self.device)
self.parameters_to_train += list(self.models["encoder"].parameters())
self.models["depth"] = networks.DepthDecoder(
self.models["encoder"].num_ch_enc, self.opt.scales)
self.models["depth"].to(self.device)
self.parameters_to_train += list(self.models["depth"].parameters())
if self.use_pose_net:
if self.opt.pose_model_type == "separate_resnet":
self.models["pose_encoder"] = networks.ResnetEncoder(
self.opt.num_layers,
self.opt.weights_init == "pretrained",
num_input_images=self.num_pose_frames)
self.models["pose_encoder"].to(self.device)
self.parameters_to_train += list(self.models["pose_encoder"].parameters())
self.models["pose"] = networks.PoseDecoder(
self.models["pose_encoder"].num_ch_enc,
num_input_features=1,
num_frames_to_predict_for=2)
elif self.opt.pose_model_type == "shared":
self.models["pose"] = networks.PoseDecoder(
self.models["encoder"].num_ch_enc, self.num_pose_frames)
elif self.opt.pose_model_type == "posecnn":
self.models["pose"] = networks.PoseCNN(
self.num_input_frames if self.opt.pose_model_input == "all" else 2)
self.models["pose"].to(self.device)
self.parameters_to_train += list(self.models["pose"].parameters())
if self.opt.predictive_mask:
assert self.opt.disable_automasking, \
"When using predictive_mask, please disable automasking with --disable_automasking"
# Our implementation of the predictive masking baseline has the the same architecture
# as our depth decoder. We predict a separate mask for each source frame.
self.models["predictive_mask"] = networks.DepthDecoder(
self.models["encoder"].num_ch_enc, self.opt.scales,
num_output_channels=(len(self.opt.frame_ids) - 1))
self.models["predictive_mask"].to(self.device)
self.parameters_to_train += list(self.models["predictive_mask"].parameters())
self.model_optimizer = optim.Adam(self.parameters_to_train, self.opt.learning_rate)
self.model_lr_scheduler = optim.lr_scheduler.StepLR(
self.model_optimizer, self.opt.scheduler_step_size, 0.1)
if self.opt.load_weights_folder is not None:
self.load_model()
print("Training model named:\n ", self.opt.model_name)
print("Models and tensorboard events files are saved to:\n ", self.opt.log_dir)
print("Training is using:\n ", self.device)
# data
datasets_dict = {"kitti": datasets.KITTIRAWDataset,
"kitti_odom": datasets.KITTIOdomDataset}
self.dataset = datasets_dict[self.opt.dataset]
fpath = os.path.join(os.path.dirname(__file__), "splits", self.opt.split, "{}_files.txt")
train_filenames = readlines(fpath.format("train"))
val_filenames = readlines(fpath.format("val"))
img_ext = '.png' if self.opt.png else '.jpg'
num_train_samples = len(train_filenames)
self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs
train_dataset = self.dataset(
self.opt.data_path, train_filenames, self.opt.height, self.opt.width,
self.opt.frame_ids, 4, is_train=True, img_ext=img_ext)
self.train_loader = DataLoader(
train_dataset, self.opt.batch_size, True,
num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)
val_dataset = self.dataset(
self.opt.data_path, val_filenames, self.opt.height, self.opt.width,
self.opt.frame_ids, 4, is_train=False, img_ext=img_ext)
self.val_loader = DataLoader(
val_dataset, self.opt.batch_size, True,
num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)
self.val_iter = iter(self.val_loader)
self.writers = {}
for mode in ["train", "val"]:
self.writers[mode] = SummaryWriter(os.path.join(self.log_path, mode))
if not self.opt.no_ssim:
self.ssim = SSIM()
self.ssim.to(self.device)
self.backproject_depth = {}
self.project_3d = {}
for scale in self.opt.scales:
h = self.opt.height // (2 ** scale)
w = self.opt.width // (2 ** scale)
self.backproject_depth[scale] = BackprojectDepth(self.opt.batch_size, h, w)
self.backproject_depth[scale].to(self.device)
self.project_3d[scale] = Project3D(self.opt.batch_size, h, w)
self.project_3d[scale].to(self.device)
self.depth_metric_names = [
"de/abs_rel", "de/sq_rel", "de/rms", "de/log_rms", "da/a1", "da/a2", "da/a3"]
print("Using split:\n ", self.opt.split)
print("There are {:d} training items and {:d} validation items\n".format(
len(train_dataset), len(val_dataset)))
self.save_opts()
def set_train(self):
"""Convert all models to training mode
"""
for m in self.models.values():
m.train()
def set_eval(self):
"""Convert all models to testing/evaluation mode
"""
for m in self.models.values():
m.eval()
def train(self):
"""Run the entire training pipeline
"""
self.epoch = 0
self.step = 0
self.start_time = time.time()
for self.epoch in range(self.opt.num_epochs):
self.run_epoch()
if (self.epoch + 1) % self.opt.save_frequency == 0:
self.save_model()
def run_epoch(self):
"""Run a single epoch of training and validation
"""
self.model_lr_scheduler.step()
print("Training")
self.set_train()
for batch_idx, inputs in enumerate(self.train_loader):
before_op_time = time.time()
outputs, losses = self.process_batch(inputs)
self.model_optimizer.zero_grad()
losses["loss"].backward()
self.model_optimizer.step()
duration = time.time() - before_op_time
# log less frequently after the first 2000 steps to save time & disk space
early_phase = batch_idx % self.opt.log_frequency == 0 and self.step < 2000
late_phase = self.step % 2000 == 0
if early_phase or late_phase:
self.log_time(batch_idx, duration, losses["loss"].cpu().data)
if "depth_gt" in inputs:
self.compute_depth_losses(inputs, outputs, losses)
self.log("train", inputs, outputs, losses)
self.val()
self.step += 1
def process_batch(self, inputs):
"""Pass a minibatch through the network and generate images and losses
"""
for key, ipt in inputs.items():
inputs[key] = ipt.to(self.device)
if self.opt.pose_model_type == "shared":
# If we are using a shared encoder for both depth and pose (as advocated
# in monodepthv1), then all images are fed separately through the depth encoder.
all_color_aug = torch.cat([inputs[("color_aug", i, 0)] for i in self.opt.frame_ids])
all_features = self.models["encoder"](all_color_aug)
all_features = [torch.split(f, self.opt.batch_size) for f in all_features]
features = {}
for i, k in enumerate(self.opt.frame_ids):
features[k] = [f[i] for f in all_features]
outputs = self.models["depth"](features[0])
else:
# Otherwise, we only feed the image with frame_id 0 through the depth encoder
features = self.models["encoder"](inputs["color_aug", 0, 0])
outputs = self.models["depth"](features)
if self.opt.predictive_mask:
outputs["predictive_mask"] = self.models["predictive_mask"](features)
if self.use_pose_net:
outputs.update(self.predict_poses(inputs, features))
self.generate_images_pred(inputs, outputs)
losses = self.compute_losses(inputs, outputs)
return outputs, losses
def predict_poses(self, inputs, features):
"""Predict poses between input frames for monocular sequences.
"""
outputs = {}
if self.num_pose_frames == 2:
# In this setting, we compute the pose to each source frame via a
# separate forward pass through the pose network.
# select what features the pose network takes as input
if self.opt.pose_model_type == "shared":
pose_feats = {f_i: features[f_i] for f_i in self.opt.frame_ids}
else:
pose_feats = {f_i: inputs["color_aug", f_i, 0] for f_i in self.opt.frame_ids}
for f_i in self.opt.frame_ids[1:]:
if f_i != "s":
# To maintain ordering we always pass frames in temporal order
if f_i < 0:
pose_inputs = [pose_feats[f_i], pose_feats[0]]
else:
pose_inputs = [pose_feats[0], pose_feats[f_i]]
if self.opt.pose_model_type == "separate_resnet":
pose_inputs = [self.models["pose_encoder"](torch.cat(pose_inputs, 1))]
elif self.opt.pose_model_type == "posecnn":
pose_inputs = torch.cat(pose_inputs, 1)
axisangle, translation = self.models["pose"](pose_inputs)
outputs[("axisangle", 0, f_i)] = axisangle
outputs[("translation", 0, f_i)] = translation
# Invert the matrix if the frame id is negative
outputs[("cam_T_cam", 0, f_i)] = transformation_from_parameters(
axisangle[:, 0], translation[:, 0], invert=(f_i < 0))
else:
# Here we input all frames to the pose net (and predict all poses) together
if self.opt.pose_model_type in ["separate_resnet", "posecnn"]:
pose_inputs = torch.cat(
[inputs[("color_aug", i, 0)] for i in self.opt.frame_ids if i != "s"], 1)
if self.opt.pose_model_type == "separate_resnet":
pose_inputs = [self.models["pose_encoder"](pose_inputs)]
elif self.opt.pose_model_type == "shared":
pose_inputs = [features[i] for i in self.opt.frame_ids if i != "s"]
axisangle, translation = self.models["pose"](pose_inputs)
for i, f_i in enumerate(self.opt.frame_ids[1:]):
if f_i != "s":
outputs[("axisangle", 0, f_i)] = axisangle
outputs[("translation", 0, f_i)] = translation
outputs[("cam_T_cam", 0, f_i)] = transformation_from_parameters(
axisangle[:, i], translation[:, i])
return outputs
def val(self):
"""Validate the model on a single minibatch
"""
self.set_eval()
try:
inputs = self.val_iter.next()
except StopIteration:
self.val_iter = iter(self.val_loader)
inputs = self.val_iter.next()
with torch.no_grad():
outputs, losses = self.process_batch(inputs)
if "depth_gt" in inputs:
self.compute_depth_losses(inputs, outputs, losses)
self.log("val", inputs, outputs, losses)
del inputs, outputs, losses
self.set_train()
def generate_images_pred(self, inputs, outputs):
"""Generate the warped (reprojected) color images for a minibatch.
Generated images are saved into the `outputs` dictionary.
"""
for scale in self.opt.scales:
disp = outputs[("disp", scale)]
if self.opt.v1_multiscale:
source_scale = scale
else:
disp = F.interpolate(
disp, [self.opt.height, self.opt.width], mode="bilinear", align_corners=False)
source_scale = 0
_, depth = disp_to_depth(disp, self.opt.min_depth, self.opt.max_depth)
outputs[("depth", 0, scale)] = depth
for i, frame_id in enumerate(self.opt.frame_ids[1:]):
if frame_id == "s":
T = inputs["stereo_T"]
else:
T = outputs[("cam_T_cam", 0, frame_id)]
# from the authors of https://arxiv.org/abs/1712.00175
if self.opt.pose_model_type == "posecnn":
axisangle = outputs[("axisangle", 0, frame_id)]
translation = outputs[("translation", 0, frame_id)]
inv_depth = 1 / depth
mean_inv_depth = inv_depth.mean(3, True).mean(2, True)
T = transformation_from_parameters(
axisangle[:, 0], translation[:, 0] * mean_inv_depth[:, 0], frame_id < 0)
cam_points = self.backproject_depth[source_scale](
depth, inputs[("inv_K", source_scale)])
pix_coords = self.project_3d[source_scale](
cam_points, inputs[("K", source_scale)], T)
outputs[("sample", frame_id, scale)] = pix_coords
outputs[("color", frame_id, scale)] = F.grid_sample(
inputs[("color", frame_id, source_scale)],
outputs[("sample", frame_id, scale)],
padding_mode="border")
if not self.opt.disable_automasking:
outputs[("color_identity", frame_id, scale)] = \
inputs[("color", frame_id, source_scale)]
def compute_reprojection_loss(self, pred, target):
"""Computes reprojection loss between a batch of predicted and target images
"""
abs_diff = torch.abs(target - pred)
l1_loss = abs_diff.mean(1, True)
if self.opt.no_ssim:
reprojection_loss = l1_loss
else:
ssim_loss = self.ssim(pred, target).mean(1, True)
reprojection_loss = 0.85 * ssim_loss + 0.15 * l1_loss
return reprojection_loss
def compute_losses(self, inputs, outputs):
"""Compute the reprojection and smoothness losses for a minibatch
"""
losses = {}
total_loss = 0
for scale in self.opt.scales:
loss = 0
reprojection_losses = []
if self.opt.v1_multiscale:
source_scale = scale
else:
source_scale = 0
disp = outputs[("disp", scale)]
color = inputs[("color", 0, scale)]
target = inputs[("color", 0, source_scale)]
for frame_id in self.opt.frame_ids[1:]:
pred = outputs[("color", frame_id, scale)]
reprojection_losses.append(self.compute_reprojection_loss(pred, target))
reprojection_losses = torch.cat(reprojection_losses, 1)
if not self.opt.disable_automasking:
identity_reprojection_losses = []
for frame_id in self.opt.frame_ids[1:]:
pred = inputs[("color", frame_id, source_scale)]
identity_reprojection_losses.append(
self.compute_reprojection_loss(pred, target))
identity_reprojection_losses = torch.cat(identity_reprojection_losses, 1)
if self.opt.avg_reprojection:
identity_reprojection_loss = identity_reprojection_losses.mean(1, keepdim=True)
else:
# save both images, and do min all at once below
identity_reprojection_loss = identity_reprojection_losses
elif self.opt.predictive_mask:
# use the predicted mask
mask = outputs["predictive_mask"]["disp", scale]
if not self.opt.v1_multiscale:
mask = F.interpolate(
mask, [self.opt.height, self.opt.width],
mode="bilinear", align_corners=False)
reprojection_losses *= mask
# add a loss pushing mask to 1 (using nn.BCELoss for stability)
weighting_loss = 0.2 * nn.BCELoss()(mask, torch.ones(mask.shape).cuda())
loss += weighting_loss.mean()
if self.opt.avg_reprojection:
reprojection_loss = reprojection_losses.mean(1, keepdim=True)
else:
reprojection_loss = reprojection_losses
if not self.opt.disable_automasking:
# add random numbers to break ties
identity_reprojection_loss += torch.randn(
identity_reprojection_loss.shape, device=self.device) * 0.00001
combined = torch.cat((identity_reprojection_loss, reprojection_loss), dim=1)
else:
combined = reprojection_loss
if combined.shape[1] == 1:
to_optimise = combined
else:
to_optimise, idxs = torch.min(combined, dim=1)
if not self.opt.disable_automasking:
outputs["identity_selection/{}".format(scale)] = (
idxs > identity_reprojection_loss.shape[1] - 1).float()
loss += to_optimise.mean()
mean_disp = disp.mean(2, True).mean(3, True)
norm_disp = disp / (mean_disp + 1e-7)
smooth_loss = get_smooth_loss(norm_disp, color)
loss += self.opt.disparity_smoothness * smooth_loss / (2 ** scale)
total_loss += loss
losses["loss/{}".format(scale)] = loss
total_loss /= self.num_scales
losses["loss"] = total_loss
return losses
def compute_depth_losses(self, inputs, outputs, losses):
"""Compute depth metrics, to allow monitoring during training
This isn't particularly accurate as it averages over the entire batch,
so is only used to give an indication of validation performance
"""
depth_pred = outputs[("depth", 0, 0)]
depth_pred = torch.clamp(F.interpolate(
depth_pred, [375, 1242], mode="bilinear", align_corners=False), 1e-3, 80)
depth_pred = depth_pred.detach()
depth_gt = inputs["depth_gt"]
mask = depth_gt > 0
# garg/eigen crop
crop_mask = torch.zeros_like(mask)
crop_mask[:, :, 153:371, 44:1197] = 1
mask = mask * crop_mask
depth_gt = depth_gt[mask]
depth_pred = depth_pred[mask]
depth_pred *= torch.median(depth_gt) / torch.median(depth_pred)
depth_pred = torch.clamp(depth_pred, min=1e-3, max=80)
depth_errors = compute_depth_errors(depth_gt, depth_pred)
for i, metric in enumerate(self.depth_metric_names):
losses[metric] = np.array(depth_errors[i].cpu())
def log_time(self, batch_idx, duration, loss):
"""Print a logging statement to the terminal
"""
samples_per_sec = self.opt.batch_size / duration
time_sofar = time.time() - self.start_time
training_time_left = (
self.num_total_steps / self.step - 1.0) * time_sofar if self.step > 0 else 0
print_string = "epoch {:>3} | batch {:>6} | examples/s: {:5.1f}" + \
" | loss: {:.5f} | time elapsed: {} | time left: {}"
print(print_string.format(self.epoch, batch_idx, samples_per_sec, loss,
sec_to_hm_str(time_sofar), sec_to_hm_str(training_time_left)))
def log(self, mode, inputs, outputs, losses):
"""Write an event to the tensorboard events file
"""
writer = self.writers[mode]
for l, v in losses.items():
writer.add_scalar("{}".format(l), v, self.step)
for j in range(min(4, self.opt.batch_size)): # write a maxmimum of four images
for s in self.opt.scales:
for frame_id in self.opt.frame_ids:
writer.add_image(
"color_{}_{}/{}".format(frame_id, s, j),
inputs[("color", frame_id, s)][j].data, self.step)
if s == 0 and frame_id != 0:
writer.add_image(
"color_pred_{}_{}/{}".format(frame_id, s, j),
outputs[("color", frame_id, s)][j].data, self.step)
writer.add_image(
"disp_{}/{}".format(s, j),
normalize_image(outputs[("disp", s)][j]), self.step)
if self.opt.predictive_mask:
for f_idx, frame_id in enumerate(self.opt.frame_ids[1:]):
writer.add_image(
"predictive_mask_{}_{}/{}".format(frame_id, s, j),
outputs["predictive_mask"][("disp", s)][j, f_idx][None, ...],
self.step)
elif not self.opt.disable_automasking:
writer.add_image(
"automask_{}/{}".format(s, j),
outputs["identity_selection/{}".format(s)][j][None, ...], self.step)
def save_opts(self):
"""Save options to disk so we know what we ran this experiment with
"""
models_dir = os.path.join(self.log_path, "models")
if not os.path.exists(models_dir):
os.makedirs(models_dir)
to_save = self.opt.__dict__.copy()
with open(os.path.join(models_dir, 'opt.json'), 'w') as f:
json.dump(to_save, f, indent=2)
def save_model(self):
"""Save model weights to disk
"""
save_folder = os.path.join(self.log_path, "models", "weights_{}".format(self.epoch))
if not os.path.exists(save_folder):
os.makedirs(save_folder)
for model_name, model in self.models.items():
save_path = os.path.join(save_folder, "{}.pth".format(model_name))
to_save = model.state_dict()
if model_name == 'encoder':
# save the sizes - these are needed at prediction time
to_save['height'] = self.opt.height
to_save['width'] = self.opt.width
to_save['use_stereo'] = self.opt.use_stereo
torch.save(to_save, save_path)
save_path = os.path.join(save_folder, "{}.pth".format("adam"))
torch.save(self.model_optimizer.state_dict(), save_path)
def load_model(self):
"""Load model(s) from disk
"""
self.opt.load_weights_folder = os.path.expanduser(self.opt.load_weights_folder)
assert os.path.isdir(self.opt.load_weights_folder), \
"Cannot find folder {}".format(self.opt.load_weights_folder)
print("loading model from folder {}".format(self.opt.load_weights_folder))
for n in self.opt.models_to_load:
print("Loading {} weights...".format(n))
path = os.path.join(self.opt.load_weights_folder, "{}.pth".format(n))
model_dict = self.models[n].state_dict()
pretrained_dict = torch.load(path)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
self.models[n].load_state_dict(model_dict)
# loading adam state
optimizer_load_path = os.path.join(self.opt.load_weights_folder, "adam.pth")
if os.path.isfile(optimizer_load_path):
print("Loading Adam weights")
optimizer_dict = torch.load(optimizer_load_path)
self.model_optimizer.load_state_dict(optimizer_dict)
else:
print("Cannot find Adam weights so Adam is randomly initialized")