-
Notifications
You must be signed in to change notification settings - Fork 196
/
Copy pathloss_function.py
executable file
·137 lines (114 loc) · 4.6 KB
/
loss_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from torch import nn
import torch
import torch.nn.functional as F
from torch.autograd import Variable
def cross_entropy_3D(input, target, weight=None, size_average=True):
n, c, h, w, s = input.size()
log_p = F.log_softmax(input, dim=1)
log_p = log_p.transpose(1, 2).transpose(2, 3).transpose(3, 4).contiguous().view(-1, c)
target = target.view(target.numel())
loss = F.nll_loss(log_p, target, weight=weight, size_average=False)
if size_average:
loss /= float(target.numel())
return loss
class Binary_Loss(nn.Module):
def __init__(self):
super(Binary_Loss, self).__init__()
self.criterion = nn.BCEWithLogitsLoss()
def forward(self, model_output, targets):
#targets[targets == 0] = -1
# torch.empty(3, dtype=torch.long)
# model_output = model_output.long()
# targets = targets.long()
# print(model_output)
# print(F.sigmoid(model_output))
# print(targets)
# print('kkk')
# model_output =torch.LongTensor(model_output.cpu())
# targets =torch.LongTensor(targets.cpu())
# model_output = model_output.type(torch.LongTensor)
# targets = targets.type(torch.LongTensor)
loss = self.criterion(model_output, targets)
return loss
def make_one_hot(input, num_classes):
"""Convert class index tensor to one hot encoding tensor.
Args:
input: A tensor of shape [N, 1, *]
num_classes: An int of number of class
Returns:
A tensor of shape [N, num_classes, *]
"""
shape = np.array(input.shape)
shape[1] = num_classes
shape = tuple(shape)
result = torch.zeros(shape)
result = result.scatter_(1, input.cpu(), 1)
return result
class BinaryDiceLoss(nn.Module):
"""Dice loss of binary class
Args:
smooth: A float number to smooth loss, and avoid NaN error, default: 1
p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
predict: A tensor of shape [N, *]
target: A tensor of shape same with predict
reduction: Reduction method to apply, return mean over batch if 'mean',
return sum if 'sum', return a tensor of shape [N,] if 'none'
Returns:
Loss tensor according to arg reduction
Raise:
Exception if unexpected reduction
"""
def __init__(self, smooth=1, p=2, reduction='mean'):
super(BinaryDiceLoss, self).__init__()
self.smooth = smooth
self.p = p
self.reduction = reduction
def forward(self, predict, target):
assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
predict = predict.contiguous().view(predict.shape[0], -1)
target = target.contiguous().view(target.shape[0], -1)
num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth
den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth
loss = 1 - num / den
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
elif self.reduction == 'none':
return loss
else:
raise Exception('Unexpected reduction {}'.format(self.reduction))
class DiceLoss(nn.Module):
def __init__(self, n_classes):
super(DiceLoss, self).__init__()
self.n_classes = n_classes
def _one_hot_encoder(self, input_tensor):
tensor_list = []
for i in range(self.n_classes):
temp_prob = input_tensor == i * torch.ones_like(input_tensor)
tensor_list.append(temp_prob)
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()
def _dice_loss(self, score, target):
target = target.float()
smooth = 1e-5
intersect = torch.sum(score * target)
y_sum = torch.sum(target * target)
z_sum = torch.sum(score * score)
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss
def forward(self, inputs, target, weight=None, softmax=False):
if softmax:
inputs = torch.softmax(inputs, dim=1)
target = self._one_hot_encoder(target)
if weight is None:
weight = [1] * self.n_classes
assert inputs.size() == target.size(), 'predict & target shape do not match'
class_wise_dice = []
loss = 0.0
for i in range(0, self.n_classes):
dice = self._dice_loss(inputs[:, i], target[:, i])
class_wise_dice.append(1.0 - dice.item())
loss += dice * weight[i]
return loss / self.n_classes