Skip to content

Commit

Permalink
[bug] adding breakout for if there's only 1 label in the batch, retur…
Browse files Browse the repository at this point in the history
…n contrastive loss
  • Loading branch information
ctr26 committed Aug 14, 2024
1 parent 1245191 commit 584b135
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions bioimage_embed/lightning/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ def loss_function(self, model_output, batch_idx):
# Scale is used as the rest of the loss functions are sums rather than means, which may mean we need to scale up the contrastive loss

scale = torch.prod(torch.tensor(model_output.z.shape[1:]))
if model_output.target.unique().size(0) == 1:
return loss
pairs = create_label_based_pairs(model_output.z.squeeze(), model_output.target)
contrastive_loss = self.criteron(*pairs)
loss["contrastive_loss"] = scale * contrastive_loss
Expand Down

0 comments on commit 584b135

Please sign in to comment.