Skip to content

Commit

Permalink
Fixing ColBERTDistillationEvaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
NohTow committed Aug 9, 2024
1 parent 3553963 commit 96670d6
Showing 1 changed file with 28 additions and 24 deletions.
52 changes: 28 additions & 24 deletions giga_cherche/evaluation/colbert_distillation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import logging
import os
from contextlib import nullcontext
Expand Down Expand Up @@ -90,10 +91,13 @@ def __init__(
show_progress_bar: bool = False,
write_csv: bool = True,
truncate_dim: int | None = None,
normalize_scores: bool = True,
) -> None:
super().__init__()
assert len(queries) == len(documents)
self.queries = queries
self.documents = documents
# Flatten the documents list
self.documents = list(itertools.chain.from_iterable(documents))
self.scores = scores
self.name = name
self.truncate_dim = truncate_dim
Expand All @@ -102,7 +106,7 @@ def __init__(
self.write_csv = write_csv

self.loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)

self.normalize_scores = normalize_scores
if show_progress_bar is None:
show_progress_bar = (
logger.getEffectiveLevel() == logging.INFO
Expand Down Expand Up @@ -145,40 +149,40 @@ def __call__(
)
)

documents_embeddings = model.encode(
self.documents,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
convert_to_tensor=True,
is_query=False,
documents_embeddings = torch.nn.utils.rnn.pad_sequence(
model.encode(
self.documents,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
convert_to_numpy=False,
is_query=False,
),
batch_first=True,
padding_value=0,
)

max_number_tokens = max([batch.shape[1] for batch in documents_embeddings])

documents_embeddings = [
torch.nn.functional.pad(
batch,
(0, 0, 0, max_number_tokens - batch.size(1), 0, 0),
mode="constant",
value=0,
)
for batch in documents_embeddings
]

documents_embeddings = torch.stack(documents_embeddings, dim=0)

documents_embeddings = documents_embeddings.view(
queries_embeddings.size(0), -1, *documents_embeddings.shape[1:]
)
scores = colbert_kd_scores(
queries_embeddings=queries_embeddings,
documents_embeddings=documents_embeddings,
)
if self.normalize_scores:
# Compute max and min along the num_scores dimension (dim=1)
max_scores, _ = torch.max(scores, dim=1, keepdim=True)
min_scores, _ = torch.min(scores, dim=1, keepdim=True)

# Avoid division by zero by adding a small epsilon
epsilon = 1e-8

# Normalize the scores
scores = (scores - min_scores) / (max_scores - min_scores + epsilon)
kl_divergence = self.loss(
torch.nn.functional.log_softmax(scores, dim=-1),
torch.nn.functional.log_softmax(
torch.tensor(self.scores, device=scores.device), dim=-1
),
).item()

metrics = self.prefix_name_to_metrics(
{"kl_divergence": kl_divergence}, self.name
)
Expand Down

0 comments on commit 96670d6

Please sign in to comment.