-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathlosses.py
261 lines (203 loc) · 8.41 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import kornia
import torch
import torch.jit as jit
import torch.nn.functional as F
from torch import Tensor, nn
from utils.generic_utils import pyrdown
from utils.geometry_utils import BackprojectDepth, Project3D
class StableBCELogitsLoss(nn.Module):
def __init__(self, pos_weight, clip_value=100.0):
super(StableBCELogitsLoss, self).__init__()
self.bce_loss = nn.BCEWithLogitsLoss(reduction="none", pos_weight=pos_weight)
self.clip_value = clip_value
# for debugging!
self.num_outside_range = None
def __call__(self, pred, target):
offset = torch.zeros_like(pred)
zeros_mask = pred > self.clip_value
offset[zeros_mask] = self.clip_value - pred[zeros_mask]
ones_mask = pred < -self.clip_value
offset[ones_mask] = -self.clip_value - pred[ones_mask]
pred = pred + offset
loss = self.bce_loss(pred, target)
self.num_outside_range = (offset != 0).sum()
return loss
class StableBCELoss(nn.Module):
def __init__(self, positive_weight, epsilon=1e-4):
super(StableBCELoss, self).__init__()
self.sigmoid = nn.Sigmoid()
self.positive_weight = positive_weight
self.epsilon = epsilon
self.bce = nn.BCELoss(reduction="none")
# for debugging!
self.num_outside_range = None
def __call__(self, pred, target):
pred = self.sigmoid(pred)
offset = torch.zeros_like(pred)
zeros_mask = pred < self.epsilon
offset[zeros_mask] = self.epsilon - pred[zeros_mask]
ones_mask = pred > 1 - self.epsilon
offset[ones_mask] = 1 - self.epsilon - pred[ones_mask]
pred = pred + offset
weights = torch.ones_like(target)
weights[target == 1] = self.positive_weight
loss = self.bce(pred, target)
loss = loss * weights
self.num_outside_range = ones_mask.sum() + zeros_mask.sum()
return loss
class BinaryL1Loss(nn.Module):
def __init__(self, positive_weight):
super(BinaryL1Loss, self).__init__()
self.sigmoid = nn.Sigmoid()
self.positive_weight = positive_weight
def forward(self, prediction, target):
prediction = self.sigmoid(prediction)
loss = torch.abs(prediction - target)
weighting = torch.ones_like(loss)
weighting[target == 1] = self.positive_weight
loss = loss * weighting
return loss.mean()
class MSGradientLoss(nn.Module):
def __init__(self, num_scales: int = 4):
super().__init__()
self.num_scales = num_scales
def forward(self, depth_gt: Tensor, depth_pred: Tensor) -> Tensor:
# Create the gradient pyramids
depth_pred_pyr = pyrdown(depth_pred, self.num_scales)
depth_gtn_pyr = pyrdown(depth_gt, self.num_scales)
grad_loss = torch.tensor(0, dtype=depth_gt.dtype, device=depth_gt.device)
for depth_pred_down, depth_gtn_down in zip(depth_pred_pyr, depth_gtn_pyr):
depth_gtn_grad = kornia.filters.spatial_gradient(depth_gtn_down)
mask_down_b = depth_gtn_grad.isfinite().all(dim=1, keepdim=True)
depth_pred_grad = kornia.filters.spatial_gradient(depth_pred_down).masked_select(
mask_down_b
)
grad_error = torch.abs(depth_pred_grad - depth_gtn_grad.masked_select(mask_down_b))
grad_loss += torch.mean(grad_error)
return grad_loss
class ScaleInvariantLoss(jit.ScriptModule):
def __init__(self, si_lambda: float = 0.85):
super().__init__()
self.si_lambda = si_lambda
@jit.script_method
def forward(self, log_depth_gt: Tensor, log_depth_pred: Tensor) -> Tensor:
# Scale invariant loss from Eigen, implementation is from AdaBins
log_diff = log_depth_gt - log_depth_pred
si_loss = torch.sqrt((log_diff**2).mean() - self.si_lambda * (log_diff.mean() ** 2))
return si_loss
class NormalsLoss(nn.Module):
def forward(self, normals_gt_b3hw: Tensor, normals_pred_b3hw: Tensor) -> Tensor:
normals_mask_b1hw = torch.logical_and(
normals_gt_b3hw.isfinite().all(dim=1, keepdim=True),
normals_pred_b3hw.isfinite().all(dim=1, keepdim=True),
)
normals_pred_b3hw = normals_pred_b3hw.masked_fill(~normals_mask_b1hw, 1.0)
normals_gt_b3hw = normals_gt_b3hw.masked_fill(~normals_mask_b1hw, 1.0)
with torch.cuda.amp.autocast(enabled=False):
normals_dot_b1hw = 0.5 * (
1.0
- torch.einsum(
"bchw, bchw -> bhw",
normals_pred_b3hw,
normals_gt_b3hw,
)
).unsqueeze(1)
normals_loss = normals_dot_b1hw.masked_select(normals_mask_b1hw).mean()
return normals_loss
class MVDepthLoss(nn.Module):
def __init__(self, height, width):
super().__init__()
self.height = height
self.width = width
self.backproject = BackprojectDepth(self.height, self.width)
self.project = Project3D()
def get_valid_mask(
self,
cur_depth_b1hw,
src_depth_b1hw,
cur_invK_b44,
src_K_b44,
cur_world_T_cam_b44,
src_cam_T_world_b44,
):
depth_height, depth_width = cur_depth_b1hw.shape[2:]
cur_cam_points_b4N = self.backproject(cur_depth_b1hw, cur_invK_b44)
world_points_b4N = cur_world_T_cam_b44 @ cur_cam_points_b4N
# Compute valid mask
src_cam_points_b3N = self.project(world_points_b4N, src_K_b44, src_cam_T_world_b44)
cam_points_b3hw = src_cam_points_b3N.view(-1, 3, depth_height, depth_width)
pix_coords_b2hw = cam_points_b3hw[:, :2]
proj_src_depths_b1hw = cam_points_b3hw[:, 2:]
uv_coords = pix_coords_b2hw.permute(0, 2, 3, 1) / torch.tensor(
[depth_width, depth_height]
).view(1, 1, 1, 2).type_as(pix_coords_b2hw)
uv_coords = 2 * uv_coords - 1
src_depth_sampled_b1hw = F.grid_sample(
input=src_depth_b1hw,
grid=uv_coords,
padding_mode="zeros",
mode="nearest",
align_corners=False,
)
valid_mask_b1hw = proj_src_depths_b1hw < 1.05 * src_depth_sampled_b1hw
valid_mask_b1hw = torch.logical_and(valid_mask_b1hw, proj_src_depths_b1hw > 0)
valid_mask_b1hw = torch.logical_and(valid_mask_b1hw, src_depth_sampled_b1hw > 0)
return valid_mask_b1hw, src_depth_sampled_b1hw
def get_error_for_pair(
self,
depth_pred_b1hw,
cur_depth_b1hw,
src_depth_b1hw,
cur_invK_b44,
src_K_b44,
cur_world_T_cam_b44,
src_cam_T_world_b44,
):
depth_height, depth_width = cur_depth_b1hw.shape[2:]
valid_mask_b1hw, src_depth_sampled_b1hw = self.get_valid_mask(
cur_depth_b1hw,
src_depth_b1hw,
cur_invK_b44,
src_K_b44,
cur_world_T_cam_b44,
src_cam_T_world_b44,
)
pred_cam_points_b4N = self.backproject(depth_pred_b1hw, cur_invK_b44)
pred_world_points_b4N = cur_world_T_cam_b44 @ pred_cam_points_b4N
src_cam_points_b3N = self.project(pred_world_points_b4N, src_K_b44, src_cam_T_world_b44)
pred_cam_points_b3hw = src_cam_points_b3N.view(-1, 3, depth_height, depth_width)
pred_src_depths_b1hw = pred_cam_points_b3hw[:, 2:]
depth_diff_b1hw = torch.abs(
torch.log(src_depth_sampled_b1hw) - torch.log(pred_src_depths_b1hw)
).masked_select(valid_mask_b1hw)
depth_loss = depth_diff_b1hw.nanmean()
return depth_loss
def forward(
self,
depth_pred_b1hw,
cur_depth_b1hw,
src_depth_bk1hw,
cur_invK_b44,
src_K_bk44,
cur_world_T_cam_b44,
src_cam_T_world_bk44,
):
src_to_iterate = [
torch.unbind(src_depth_bk1hw, dim=1),
torch.unbind(src_K_bk44, dim=1),
torch.unbind(src_cam_T_world_bk44, dim=1),
]
num_src_frames = src_depth_bk1hw.shape[1]
loss = 0
for src_depth_b1hw, src_K_b44, src_cam_T_world_b44 in zip(*src_to_iterate):
error = self.get_error_for_pair(
depth_pred_b1hw,
cur_depth_b1hw,
src_depth_b1hw,
cur_invK_b44,
src_K_b44,
cur_world_T_cam_b44,
src_cam_T_world_b44,
)
loss += error
return loss / num_src_frames