-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathloss.py
79 lines (61 loc) · 2.32 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# coding: utf-8
"""
Module to implement training loss
"""
from torch import nn, Tensor
class RegLoss(nn.Module):
"""
Regression Loss
"""
def __init__(self, cfg, target_pad=0.0):
super(RegLoss, self).__init__()
self.loss = cfg["training"]["loss"].lower()
if self.loss == "l1":
self.criterion = nn.L1Loss()
elif self.loss == "mse":
self.criterion = nn.MSELoss()
else:
print("Loss not found - revert to default L1 loss")
self.criterion = nn.L1Loss()
model_cfg = cfg["model"]
self.target_pad = target_pad
self.loss_scale = model_cfg.get("loss_scale", 1.0)
# pylint: disable=arguments-differ
def forward(self, preds, targets):
loss_mask = (targets != self.target_pad)
# Find the masked predictions and targets using loss mask
preds_masked = preds * loss_mask
targets_masked = targets * loss_mask
# Calculate loss just over the masked predictions
loss = self.criterion(preds_masked, targets_masked)
# Multiply loss by the loss scale
if self.loss_scale != 1.0:
loss = loss * self.loss_scale
return loss
class XentLoss(nn.Module):
"""
Cross-Entropy Loss with optional label smoothing
"""
def __init__(self, pad_index: int, smoothing: float = 0.0):
super(XentLoss, self).__init__()
self.smoothing = smoothing
self.pad_index = pad_index
# standard xent loss
self.criterion = nn.NLLLoss(ignore_index=self.pad_index,
reduction='sum')
# pylint: disable=arguments-differ
def forward(self, log_probs, targets):
"""
Compute the cross-entropy between logits and targets.
If label smoothing is used, target distributions are not one-hot, but
"1-smoothing" for the correct target token and the rest of the
probability mass is uniformly spread across the other tokens.
:param log_probs: log probabilities as predicted by model
:param targets: target indices
:return:
"""
# targets: indices with batch*seq_len
targets = targets.contiguous().view(-1)
loss = self.criterion(
log_probs.contiguous().view(-1, log_probs.size(-1)), targets)
return loss