Skip to content

Commit

Permalink
implement document indexing and query searching
Browse files Browse the repository at this point in the history
  • Loading branch information
Jabbawukis committed Oct 11, 2024
1 parent 7d1d2f9 commit 2de5a38
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 10 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ help:
## run-workflow: Run the workflow pipeline locally for quick evaluation.
.PHONY: run-workflow
run-workflow:
pip install .
pytest tests/
black src/
pylint src/
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies = [
"transformers>=4.38",
"tqdm",
"datasets",
"Whoosh-Reloaded"
]
dynamic = ["version"]

Expand Down
88 changes: 80 additions & 8 deletions src/sample_efficiency_evaluation/fact_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
"""

import logging
import os.path
import hashlib

from abc import ABC, abstractmethod

from tqdm import tqdm
from typing_extensions import overload
from whoosh.index import create_in, FileIndex
from whoosh.fields import Schema, TEXT, ID
from whoosh.writing import SegmentWriter
from whoosh.qparser import QueryParser, query

from utility import utility

Expand All @@ -20,6 +26,10 @@ def __init__(self, **kwargs):
self.bear_relation_info_dict: dict = utility.load_json_dict(kwargs.get("bear_relation_info_path"))
self.entity_relation_info_dict: dict = self.extract_entity_information(kwargs.get("bear_data_path"))

index_path = kwargs.get("file_index_dir", "indexdir")
self.writer, self.indexer = self.initialize_index(index_path)
self.query_parser = QueryParser("content", schema=self.indexer.schema)

def extract_entity_information(self, bear_data_path: str) -> dict:
"""
Extract entity information from bear data.
Expand All @@ -34,19 +44,60 @@ def extract_entity_information(self, bear_data_path: str) -> dict:
except FileNotFoundError:
logging.error("File not found: %s/%s.jsonl", bear_data_path, relation_key)
continue
for fact in tqdm(fact_list, desc=f"Extracting entity information for {relation_key}"):
for fact in fact_list:
logging.info("Extracting entity information for %s", relation_key)
fact_dict = utility.load_json_str(fact)
relation_dict[relation_key][fact_dict["sub_label"]] = {
"aliases": fact_dict["sub_aliases"],
"obj_label": fact_dict["obj_label"],
}
return relation_dict

def index_file(self, file_content: str) -> None:
"""
Index file.
:param file_content: File content to index
:return:
"""
doc_hash = str(hashlib.sha256(file_content.encode()).hexdigest())
self.writer.add_document(title=doc_hash, path=f"/{doc_hash}", content=file_content)

def index_dataset(self, file_contents: list[dict], text_key: str = "text") -> None:
"""
Index dataset files, the dataset is a list of file contents.
:param text_key: Key to extract text from file content. Since the dataset is a list of file contents, we need to
specify the key to extract text from the file content. That would be the case if we pass a huggingface dataset.
:param file_contents: List of file contents
:return:
"""
for file_content in tqdm(file_contents, desc="Indexing dataset"):
self.index_file(file_content[text_key])
self.commit_index()

def commit_index(self) -> None:
self.writer.commit()

@staticmethod
def initialize_index(index_path) -> tuple[SegmentWriter, FileIndex]:
"""
Initialize index writer and indexer.
:param index_path:
:return:
"""
indexing_schema = Schema(title=TEXT(stored=True), path=ID(stored=True), content=TEXT(stored=True))
if not os.path.exists(index_path):
os.mkdir(index_path)
indexer = create_in(index_path, indexing_schema)
writer = indexer.writer()
return writer, indexer

@abstractmethod
def match_facts(self) -> dict:
def search_index(self, main_query: str, sub_query: str = "") -> list[dict]:
"""
Match facts
:return: Matched facts
Search index
:param main_query: The main query
:param sub_query: The sub query
:return: List of search results
"""


