Skip to content

Commit

Permalink
Correct field centric mm
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Dec 28, 2023
1 parent bea2dbf commit 652124e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
12 changes: 7 additions & 5 deletions searcharray/solr.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _edismax_term_centric(frame: pd.DataFrame,
query_fields: Dict[str, float],
num_search_terms: int,
search_terms: Dict[str, List[str]],
min_should_match: int) -> Tuple[np.ndarray, str]:
mm: str) -> Tuple[np.ndarray, str]:
explain = []
term_scores = []
for term_posn in range(num_search_terms):
Expand All @@ -127,6 +127,7 @@ def _edismax_term_centric(frame: pd.DataFrame,
term_scores.append(max_scores)
explain.append("(" + " | ".join(term_explain) + ")")

min_should_match = parse_min_should_match(num_search_terms, spec=mm)
qf_scores = np.asarray(term_scores)
matches_gt_mm = np.sum(qf_scores > 0, axis=0) >= min_should_match
qf_scores = np.sum(term_scores, axis=0)
Expand All @@ -138,16 +139,18 @@ def _edismax_field_centric(frame: pd.DataFrame,
query_fields: Dict[str, float],
num_search_terms: int,
search_terms: Dict[str, List[str]],
min_should_match: int) -> Tuple[np.ndarray, str]:
mm: str) -> Tuple[np.ndarray, str]:
field_scores = []
explain = []
for field, boost in query_fields.items():
post_arr = get_field(frame, field)
term_scores = np.array([post_arr.bm25(term) for term in search_terms[field]])
min_should_match = parse_min_should_match(len(search_terms[field]), spec=mm)
exp = " ".join([f"{field}:{term}" for term in search_terms[field]])
boost_exp = f"{boost}" if boost is not None else "1"
exp = "(" + exp + f")~{min(min_should_match, len(search_terms[field]))}"
exp = "(" + exp + f")^{boost_exp}"

matches_gt_mm = np.sum(term_scores > 0, axis=0) >= min(min_should_match, len(search_terms[field]))
sum_terms_bm25 = np.sum(term_scores, axis=0)
sum_terms_bm25[~matches_gt_mm] = 0
Expand Down Expand Up @@ -205,11 +208,10 @@ def listify(x):
# trigram_fields = parse_field_boosts(pf3) if pf3 else {}

num_search_terms, search_terms, term_centric = parse_query_terms(frame, q, list(query_fields.keys()))
min_should_match = parse_min_should_match(num_search_terms, spec=mm)
if term_centric:
qf_scores, explain = _edismax_term_centric(frame, query_fields, num_search_terms, search_terms, min_should_match)
qf_scores, explain = _edismax_term_centric(frame, query_fields, num_search_terms, search_terms, mm)
else:
qf_scores, explain = _edismax_field_centric(frame, query_fields, num_search_terms, search_terms, min_should_match)
qf_scores, explain = _edismax_field_centric(frame, query_fields, num_search_terms, search_terms, mm)

phrase_scores = []
for field, boost in phrase_fields.items():
Expand Down
14 changes: 14 additions & 0 deletions test/test_solr.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ def just_lowercasing_tokenizer(text: str) -> List[str]:
0],
"params": {'q': "foo bar", 'qf': ["title", "body"], 'mm': "2"},
},
"field_centric_mm_opp": {
"frame": {
'title': lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"]),
'body': lambda: PostingsArray.index(["foo bar", "data2", "data3 bar", "bunny funny wunny"],
tokenizer=just_lowercasing_tokenizer)
},
"expected": [lambda frame: max(sum([frame['title'].array.bm25("foo")[0],
frame['title'].array.bm25("bar")[0]]),
frame['body'].array.bm25("foo bar")[0]),
0,
0,
0],
"params": {'q': "foo bar", 'qf': ["body", "title"], 'mm': "2"},
},
"boost_title": {
"frame": {
'title': lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"]),
Expand Down

0 comments on commit 652124e

Please sign in to comment.