From 3f071736d41e76f0620b1063dc34e4f04e34b5be Mon Sep 17 00:00:00 2001 From: ZiyiXia Date: Mon, 11 Nov 2024 17:29:15 +0800 Subject: [PATCH] evaluation --- FlagEmbedding/abc/evaluation/data_loader.py | 2 +- .../evaluation/miracl/data_loader.py | 47 ++++++++++ FlagEmbedding/evaluation/miracl/runner.py | 8 ++ FlagEmbedding/evaluation/mkqa/__init__.py | 2 + FlagEmbedding/evaluation/mkqa/data_loader.py | 59 +++++++++++++ FlagEmbedding/evaluation/mkqa/evaluator.py | 30 +++++++ FlagEmbedding/evaluation/mkqa/runner.py | 13 +++ docs/source/API/evaluation.rst | 6 +- docs/source/API/evaluation/miracl.rst | 48 ++++++++++ .../API/evaluation/miracl/data_loader.rst | 13 +++ docs/source/API/evaluation/miracl/runner.rst | 5 ++ docs/source/API/evaluation/mkqa.rst | 87 +++++++++++++++++++ .../API/evaluation/mkqa/data_loader.rst | 15 ++++ docs/source/API/evaluation/mkqa/evaluator.rst | 5 ++ docs/source/API/evaluation/mkqa/runner.rst | 4 + 15 files changed, 342 insertions(+), 2 deletions(-) create mode 100644 docs/source/API/evaluation/miracl.rst create mode 100644 docs/source/API/evaluation/miracl/data_loader.rst create mode 100644 docs/source/API/evaluation/miracl/runner.rst create mode 100644 docs/source/API/evaluation/mkqa.rst create mode 100644 docs/source/API/evaluation/mkqa/data_loader.rst create mode 100644 docs/source/API/evaluation/mkqa/evaluator.rst create mode 100644 docs/source/API/evaluation/mkqa/runner.rst diff --git a/FlagEmbedding/abc/evaluation/data_loader.py b/FlagEmbedding/abc/evaluation/data_loader.py index 02f3588b..f55e62db 100644 --- a/FlagEmbedding/abc/evaluation/data_loader.py +++ b/FlagEmbedding/abc/evaluation/data_loader.py @@ -113,7 +113,7 @@ def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDic return self._load_remote_corpus(dataset_name=dataset_name) def load_qrels(self, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict: - """Load the corpus from the dataset. + """Load the qrels from the dataset. Args: dataset_name (Optional[str], optional): Name of the dataset. Defaults to :data:`None`. diff --git a/FlagEmbedding/evaluation/miracl/data_loader.py b/FlagEmbedding/evaluation/miracl/data_loader.py index 700781fd..cb508707 100644 --- a/FlagEmbedding/evaluation/miracl/data_loader.py +++ b/FlagEmbedding/evaluation/miracl/data_loader.py @@ -11,10 +11,28 @@ class MIRACLEvalDataLoader(AbsEvalDataLoader): + """ + Data loader class for MIRACL. + """ def available_dataset_names(self) -> List[str]: + """ + Get the available dataset names. + + Returns: + List[str]: All the available dataset names. + """ return ["ar", "bn", "en", "es", "fa", "fi", "fr", "hi", "id", "ja", "ko", "ru", "sw", "te", "th", "zh", "de", "yo"] def available_splits(self, dataset_name: str) -> List[str]: + """ + Get the avaialble splits. + + Args: + dataset_name (str): Dataset name. + + Returns: + List[str]: All the available splits for the dataset. + """ if dataset_name in ["de", "yo"]: return ["dev"] else: @@ -25,6 +43,15 @@ def _load_remote_corpus( dataset_name: str, save_dir: Optional[str] = None ) -> datasets.DatasetDict: + """Load the corpus dataset from HF. + + Args: + dataset_name (str): Name of the dataset. + save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``. + + Returns: + datasets.DatasetDict: Loaded datasets instance of corpus. + """ corpus = datasets.load_dataset( "miracl/miracl-corpus", dataset_name, cache_dir=self.cache_dir, @@ -60,6 +87,16 @@ def _load_remote_qrels( split: str = 'dev', save_dir: Optional[str] = None ) -> datasets.DatasetDict: + """Load the qrels from HF. + + Args: + dataset_name (str): Name of the dataset. + split (str, optional): Split of the dataset. Defaults to ``'dev'``. + save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``. + + Returns: + datasets.DatasetDict: Loaded datasets instance of qrel. + """ endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/miracl/miracl" qrels_download_url = f"{endpoint}/resolve/main/miracl-v1.0-{dataset_name}/qrels/qrels.miracl-v1.0-{dataset_name}-{split}.tsv" @@ -101,6 +138,16 @@ def _load_remote_queries( split: str = 'dev', save_dir: Optional[str] = None ) -> datasets.DatasetDict: + """Load the queries from HF. + + Args: + dataset_name (str): Name of the dataset. + split (str, optional): Split of the dataset. Defaults to ``'dev'``. + save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``. + + Returns: + datasets.DatasetDict: Loaded datasets instance of queries. + """ endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/miracl/miracl" queries_download_url = f"{endpoint}/resolve/main/miracl-v1.0-{dataset_name}/topics/topics.miracl-v1.0-{dataset_name}-{split}.tsv" diff --git a/FlagEmbedding/evaluation/miracl/runner.py b/FlagEmbedding/evaluation/miracl/runner.py index c6702a81..bab0a681 100644 --- a/FlagEmbedding/evaluation/miracl/runner.py +++ b/FlagEmbedding/evaluation/miracl/runner.py @@ -4,7 +4,15 @@ class MIRACLEvalRunner(AbsEvalRunner): + """ + Evaluation runner of MIRACL. + """ def load_data_loader(self) -> MIRACLEvalDataLoader: + """Load the data loader instance by args. + + Returns: + MIRACLEvalDataLoader: The MIRACL data loader instance. + """ data_loader = MIRACLEvalDataLoader( eval_name=self.eval_args.eval_name, dataset_dir=self.eval_args.dataset_dir, diff --git a/FlagEmbedding/evaluation/mkqa/__init__.py b/FlagEmbedding/evaluation/mkqa/__init__.py index 87c20deb..072cfebf 100644 --- a/FlagEmbedding/evaluation/mkqa/__init__.py +++ b/FlagEmbedding/evaluation/mkqa/__init__.py @@ -4,6 +4,7 @@ ) from .data_loader import MKQAEvalDataLoader +from .evaluator import MKQAEvaluator from .runner import MKQAEvalRunner __all__ = [ @@ -11,4 +12,5 @@ "MKQAEvalModelArgs", "MKQAEvalRunner", "MKQAEvalDataLoader", + "MKQAEvaluator" ] diff --git a/FlagEmbedding/evaluation/mkqa/data_loader.py b/FlagEmbedding/evaluation/mkqa/data_loader.py index 25bdab1a..af2180d6 100644 --- a/FlagEmbedding/evaluation/mkqa/data_loader.py +++ b/FlagEmbedding/evaluation/mkqa/data_loader.py @@ -13,13 +13,39 @@ class MKQAEvalDataLoader(AbsEvalDataLoader): + """ + Data loader class for MKQA. + """ def available_dataset_names(self) -> List[str]: + """ + Get the available dataset names. + + Returns: + List[str]: All the available dataset names. + """ return ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw'] def available_splits(self, dataset_name: Optional[str] = None) -> List[str]: + """ + Get the avaialble splits. + + Args: + dataset_name (str): Dataset name. + + Returns: + List[str]: All the available splits for the dataset. + """ return ["test"] def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDict: + """Load the corpus. + + Args: + dataset_name (Optional[str], optional): Name of the dataset. Defaults to None. + + Returns: + datasets.DatasetDict: Loaded datasets instance of corpus. + """ if self.dataset_dir is not None: # same corpus for all languages save_dir = self.dataset_dir @@ -28,6 +54,19 @@ def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDic return self._load_remote_corpus(dataset_name=dataset_name) def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict: + """Try to load qrels from local datasets. + + Args: + save_dir (str): Directory that save the data files. + dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``. + split (str, optional): Split of the dataset. Defaults to ``'test'``. + + Raises: + ValueError: No local qrels found, will try to download from remote. + + Returns: + datasets.DatasetDict: Loaded datasets instance of qrels. + """ checked_split = self.check_splits(split) if len(checked_split) == 0: raise ValueError(f"Split {split} not found in the dataset.") @@ -96,6 +135,16 @@ def _load_remote_qrels( split: str = 'test', save_dir: Optional[str] = None ) -> datasets.DatasetDict: + """Load remote qrels from HF. + + Args: + dataset_name (str): Name of the dataset. + split (str, optional): Split of the dataset. Defaults to ``'test'``. + save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``. + + Returns: + datasets.DatasetDict: Loaded datasets instance of qrel. + """ endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/Shitao/bge-m3-data" queries_download_url = f"{endpoint}/resolve/main/MKQA_test-data.zip" @@ -137,6 +186,16 @@ def _load_remote_queries( split: str = 'test', save_dir: Optional[str] = None ) -> datasets.DatasetDict: + """Load the queries from HF. + + Args: + dataset_name (str): Name of the dataset. + split (str, optional): Split of the dataset. Defaults to ``'test'``. + save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``. + + Returns: + datasets.DatasetDict: Loaded datasets instance of queries. + """ endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/Shitao/bge-m3-data" queries_download_url = f"{endpoint}/resolve/main/MKQA_test-data.zip" diff --git a/FlagEmbedding/evaluation/mkqa/evaluator.py b/FlagEmbedding/evaluation/mkqa/evaluator.py index d6a4756a..e65e7eae 100644 --- a/FlagEmbedding/evaluation/mkqa/evaluator.py +++ b/FlagEmbedding/evaluation/mkqa/evaluator.py @@ -8,12 +8,25 @@ class MKQAEvaluator(AbsEvaluator): + """ + The evaluator class of MKQA. + """ def get_corpus_embd_save_dir( self, retriever_name: str, corpus_embd_save_dir: Optional[str] = None, dataset_name: Optional[str] = None ): + """Get the directory to save the corpus embedding. + + Args: + retriever_name (str): Name of the retriever. + corpus_embd_save_dir (Optional[str], optional): Directory to save the corpus embedding. Defaults to ``None``. + dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``. + + Returns: + str: The final directory to save the corpus embedding. + """ if corpus_embd_save_dir is not None: # Save the corpus embeddings in the same directory for all dataset_name corpus_embd_save_dir = os.path.join(corpus_embd_save_dir, retriever_name) @@ -24,6 +37,15 @@ def evaluate_results( search_results_save_dir: str, k_values: List[int] = [1, 3, 5, 10, 100, 1000] ): + """Compute the metrics and get the eval results. + + Args: + search_results_save_dir (str): Directory that saves the search results. + k_values (List[int], optional): Cutoffs. Defaults to ``[1, 3, 5, 10, 100, 1000]``. + + Returns: + dict: The evaluation results. + """ eval_results_dict = {} corpus = self.data_loader.load_corpus() @@ -70,6 +92,14 @@ def compute_metrics( ): """ Compute Recall@k for QA task. The definition of recall in QA task is different from the one in IR task. Please refer to the paper of RocketQA: https://aclanthology.org/2021.naacl-main.466.pdf. + + Args: + corpus_dict (Dict[str, str]): Dictionary of the corpus with doc id and contents. + qrels (Dict[str, List[str]]): Relevances of queries and passage. + search_results (Dict[str, Dict[str, float]]): Search results of the model to evaluate. + + Returns: + dict: The model's scores of the metrics. """ contexts = [] answers = [] diff --git a/FlagEmbedding/evaluation/mkqa/runner.py b/FlagEmbedding/evaluation/mkqa/runner.py index 74683aac..69902765 100644 --- a/FlagEmbedding/evaluation/mkqa/runner.py +++ b/FlagEmbedding/evaluation/mkqa/runner.py @@ -5,7 +5,15 @@ class MKQAEvalRunner(AbsEvalRunner): + """ + Evaluation runner of MKQA. + """ def load_data_loader(self) -> MKQAEvalDataLoader: + """Load the data loader instance by args. + + Returns: + MKQAEvalDataLoader: The MKQA data loader instance. + """ data_loader = MKQAEvalDataLoader( eval_name=self.eval_args.eval_name, dataset_dir=self.eval_args.dataset_dir, @@ -16,6 +24,11 @@ def load_data_loader(self) -> MKQAEvalDataLoader: return data_loader def load_evaluator(self) -> MKQAEvaluator: + """Load the evaluator instance by args. + + Returns: + MKQAEvaluator: The MKQA evaluator instance. + """ evaluator = MKQAEvaluator( eval_name=self.eval_args.eval_name, data_loader=self.data_loader, diff --git a/docs/source/API/evaluation.rst b/docs/source/API/evaluation.rst index 0e6b32f8..6b52886d 100644 --- a/docs/source/API/evaluation.rst +++ b/docs/source/API/evaluation.rst @@ -1,2 +1,6 @@ Evaluation -========== \ No newline at end of file +========== + +.. toctree:: + evaluation/miracl + evaluation/mkqa \ No newline at end of file diff --git a/docs/source/API/evaluation/miracl.rst b/docs/source/API/evaluation/miracl.rst new file mode 100644 index 00000000..132bcf7c --- /dev/null +++ b/docs/source/API/evaluation/miracl.rst @@ -0,0 +1,48 @@ +MIRACL +====== + +`MIRACL `_ (Multilingual Information Retrieval Across a Continuum of Languages) +is an WSDM 2023 Cup challenge that focuses on search across 18 different languages. +They release a multilingual retrieval dataset containing the train and dev set for 16 "known languages" and only dev set for 2 "surprise languages". +The topics are generated by native speakers of each language, who also label the relevance between the topics and a given document list. +You can found the `dataset `_ on HuggingFace. + +You can evaluate model's performance on MIRACL simply by running our provided shell script: + +.. code:: bash + + chmod +x /examples/evaluation/miracl/eval_miracl.sh + ./examples/evaluation/miracl/eval_miracl.sh + +Or by running: + +.. code:: bash + + python -m FlagEmbedding.evaluation.miracl \ + --eval_name miracl \ + --dataset_dir ./miracl/data \ + --dataset_names bn hi sw te th yo \ + --splits dev \ + --corpus_embd_save_dir ./miracl/corpus_embd \ + --output_dir ./miracl/search_results \ + --search_top_k 1000 \ + --rerank_top_k 100 \ + --cache_path /root/.cache/huggingface/hub \ + --overwrite False \ + --k_values 10 100 \ + --eval_output_method markdown \ + --eval_output_path ./miracl/miracl_eval_results.md \ + --eval_metrics ndcg_at_10 recall_at_100 \ + --embedder_name_or_path BAAI/bge-m3 \ + --reranker_name_or_path BAAI/bge-reranker-v2-m3 \ + --devices cuda:0 cuda:1 \ + --cache_dir /root/.cache/huggingface/hub \ + --reranker_max_length 1024 + +change the embedder, reranker, devices and cache directory to your preference. + +.. toctree:: + :hidden: + + miracl/data_loader + miracl/runner \ No newline at end of file diff --git a/docs/source/API/evaluation/miracl/data_loader.rst b/docs/source/API/evaluation/miracl/data_loader.rst new file mode 100644 index 00000000..7dbcfced --- /dev/null +++ b/docs/source/API/evaluation/miracl/data_loader.rst @@ -0,0 +1,13 @@ +data_loader +=========== + +.. autoclass:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader + +Methods +------- + +.. automethod:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader.available_dataset_names +.. automethod:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader.available_splits +.. automethod:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader._load_remote_corpus +.. automethod:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader._load_remote_qrels +.. automethod:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader._load_remote_queries \ No newline at end of file diff --git a/docs/source/API/evaluation/miracl/runner.rst b/docs/source/API/evaluation/miracl/runner.rst new file mode 100644 index 00000000..b77da139 --- /dev/null +++ b/docs/source/API/evaluation/miracl/runner.rst @@ -0,0 +1,5 @@ +runner +====== + +.. autoclass:: FlagEmbedding.evaluation.miracl.MIRACLEvalRunner + :members: \ No newline at end of file diff --git a/docs/source/API/evaluation/mkqa.rst b/docs/source/API/evaluation/mkqa.rst new file mode 100644 index 00000000..0f242362 --- /dev/null +++ b/docs/source/API/evaluation/mkqa.rst @@ -0,0 +1,87 @@ +MKQA +==== + +`MKQA `_ is an open-domain question answering evaluation set comprising 10k question-answer pairs aligned across 26 typologically diverse languages. +Each example in the dataset has the following structure: + +.. code:: bash + + { + 'example_id': 563260143484355911, + 'queries': { + 'en': "who sings i hear you knocking but you can't come in", + 'ru': "кто поет i hear you knocking but you can't come in", + 'ja': '「 I hear you knocking」は誰が歌っていますか', + 'zh_cn': "《i hear you knocking but you can't come in》是谁演唱的", + ... + }, + 'query': "who sings i hear you knocking but you can't come in", + 'answers': { + 'en': [{ + 'type': 'entity', + 'entity': 'Q545186', + 'text': 'Dave Edmunds', + 'aliases': [], + }], + 'ru': [{ + 'type': 'entity', + 'entity': 'Q545186', + 'text': 'Эдмундс, Дэйв', + 'aliases': ['Эдмундс', 'Дэйв Эдмундс', 'Эдмундс Дэйв', 'Dave Edmunds'], + }], + 'ja': [{ + 'type': 'entity', + 'entity': 'Q545186', + 'text': 'デイヴ・エドモンズ', + 'aliases': ['デーブ・エドモンズ', 'デイブ・エドモンズ'], + }], + 'zh_cn': [{ + 'type': 'entity', + 'text': '戴维·埃德蒙兹 ', + 'entity': 'Q545186', + }], + ... + }, + } + + +You can evaluate model's performance on MKQA simply by running our provided shell script: + +.. code:: bash + + chmod +x /examples/evaluation/mkqa/eval_mkqa.sh + ./examples/evaluation/mkqa/eval_mkqa.sh + +Or by running: + +.. code:: bash + + python -m FlagEmbedding.evaluation.mkqa \ + --eval_name mkqa \ + --dataset_dir ./mkqa/data \ + --dataset_names en zh_cn \ + --splits test \ + --corpus_embd_save_dir ./mkqa/corpus_embd \ + --output_dir ./mkqa/search_results \ + --search_top_k 1000 \ + --rerank_top_k 100 \ + --cache_path /root/.cache/huggingface/hub \ + --overwrite False \ + --k_values 20 \ + --eval_output_method markdown \ + --eval_output_path ./mkqa/mkqa_eval_results.md \ + --eval_metrics qa_recall_at_20 \ + --embedder_name_or_path BAAI/bge-m3 \ + --reranker_name_or_path BAAI/bge-reranker-v2-m3 \ + --devices cuda:0 cuda:1 \ + --cache_dir /root/.cache/huggingface/hub \ + --reranker_max_length 1024 + +change the embedder, reranker, devices and cache directory to your preference. + +.. toctree:: + :hidden: + + mkqa/data_loader + mkqa/evaluator + mkqa/runner \ No newline at end of file diff --git a/docs/source/API/evaluation/mkqa/data_loader.rst b/docs/source/API/evaluation/mkqa/data_loader.rst new file mode 100644 index 00000000..94c62b22 --- /dev/null +++ b/docs/source/API/evaluation/mkqa/data_loader.rst @@ -0,0 +1,15 @@ +data_loader +=========== + +.. autoclass:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader + +Methods +------- + +.. automethod:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader.available_dataset_names +.. automethod:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader.available_splits +.. automethod:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader.load_corpus +.. automethod:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader._load_local_qrels +.. automethod:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader._load_remote_corpus +.. automethod:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader._load_remote_qrels +.. automethod:: FlagEmbedding.evaluation.mkqa.MKQAEvalDataLoader._load_remote_queries \ No newline at end of file diff --git a/docs/source/API/evaluation/mkqa/evaluator.rst b/docs/source/API/evaluation/mkqa/evaluator.rst new file mode 100644 index 00000000..c46fc2f9 --- /dev/null +++ b/docs/source/API/evaluation/mkqa/evaluator.rst @@ -0,0 +1,5 @@ +evaluator +========= + +.. autoclass:: FlagEmbedding.evaluation.mkqa.MKQAEvaluator + :members: \ No newline at end of file diff --git a/docs/source/API/evaluation/mkqa/runner.rst b/docs/source/API/evaluation/mkqa/runner.rst new file mode 100644 index 00000000..bddedfcb --- /dev/null +++ b/docs/source/API/evaluation/mkqa/runner.rst @@ -0,0 +1,4 @@ +runner +====== +.. autoclass:: FlagEmbedding.evaluation.mkqa.MKQAEvalRunner + :members: \ No newline at end of file