Skip to content

Commit

Permalink
Merge pull request #27 from lightonai/constrastive-loss-label
Browse files Browse the repository at this point in the history
add-labels
  • Loading branch information
raphaelsty authored Aug 8, 2024
2 parents da24370 + 0d3ec82 commit f2c3e83
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions giga_cherche/losses/contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def __init__(
self.size_average = size_average

def forward(
self, sentence_features: Iterable[dict[str, Tensor]], **kwargs
self,
sentence_features: Iterable[dict[str, Tensor]],
labels: torch.Tensor | None = None,
) -> torch.Tensor:
"""Compute the Constrastive loss.
Expand Down Expand Up @@ -157,5 +159,7 @@ def forward(
# compute constrastive loss using cross-entropy over the distances

return F.cross_entropy(
distances, labels, reduction="mean" if self.size_average else "sum"
input=distances,
target=labels,
reduction="mean" if self.size_average else "sum",
)

0 comments on commit f2c3e83

Please sign in to comment.