forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lovasz_loss.py
323 lines (285 loc) · 11.9 KB
/
lovasz_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
Berman 2018 ESAT-PSI KU Leuven (MIT License)"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.utils import is_list_of
from mmseg.registry import MODELS
from .utils import get_class_weight, weight_reduce_loss
def lovasz_grad(gt_sorted):
"""Computes gradient of the Lovasz extension w.r.t sorted errors.
See Alg. 1 in paper.
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1. - intersection / union
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
def flatten_binary_logits(logits, labels, ignore_index=None):
"""Flattens predictions in the batch (binary case) Remove labels equal to
'ignore_index'."""
logits = logits.view(-1)
labels = labels.view(-1)
if ignore_index is None:
return logits, labels
valid = (labels != ignore_index)
vlogits = logits[valid]
vlabels = labels[valid]
return vlogits, vlabels
def flatten_probs(probs, labels, ignore_index=None):
"""Flattens predictions in the batch."""
if probs.dim() == 3:
# assumes output of a sigmoid layer
B, H, W = probs.size()
probs = probs.view(B, 1, H, W)
B, C, H, W = probs.size()
probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C
labels = labels.view(-1)
if ignore_index is None:
return probs, labels
valid = (labels != ignore_index)
vprobs = probs[valid.nonzero().squeeze()]
vlabels = labels[valid]
return vprobs, vlabels
def lovasz_hinge_flat(logits, labels):
"""Binary Lovasz hinge loss.
Args:
logits (torch.Tensor): [P], logits at each prediction
(between -infty and +infty).
labels (torch.Tensor): [P], binary ground truth labels (0 or 1).
Returns:
torch.Tensor: The calculated loss.
"""
if len(labels) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
signs = 2. * labels.float() - 1.
errors = (1. - logits * signs)
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), grad)
return loss
def lovasz_hinge(logits,
labels,
classes='present',
per_image=False,
class_weight=None,
reduction='mean',
avg_factor=None,
ignore_index=255):
"""Binary Lovasz hinge loss.
Args:
logits (torch.Tensor): [B, H, W], logits at each pixel
(between -infty and +infty).
labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).
classes (str | list[int], optional): Placeholder, to be consistent with
other loss. Default: None.
per_image (bool, optional): If per_image is True, compute the loss per
image instead of per batch. Default: False.
class_weight (list[float], optional): Placeholder, to be consistent
with other loss. Default: None.
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. This parameter only works when per_image is True.
Default: None.
ignore_index (int | None): The label index to be ignored. Default: 255.
Returns:
torch.Tensor: The calculated loss.
"""
if per_image:
loss = [
lovasz_hinge_flat(*flatten_binary_logits(
logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
for logit, label in zip(logits, labels)
]
loss = weight_reduce_loss(
torch.stack(loss), None, reduction, avg_factor)
else:
loss = lovasz_hinge_flat(
*flatten_binary_logits(logits, labels, ignore_index))
return loss
def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):
"""Multi-class Lovasz-Softmax loss.
Args:
probs (torch.Tensor): [P, C], class probabilities at each prediction
(between 0 and 1).
labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).
classes (str | list[int], optional): Classes chosen to calculate loss.
'all' for all classes, 'present' for classes present in labels, or
a list of classes to average. Default: 'present'.
class_weight (list[float], optional): The weight for each class.
Default: None.
Returns:
torch.Tensor: The calculated loss.
"""
if probs.numel() == 0:
# only void pixels, the gradients should be 0
return probs * 0.
C = probs.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
for c in class_to_sum:
fg = (labels == c).float() # foreground for class c
if (classes == 'present' and fg.sum() == 0):
continue
if C == 1:
if len(classes) > 1:
raise ValueError('Sigmoid output possible only with 1 class')
class_pred = probs[:, 0]
else:
class_pred = probs[:, c]
errors = (fg - class_pred).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
if class_weight is not None:
loss *= class_weight[c]
losses.append(loss)
return torch.stack(losses).mean()
def lovasz_softmax(probs,
labels,
classes='present',
per_image=False,
class_weight=None,
reduction='mean',
avg_factor=None,
ignore_index=255):
"""Multi-class Lovasz-Softmax loss.
Args:
probs (torch.Tensor): [B, C, H, W], class probabilities at each
prediction (between 0 and 1).
labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and
C - 1).
classes (str | list[int], optional): Classes chosen to calculate loss.
'all' for all classes, 'present' for classes present in labels, or
a list of classes to average. Default: 'present'.
per_image (bool, optional): If per_image is True, compute the loss per
image instead of per batch. Default: False.
class_weight (list[float], optional): The weight for each class.
Default: None.
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. This parameter only works when per_image is True.
Default: None.
ignore_index (int | None): The label index to be ignored. Default: 255.
Returns:
torch.Tensor: The calculated loss.
"""
if per_image:
loss = [
lovasz_softmax_flat(
*flatten_probs(
prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
classes=classes,
class_weight=class_weight)
for prob, label in zip(probs, labels)
]
loss = weight_reduce_loss(
torch.stack(loss), None, reduction, avg_factor)
else:
loss = lovasz_softmax_flat(
*flatten_probs(probs, labels, ignore_index),
classes=classes,
class_weight=class_weight)
return loss
@MODELS.register_module()
class LovaszLoss(nn.Module):
"""LovaszLoss.
This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
for the optimization of the intersection-over-union measure in neural
networks <https://arxiv.org/abs/1705.08790>`_.
Args:
loss_type (str, optional): Binary or multi-class loss.
Default: 'multi_class'. Options are "binary" and "multi_class".
classes (str | list[int], optional): Classes chosen to calculate loss.
'all' for all classes, 'present' for classes present in labels, or
a list of classes to average. Default: 'present'.
per_image (bool, optional): If per_image is True, compute the loss per
image instead of per batch. Default: False.
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_lovasz'.
"""
def __init__(self,
loss_type='multi_class',
classes='present',
per_image=False,
reduction='mean',
class_weight=None,
loss_weight=1.0,
loss_name='loss_lovasz'):
super().__init__()
assert loss_type in ('binary', 'multi_class'), "loss_type should be \
'binary' or 'multi_class'."
if loss_type == 'binary':
self.cls_criterion = lovasz_hinge
else:
self.cls_criterion = lovasz_softmax
assert classes in ('all', 'present') or is_list_of(classes, int)
if not per_image:
assert reduction == 'none', "reduction should be 'none' when \
per_image is False."
self.classes = classes
self.per_image = per_image
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight)
self._loss_name = loss_name
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
"""Forward function."""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = cls_score.new_tensor(self.class_weight)
else:
class_weight = None
# if multi-class loss, transform logits to probs
if self.cls_criterion == lovasz_softmax:
cls_score = F.softmax(cls_score, dim=1)
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
self.classes,
self.per_image,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_cls
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name