diff --git a/climatenet/utils/losses.py b/climatenet/utils/losses.py index 42e0389..bd96652 100644 --- a/climatenet/utils/losses.py +++ b/climatenet/utils/losses.py @@ -18,7 +18,8 @@ def jaccard_loss(logits, true, eps=1e-7): jacc_loss: the Jaccard loss. """ num_classes = logits.shape[1] - true_1_hot = torch.eye(num_classes)[true.squeeze(1)] + # Keep on same device + true_1_hot = torch.eye(num_classes).to(true.device)[true.squeeze(1)] true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() probas = F.softmax(logits, dim=1) true_1_hot = true_1_hot.type(logits.type())