-
Notifications
You must be signed in to change notification settings - Fork 206
/
Copy pathmultiscaleloss.py
64 lines (49 loc) · 2.14 KB
/
multiscaleloss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn.functional as F
def EPE(input_flow, target_flow, sparse=False, mean=True):
EPE_map = torch.norm(target_flow - input_flow, 2, 1)
batch_size = EPE_map.size(0)
if sparse:
# invalid flow is defined with both flow coordinates to be exactly 0
mask = (target_flow[:, 0] == 0) & (target_flow[:, 1] == 0)
EPE_map = EPE_map[~mask]
if mean:
return EPE_map.mean()
else:
return EPE_map.sum() / batch_size
def sparse_max_pool(input, size):
"""Downsample the input by considering 0 values as invalid.
Unfortunately, no generic interpolation mode can resize a sparse map correctly,
the strategy here is to use max pooling for positive values and "min pooling"
for negative values, the two results are then summed.
This technique allows sparsity to be minized, contrary to nearest interpolation,
which could potentially lose information for isolated data points."""
positive = (input > 0).float()
negative = (input < 0).float()
output = F.adaptive_max_pool2d(input * positive, size) - F.adaptive_max_pool2d(
-input * negative, size
)
return output
def multiscaleEPE(network_output, target_flow, weights=None, sparse=False):
def one_scale(output, target, sparse):
b, _, h, w = output.size()
if sparse:
target_scaled = sparse_max_pool(target, (h, w))
else:
target_scaled = F.interpolate(target, (h, w), mode="area")
return EPE(output, target_scaled, sparse, mean=False)
if type(network_output) not in [tuple, list]:
network_output = [network_output]
if weights is None:
weights = [0.005, 0.01, 0.02, 0.08, 0.32] # as in original article
assert len(weights) == len(network_output)
loss = 0
for output, weight in zip(network_output, weights):
loss += weight * one_scale(output, target_flow, sparse)
return loss
def realEPE(output, target, sparse=False):
b, _, h, w = target.size()
upsampled_output = F.interpolate(
output, (h, w), mode="bilinear", align_corners=False
)
return EPE(upsampled_output, target, sparse, mean=True)