Skip to content

Commit

Permalink
miracl evaluation and fix beir
Browse files Browse the repository at this point in the history
  • Loading branch information
NohTow committed Aug 2, 2024
1 parent bad5788 commit a0c64e3
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
2 changes: 1 addition & 1 deletion evaluation_beir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
model = models.ColBERT(
model_name_or_path="NohTow/colbertv2_sentence_transformer",
)
index = indexes.Weaviate(recreate=True, max_doc_length=model.document_length)
index = indexes.Weaviate(override_collection=True, max_doc_length=model.document_length)

retriever = retrieve.ColBERT(index=index)

Expand Down
60 changes: 60 additions & 0 deletions evaluation_miracl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Evaluation script for the miracl_fr dataset using the Beir library."""

from beir.datasets.data_loader import GenericDataLoader

from giga_cherche import evaluation, indexes, models, retrieve, utils

model = models.ColBERT(
model_name_or_path="NohTow/colbert_xml-r-english",
document_length=300,
)
index = indexes.Weaviate(override_collection=True, max_doc_length=model.document_length)
retriever = retrieve.ColBERT(index=index)

documents, queries, qrels = GenericDataLoader("datasets/miracl_fr").load(split="dev")

documents = [
{
"id": document_id,
"title": document["title"],
"text": document["text"],
}
for document_id, document in documents.items()
]

qrels = {
queries[query_id]: query_documents for query_id, query_documents in qrels.items()
}
queries = list(qrels.keys())

for batch in utils.iter_batch(documents, batch_size=500):
documents_embeddings = model.encode(
[document["title"] + " " + document["text"] for document in batch],
convert_to_numpy=True,
is_query=False,
)

index.add_documents(
doc_ids=[document["id"] for document in batch],
doc_embeddings=documents_embeddings,
)

scores = []

for batch in utils.iter_batch(queries, batch_size=5):
queries_embeddings = model.encode(
sentences=batch,
convert_to_numpy=True,
is_query=True,
)

scores.extend(retriever.retrieve(queries=queries_embeddings, k=10))

print(
evaluation.evaluate(
scores=scores,
qrels=qrels,
queries=queries,
metrics=["map", "ndcg@10", "ndcg@100", "recall@10", "recall@100"],
)
)

0 comments on commit a0c64e3

Please sign in to comment.