Expand All @@ -55,6 +106,27 @@ class FactMatcherSimpleHeuristic(FactMatcherBase):
FactMatcherSimpleHeuristic
"""

@overload
def match_facts(self) -> dict:
pass
def search_index(self, main_query: str, sub_query: str = "") -> list[dict[str, str]]:
"""
Search index for main-query and sub query.
If the sub-query is not provided, it will only search for the main query.
A simple heuristic is used to filter the search results where it is only considered a match if the query is
found in the content field.
:param main_query: The main query
:param sub_query: The sub query
:return: List of search results
"""
collected_results = []
with self.indexer.searcher() as searcher:
user_q = self.query_parser.parse(main_query)
if sub_query != "":
print(f"Searching index for ({main_query}) and ({sub_query})")
sub_q = query.Term("content", sub_query)
results = searcher.search(user_q, filter=sub_q)
else:
print(f"Searching index for ({main_query})")
results = searcher.search(user_q)
for result in results:
collected_results.append(dict(result))
return collected_results
54 changes: 52 additions & 2 deletions tests/test_fact_matcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import unittest
from unittest.mock import patch
import shutil
from unittest.mock import patch, MagicMock

from sample_efficiency_evaluation.fact_matcher import FactMatcherSimpleHeuristic
from utility import utility
Expand Down Expand Up @@ -41,9 +42,15 @@ def setUp(self) -> None:
}
self.maxDiff = None
self.test_resources_abs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "test_resources"))
self.indexer_mocked = MagicMock()
self.writer_mocked = MagicMock()
self.test_index_dir = f"{self.test_resources_abs_path}/test_index_dir"
if os.path.exists(self.test_index_dir):
shutil.rmtree(self.test_index_dir)

def test_extract_entity_information(self):
with patch.object(utility, "load_json_dict", return_value=self.test_relation_info_dict) as mock_load_json_dict:
with patch.object(utility, "load_json_dict", return_value=self.test_relation_info_dict) as mock_load_json_dict, \
patch.object(FactMatcherSimpleHeuristic, "initialize_index", return_value=(self.writer_mocked, self.indexer_mocked)):

fact_matcher = FactMatcherSimpleHeuristic(
bear_relation_info_path=f"{self.test_resources_abs_path}/relation_info.json",
Expand All @@ -53,3 +60,46 @@ def test_extract_entity_information(self):
self.assertEqual(fact_matcher.bear_relation_info_dict, self.test_relation_info_dict)
self.assertEqual(fact_matcher.entity_relation_info_dict, self.test_entity_relation_info_dict)
mock_load_json_dict.assert_called_once_with(f"{self.test_resources_abs_path}/relation_info.json")

def test_search_index(self):
with patch.object(utility, "load_json_dict", return_value=self.test_relation_info_dict):
patch.object(FactMatcherSimpleHeuristic, "extract_entity_information", return_value=self.test_entity_relation_info_dict)

fact_matcher = FactMatcherSimpleHeuristic(
bear_relation_info_path=f"{self.test_resources_abs_path}/relation_info.json",
bear_data_path=f"{self.test_resources_abs_path}/BEAR",
file_index_dir=self.test_index_dir,
)

fact_matcher.index_file("Boeing is a company")
fact_matcher.index_file("Boeing 747 is a plane")
fact_matcher.commit_index()
results = fact_matcher.search_index("Boeing")
self.assertEqual(len(results), 2)
self.assertEqual(results,
[{'path': '/ddda5959a6a4f994ee6a55c0e60b6137ea776e79846fc5a35d58ef0840005905',
'title': 'ddda5959a6a4f994ee6a55c0e60b6137ea776e79846fc5a35d58ef0840005905',
'content': 'Boeing is a company'},
{'path': '/1b4c34a604c95618ceb558da613bd8655d0a6a21ccaf0480dc150eff44d30047',
'title': '1b4c34a604c95618ceb558da613bd8655d0a6a21ccaf0480dc150eff44d30047',
'content': 'Boeing 747 is a plane'}])

def test_search_index_sub_query(self):
with patch.object(utility, "load_json_dict", return_value=self.test_relation_info_dict):
patch.object(FactMatcherSimpleHeuristic, "extract_entity_information", return_value=self.test_entity_relation_info_dict)

fact_matcher = FactMatcherSimpleHeuristic(
bear_relation_info_path=f"{self.test_resources_abs_path}/relation_info.json",
bear_data_path=f"{self.test_resources_abs_path}/BEAR",
file_index_dir=self.test_index_dir,
)

fact_matcher.index_file("Boeing is a company")
fact_matcher.index_file("Boeing 747 is a plane")
fact_matcher.commit_index()
results = fact_matcher.search_index("Boeing", sub_query="747")
self.assertEqual(len(results), 1)
self.assertEqual(results,
[{'path': '/1b4c34a604c95618ceb558da613bd8655d0a6a21ccaf0480dc150eff44d30047',
'title': '1b4c34a604c95618ceb558da613bd8655d0a6a21ccaf0480dc150eff44d30047',
'content': 'Boeing 747 is a plane'}])

0 comments on commit 2de5a38

Please sign in to comment.