diff --git a/Makefile b/Makefile index 79ef772..079596b 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,7 @@ help: ## run-workflow: Run the workflow pipeline locally for quick evaluation. .PHONY: run-workflow run-workflow: + pip install . pytest tests/ black src/ pylint src/ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 282a972..09db71e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "transformers>=4.38", "tqdm", "datasets", + "Whoosh-Reloaded" ] dynamic = ["version"] diff --git a/src/sample_efficiency_evaluation/fact_matcher.py b/src/sample_efficiency_evaluation/fact_matcher.py index db84f50..84846a5 100644 --- a/src/sample_efficiency_evaluation/fact_matcher.py +++ b/src/sample_efficiency_evaluation/fact_matcher.py @@ -3,10 +3,16 @@ """ import logging +import os.path +import hashlib from abc import ABC, abstractmethod + from tqdm import tqdm -from typing_extensions import overload +from whoosh.index import create_in, FileIndex +from whoosh.fields import Schema, TEXT, ID +from whoosh.writing import SegmentWriter +from whoosh.qparser import QueryParser, query from utility import utility @@ -20,6 +26,10 @@ def __init__(self, **kwargs): self.bear_relation_info_dict: dict = utility.load_json_dict(kwargs.get("bear_relation_info_path")) self.entity_relation_info_dict: dict = self.extract_entity_information(kwargs.get("bear_data_path")) + index_path = kwargs.get("file_index_dir", "indexdir") + self.writer, self.indexer = self.initialize_index(index_path) + self.query_parser = QueryParser("content", schema=self.indexer.schema) + def extract_entity_information(self, bear_data_path: str) -> dict: """ Extract entity information from bear data. @@ -34,7 +44,8 @@ def extract_entity_information(self, bear_data_path: str) -> dict: except FileNotFoundError: logging.error("File not found: %s/%s.jsonl", bear_data_path, relation_key) continue - for fact in tqdm(fact_list, desc=f"Extracting entity information for {relation_key}"): + for fact in fact_list: + logging.info("Extracting entity information for %s", relation_key) fact_dict = utility.load_json_str(fact) relation_dict[relation_key][fact_dict["sub_label"]] = { "aliases": fact_dict["sub_aliases"], @@ -42,11 +53,51 @@ def extract_entity_information(self, bear_data_path: str) -> dict: } return relation_dict + def index_file(self, file_content: str) -> None: + """ + Index file. + :param file_content: File content to index + :return: + """ + doc_hash = str(hashlib.sha256(file_content.encode()).hexdigest()) + self.writer.add_document(title=doc_hash, path=f"/{doc_hash}", content=file_content) + + def index_dataset(self, file_contents: list[dict], text_key: str = "text") -> None: + """ + Index dataset files, the dataset is a list of file contents. + :param text_key: Key to extract text from file content. Since the dataset is a list of file contents, we need to + specify the key to extract text from the file content. That would be the case if we pass a huggingface dataset. + :param file_contents: List of file contents + :return: + """ + for file_content in tqdm(file_contents, desc="Indexing dataset"): + self.index_file(file_content[text_key]) + self.commit_index() + + def commit_index(self) -> None: + self.writer.commit() + + @staticmethod + def initialize_index(index_path) -> tuple[SegmentWriter, FileIndex]: + """ + Initialize index writer and indexer. + :param index_path: + :return: + """ + indexing_schema = Schema(title=TEXT(stored=True), path=ID(stored=True), content=TEXT(stored=True)) + if not os.path.exists(index_path): + os.mkdir(index_path) + indexer = create_in(index_path, indexing_schema) + writer = indexer.writer() + return writer, indexer + @abstractmethod - def match_facts(self) -> dict: + def search_index(self, main_query: str, sub_query: str = "") -> list[dict]: """ - Match facts - :return: Matched facts + Search index + :param main_query: The main query + :param sub_query: The sub query + :return: List of search results """ @@ -55,6 +106,27 @@ class FactMatcherSimpleHeuristic(FactMatcherBase): FactMatcherSimpleHeuristic """ - @overload - def match_facts(self) -> dict: - pass + def search_index(self, main_query: str, sub_query: str = "") -> list[dict[str, str]]: + """ + Search index for main-query and sub query. + + If the sub-query is not provided, it will only search for the main query. + A simple heuristic is used to filter the search results where it is only considered a match if the query is + found in the content field. + :param main_query: The main query + :param sub_query: The sub query + :return: List of search results + """ + collected_results = [] + with self.indexer.searcher() as searcher: + user_q = self.query_parser.parse(main_query) + if sub_query != "": + print(f"Searching index for ({main_query}) and ({sub_query})") + sub_q = query.Term("content", sub_query) + results = searcher.search(user_q, filter=sub_q) + else: + print(f"Searching index for ({main_query})") + results = searcher.search(user_q) + for result in results: + collected_results.append(dict(result)) + return collected_results diff --git a/tests/test_fact_matcher.py b/tests/test_fact_matcher.py index cd2a2a2..9a4d012 100644 --- a/tests/test_fact_matcher.py +++ b/tests/test_fact_matcher.py @@ -1,6 +1,7 @@ import os import unittest -from unittest.mock import patch +import shutil +from unittest.mock import patch, MagicMock from sample_efficiency_evaluation.fact_matcher import FactMatcherSimpleHeuristic from utility import utility @@ -41,9 +42,15 @@ def setUp(self) -> None: } self.maxDiff = None self.test_resources_abs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "test_resources")) + self.indexer_mocked = MagicMock() + self.writer_mocked = MagicMock() + self.test_index_dir = f"{self.test_resources_abs_path}/test_index_dir" + if os.path.exists(self.test_index_dir): + shutil.rmtree(self.test_index_dir) def test_extract_entity_information(self): - with patch.object(utility, "load_json_dict", return_value=self.test_relation_info_dict) as mock_load_json_dict: + with patch.object(utility, "load_json_dict", return_value=self.test_relation_info_dict) as mock_load_json_dict, \ + patch.object(FactMatcherSimpleHeuristic, "initialize_index", return_value=(self.writer_mocked, self.indexer_mocked)): fact_matcher = FactMatcherSimpleHeuristic( bear_relation_info_path=f"{self.test_resources_abs_path}/relation_info.json", @@ -53,3 +60,46 @@ def test_extract_entity_information(self): self.assertEqual(fact_matcher.bear_relation_info_dict, self.test_relation_info_dict) self.assertEqual(fact_matcher.entity_relation_info_dict, self.test_entity_relation_info_dict) mock_load_json_dict.assert_called_once_with(f"{self.test_resources_abs_path}/relation_info.json") + + def test_search_index(self): + with patch.object(utility, "load_json_dict", return_value=self.test_relation_info_dict): + patch.object(FactMatcherSimpleHeuristic, "extract_entity_information", return_value=self.test_entity_relation_info_dict) + + fact_matcher = FactMatcherSimpleHeuristic( + bear_relation_info_path=f"{self.test_resources_abs_path}/relation_info.json", + bear_data_path=f"{self.test_resources_abs_path}/BEAR", + file_index_dir=self.test_index_dir, + ) + + fact_matcher.index_file("Boeing is a company") + fact_matcher.index_file("Boeing 747 is a plane") + fact_matcher.commit_index() + results = fact_matcher.search_index("Boeing") + self.assertEqual(len(results), 2) + self.assertEqual(results, + [{'path': '/ddda5959a6a4f994ee6a55c0e60b6137ea776e79846fc5a35d58ef0840005905', + 'title': 'ddda5959a6a4f994ee6a55c0e60b6137ea776e79846fc5a35d58ef0840005905', + 'content': 'Boeing is a company'}, + {'path': '/1b4c34a604c95618ceb558da613bd8655d0a6a21ccaf0480dc150eff44d30047', + 'title': '1b4c34a604c95618ceb558da613bd8655d0a6a21ccaf0480dc150eff44d30047', + 'content': 'Boeing 747 is a plane'}]) + + def test_search_index_sub_query(self): + with patch.object(utility, "load_json_dict", return_value=self.test_relation_info_dict): + patch.object(FactMatcherSimpleHeuristic, "extract_entity_information", return_value=self.test_entity_relation_info_dict) + + fact_matcher = FactMatcherSimpleHeuristic( + bear_relation_info_path=f"{self.test_resources_abs_path}/relation_info.json", + bear_data_path=f"{self.test_resources_abs_path}/BEAR", + file_index_dir=self.test_index_dir, + ) + + fact_matcher.index_file("Boeing is a company") + fact_matcher.index_file("Boeing 747 is a plane") + fact_matcher.commit_index() + results = fact_matcher.search_index("Boeing", sub_query="747") + self.assertEqual(len(results), 1) + self.assertEqual(results, + [{'path': '/1b4c34a604c95618ceb558da613bd8655d0a6a21ccaf0480dc150eff44d30047', + 'title': '1b4c34a604c95618ceb558da613bd8655d0a6a21ccaf0480dc150eff44d30047', + 'content': 'Boeing 747 is a plane'}])