-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_combo_supervision.py
362 lines (270 loc) · 14.6 KB
/
train_combo_supervision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
from __future__ import print_function, division
import sys
sys.path.append('core')
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import argparse
import os
import random
import time
import numpy as np
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from combo_model import RAFT
from utils.utils import InputPadder, coords_grid, bilinear_sampler
import evaluate
import datasets
from torch.utils.tensorboard import SummaryWriter
print('CUDA AVAILABLE ', torch.cuda.is_available())
try:
from torch.cuda.amp import GradScaler
except:
# dummy GradScaler for PyTorch < 1.6
class GradScaler:
def __init__(self):
pass
def scale(self, loss):
return lossP
def unscale_(self, optimizer):
pass
def step(self, optimizer):
optimizer.step()
def update(self):
pass
# exclude extremly large displacements
MAX_FLOW = 400
SUM_FREQ = 100
VAL_FREQ = 1
def sequence_loss(epoch, flow_preds_p, flow_preds_a, uncert, flow_gt, wp_star, wa_star, gt_uncert, valid, gamma=0.8, max_flow=MAX_FLOW):
""" Loss function defined over sequence of flow predictions """
# uncert : list 12 of torch.Size([10, 1, 368, 496])
# gt_uncert: torch.Size([10, 368, 496])
scheduled_sampling = np.maximum(0.01 , 1 - epoch/50)
n_predictions = len(flow_preds_p)
final_flow_loss, wp_flow_loss, wa_flow_loss, uncert_loss = 0.0, 0.0, 0.0, 0.0
if len(uncert) == 1: # GT uncertainty
uncert = [ uncert[0] for i in range(n_predictions)]
# exclude invalid pixels and extremely large diplacements
mag = torch.sum(flow_gt**2, dim=1).sqrt()
valid = (valid >= 0.5) & (mag < max_flow)
for i in range(n_predictions):
i_weight = gamma**(n_predictions - i - 1)
# scheduled_sampling -> 0. At beginning, use GT uncert, then rely more and more on learned uncert
if random.random() < scheduled_sampling:
uncertainty_b2wh = gt_uncert.unsqueeze(1)
uncertainty_b2wh = uncertainty_b2wh.repeat(1,2,1,1)
else:
uncertainty_b2wh = uncert[i].repeat(1,2,1,1)
# uncertainty loss
temp = uncert[i]
i_loss = (temp[:,0,:,:]-gt_uncert).abs()
uncert_loss += i_weight * (valid[:, None] * i_loss).mean()
# wp flow loss
i_loss = (flow_preds_p[i]-wp_star).abs()
wp_flow_loss += i_weight * (valid[:, None] * i_loss).mean()
# wa flow loss
i_loss = (flow_preds_a[i]-wa_star).abs()
wa_flow_loss += i_weight * (valid[:, None] * i_loss).mean()
# Final flow loss
flow_pred = (1-uncertainty_b2wh) * flow_preds_p[i] + uncertainty_b2wh * flow_preds_a[i]
i_loss = (flow_pred - flow_gt).abs()
final_flow_loss += i_weight * (valid[:, None] * i_loss).mean()
flow_pred = (1-uncert[i]) * flow_preds_p[-1] + uncert[i] * flow_preds_a[-1]
epe = torch.sum((flow_pred - flow_gt)**2, dim=1).sqrt()
epe = epe.view(-1)[valid.view(-1)]
metrics = {
'epe': epe.mean().item(),
'1px': (epe < 1).float().mean().item(),
'3px': (epe < 3).float().mean().item(),
'5px': (epe < 5).float().mean().item(),
}
return final_flow_loss, wp_flow_loss, wa_flow_loss, uncert_loss, metrics
def photometric_sequence_loss(flow_preds_p, image1, image2, uncert, gamma=0.8):
""" Loss function defined over sequence of flow predictions """
# flow_preds[-1], GT : [batch, 2, W, H]
n_predictions = len(flow_preds_p)
photometric_loss = 0.0
criterion = nn.L1Loss()
if len(uncert) == 1: # GT uncertainty
uncert = [ uncert[0] for i in range(n_predictions)]
for i in range(n_predictions):
coords2 = coords_grid(image2.shape[0], image2.shape[2], image2.shape[3]).cuda()
coords21 = coords2 + flow_preds_p[i]
im21 = bilinear_sampler(image2, coords21.permute(0,2,3,1))
i_weight = gamma**(n_predictions - i - 1)
uncertainty_b3wh = uncert[i].repeat(1,3,1,1)
photometric_loss += i_weight * criterion( (1-uncert[i])*im21, (1-uncert[i])*image1)
return photometric_loss
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
#scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
# pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=10,factor=0.5,verbose=True)
return optimizer, scheduler
class Logger:
def __init__(self, model, scheduler):
self.model = model
self.scheduler = scheduler
self.total_steps = 0
self.running_loss = {}
self.writer = None
self.T0 = time.time()
def _print_training_status(self):
metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
# print the training status
print(training_str + metrics_str, ' time = ', time.time()-self.T0)
self.T0 = time.time()
if self.writer is None:
self.writer = SummaryWriter()
for k in self.running_loss:
self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)
self.running_loss[k] = 0.0
def push(self, metrics):
self.total_steps += 1
for key in metrics:
if key not in self.running_loss:
self.running_loss[key] = 0.0
self.running_loss[key] += metrics[key]
if self.total_steps % SUM_FREQ == SUM_FREQ-1:
self._print_training_status()
self.running_loss = {}
def write_dict(self, step, results):
if self.writer is None:
self.writer = SummaryWriter()
for key in results:
self.writer.add_scalar(key, results[key], step)
def close(self):
self.writer.close()
def train(args):
model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
print("Parameter Count: %d" % count_parameters(model))
if args.restore_ckpt is not None:
model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
model.cuda()
model.train()
if args.stage != 'chairs':
model.module.freeze_bn()
train_loader = datasets.fetch_dataloader(args)
optimizer, scheduler = fetch_optimizer(args, model)
total_steps = 0
epoch = 0
scaler = GradScaler(enabled=args.mixed_precision)
logger = Logger(model, scheduler)
VAL_FREQ = 5000
add_noise = True
should_keep_training = True
while should_keep_training:
t0_epoch = time.time()
ep_loss_total, ep_loss_final, ep_loss_wp, ep_loss_wa, ep_loss_uncert, ep_loss_photo, ep_loss_wpa = 0,0,0,0,0,0,0
for i_batch, data_blob in enumerate(train_loader):
t0 = time.time()
optimizer.zero_grad()
image1, image2, flow, wp_star, wa_star, gt_uncert, valid ,_ = [x.cuda() for x in data_blob]
#print('image1 ', image1.shape, ' flow ',flow.shape, wp_star.shape, wa_star.shape, gt_uncert.shape)
if args.add_noise:
stdv = np.random.uniform(0.0, 5.0)
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
## MODEL PREDICTION
flow_predictions_p, flow_predictions_a, uncertainty_masks = model(image1, image2, iters=args.iters)
#print('flow_predictions_p ', len(flow_predictions_p), flow_predictions_p[0].shape, ' uncert ', len(uncertainty_masks), uncertainty_masks[0].shape)
## SUPERVISED LOSS
final_flow_loss, wp_flow_loss, wa_flow_loss, uncert_loss, metrics = sequence_loss(epoch, flow_predictions_p, flow_predictions_a, uncertainty_masks, flow, wp_star, wa_star, gt_uncert, valid, args.gamma)
## NORM Wpa
norm_wpa = 0
for k in range(len(flow_predictions_a)):
norm_wpa += torch.norm(flow_predictions_a[k]) / len(flow_predictions_a)
norm_wpa += torch.norm(flow_predictions_p[k]) / len(flow_predictions_p)
## PHOTOMETRIC LOSS
photometric_loss = photometric_sequence_loss(flow_predictions_p, image1, image2, uncertainty_masks, gamma=0.8)
loss = args.lambda_final * final_flow_loss + args.lambda_wp * wp_flow_loss + args.lambda_wa * wa_flow_loss + args.lambda_uncert * uncert_loss + args.lambda_photo * photometric_loss + args.lambda_wpa * norm_wpa
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
ep_loss_total += loss.item()
ep_loss_final += final_flow_loss.item()
ep_loss_wp += wp_flow_loss.item()
ep_loss_wa += wa_flow_loss.item()
ep_loss_uncert += uncert_loss.item()
ep_loss_photo += photometric_loss.item()
ep_loss_wpa += norm_wpa.item()
scaler.step(optimizer)
scaler.update()
if total_steps % VAL_FREQ == VAL_FREQ - 1:
#PATH = 'checkpoints/%s.pth' % (args.name)
PATH = '/raid/F07773/optical_flow/checkpoints/%s_%d.pth' % (args.name, total_steps+1)
torch.save(model.state_dict(), PATH)
results = {}
for val_dataset in args.validation:
if val_dataset == 'chairs':
res = evaluate_gt.validate_chairs(model.module)
scheduler.step(res['chairs'])
results.update(res)
elif val_dataset == 'sintel':
results.update(evaluate.validate_sintel(model.module))
elif val_dataset == 'kitti':
results.update(evaluate.validate_kitti(model.module))
elif val_dataset == 'sintel_resplit':
results.update(evaluate.validate_sintel_resplit(model.module))
elif val_dataset == 'kitti_resplit':
results.update(evaluate.validate_kitti_resplit(model.module))
loss_analysis = {'wp_loss':ep_loss_wp/i_batch, 'wa_loss':ep_loss_wa/i_batch, 'photo':ep_loss_photo/i_batch, 'loss_alpha':ep_loss_uncert/i_batch, 'wpa':ep_loss_wpa/i_batch}
results.update(loss_analysis)
logger.write_dict(total_steps+1, results)
model.train()
if args.stage != 'chairs':
model.module.freeze_bn()
total_steps += 1
if total_steps > args.num_steps:
should_keep_training = False
break
#print('ibatch ', i_batch , '/', len(train_loader), ' time ',time.time()-t0)
if (total_steps % 500 == 0) & (total_steps > 1):
print('epoch ', epoch, i_batch,'/', len(train_loader), 'step ', total_steps, ' time ',time.time()-t0, ' loss ',ep_loss_total/i_batch,' final ',ep_loss_final/i_batch,' wp ',ep_loss_wp/i_batch, 'wa ',ep_loss_wa/i_batch, ' photo ', ep_loss_photo/i_batch,' loss uncert ',ep_loss_uncert/i_batch, ' wpa ',ep_loss_wpa/i_batch, 'certainty ', torch.max(uncertainty_masks[-1]).item(),torch.mean(uncertainty_masks[-1]).item(), 'gt uncert ',torch.max(gt_uncert),torch.mean(gt_uncert))
#print('EPOCH ', epoch, ' time = ',time.time()-t0_epoch)
epoch = epoch + 1
logger.close()
PATH = '/raid/F07773/optical_flow/checkpoints/%s.pth' % args.name
torch.save(model.state_dict(), PATH)
return PATH
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='raft', help="name your experiment")
parser.add_argument('--stage', help="determines which dataset to use for training")
parser.add_argument('--restore_ckpt', help="restore checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--validation', type=str, nargs='+')
parser.add_argument('--lr', type=float, default=0.00002)
parser.add_argument('--num_steps', type=int, default=250000)
parser.add_argument('--batch_size', type=int, default=6)
parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--iters', type=int, default=12)
parser.add_argument('--wdecay', type=float, default=.00005)
parser.add_argument('--epsilon', type=float, default=1e-8)
parser.add_argument('--clip', type=float, default=1.0)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
parser.add_argument('--add_noise', action='store_true')
parser.add_argument('--lambda_final', default=1, type=float, help='')
parser.add_argument('--lambda_wp', default=1, type=float, help='')
parser.add_argument('--lambda_wa', default=1, type=float, help='')
parser.add_argument('--lambda_uncert', default=0.1, type=float, help='')
parser.add_argument('--lambda_photo', default=0.1, type=float, help='')
parser.add_argument('--lambda_wpa', default=0.01, type=float, help='')
args = parser.parse_args()
torch.manual_seed(1234)
np.random.seed(1234)
if not os.path.isdir('checkpoints'):
os.mkdir('checkpoints')
train(args)