Skip to content

Commit

Permalink
Fix bug in scan phrase matching
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Dec 3, 2023
1 parent 7dce77c commit 04d03bd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
14 changes: 13 additions & 1 deletion searcharray/postings.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,11 @@ def memory_usage(self, deep=False):

@property
def nbytes(self):
return self.term_freqs.nbytes + self.posns.nbytes
posns_lookup_bytes = sum(x.nbytes for x in self.posns_lookup)
print(f"posns_lookup_bytes = {posns_lookup_bytes / 1024 ** 2:.2f} MB")
print(f"term_freqs.nbytes = {self.term_freqs.nbytes / 1024 ** 2:.2f} MB")
print(f"posns.nbytes = {self.posns.nbytes / 1024 ** 2:.2f} MB")
return self.term_freqs.nbytes + self.posns.nbytes + posns_lookup_bytes

def __getitem__(self, key):
key = pd.api.indexers.check_array_indexer(self, key)
Expand Down Expand Up @@ -724,6 +728,8 @@ def bm25(self, token, doc_stats=None, k1=1.2, b=0.75):
k1 : float, optional BM25 param. Defaults to 1.2.
b : float, optional BM25 param. Defaults to 0.75.
"""
# Get term freqs per token
token = self._check_token_arg(token)
return self.bm25_idf(token, doc_stats=doc_stats) * self.bm25_tf(token)

def _posns_lookup_to_csr(self):
Expand Down Expand Up @@ -886,16 +892,22 @@ def phrase_freq_scan_old(self, tokens, mask=None, slop=1):
# Find insert position of every next term in prior term's positions
# Intuition:
# https://colab.research.google.com/drive/1EeqHYuCiqyptd-awS67Re78pqVdTfH4A
if len(prior_posns[idx]) == 0:
bigram_freqs[idx] = 0
cont_posns.append([])
continue
priors_in_self = self_adjs(prior_posns[idx], term_posns[idx])
takeaway = 0
satisfies_slop = None
cont_indices = None
# Different term
if len(priors_in_self) == 0:
ins_posns = np.searchsorted(prior_posns[idx], term_posns[idx], side='right')
prior_adjacents = prior_posns[idx][ins_posns - 1]
adjacents = term_posns[idx] - prior_adjacents
satisfies_slop = (adjacents <= slop) & ~(ins_posns == 0)
cont_indices = np.argwhere(satisfies_slop).flatten()
# Overlapping term
else:
adjacents = np.diff(priors_in_self)
satisfies_slop = adjacents <= slop
Expand Down
16 changes: 10 additions & 6 deletions test/test_msmarco.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,14 @@ def ws_punc_tokenizer(text):
# msmarco phraes search: 1.5184s

@pytest.mark.skip
def test_msmarco(msmarco100k):
phrase_search = ["what", "is"]
@pytest.mark.parametrize("phrase_search", ["what is", "what is the", "what is the purpose", "what is the purpose of", "what is the purpose of cats", "star trek", "star trek the next generation"])
def test_msmarco(phrase_search, msmarco100k):
import cProfile
phrase_search = phrase_search.split()
# print(f"Memory Usage (BODY): {msmarco100k['body_ws'].array.memory_usage() / 1024 ** 2:.2f} MB")
# print(f"Memory Usage (TITLE): {msmarco100k['title_ws'].array.memory_usage() / 1024 ** 2:.2f} MB")
start = perf_counter()
print("Phrase search...")
results = msmarco100k['body_ws'].array.bm25(phrase_search)
print(f"msmarco phraes search: {perf_counter() - start:.4f}s")

with cProfile.Profile() as pr:
msmarco100k['body_ws'].array.bm25(phrase_search)
pr.print_stats(sort="cumtime")
print(f"msmarco phrase search {phrase_search}: {perf_counter() - start:.4f}s")

0 comments on commit 04d03bd

Please sign in to comment.