Skip to content

Commit

Permalink
fix return of already collected candidates
Browse files Browse the repository at this point in the history
  • Loading branch information
rwood-97 committed Oct 3, 2023
1 parent 0d939c2 commit 81bc751
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 55 deletions.
5 changes: 1 addition & 4 deletions experiments/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,7 @@ def prepare_data(self) -> dict:
# Obtain candidates per sentence:
for sentence_id in tqdm(dMentionsPred):
pred_mentions_sent = dMentionsPred[sentence_id]
(
wk_cands,
self.myranker.already_collected_cands,
) = self.myranker.find_candidates(pred_mentions_sent)
wk_cands = self.myranker.find_candidates(pred_mentions_sent)
dCandidates[sentence_id] = wk_cands

# -------------------------------------------
Expand Down
8 changes: 2 additions & 6 deletions geoparser/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,7 @@ def run_sentence(
rmentions = [{"mention": y["mention"]} for y in mentions]

# Perform candidate ranking:
wk_cands, self.myranker.already_collected_cands = self.myranker.find_candidates(
rmentions
)
wk_cands = self.myranker.find_candidates(rmentions)

mentions_dataset = dict()
mentions_dataset["linking"] = []
Expand Down Expand Up @@ -685,9 +683,7 @@ def run_candidate_selection(self, document_dataset: List[dict]) -> dict:
mentions = [{"mention": m} for m in mentions]

# Perform candidate ranking:
wk_cands, self.myranker.already_collected_cands = self.myranker.find_candidates(
mentions
)
wk_cands = self.myranker.find_candidates(rmentions)
return wk_cands

def run_disambiguation(
Expand Down
24 changes: 14 additions & 10 deletions geoparser/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ class Ranker:
>>> # Perform candidate selection
>>> queries = ['London', 'Paraguay']
>>> candidates, already_collected = ranker.run(queries)
>>> candidates = ranker.run(queries)
>>> # Find candidates for mentions
>>> mentions = [{'mention': 'London'}, {'mention': 'Paraguay'}]
>>> mention_candidates, mention_already_collected = ranker.find_candidates(mentions)
>>> mention_candidates = ranker.find_candidates(mentions)
>>> # Print the results
>>> print("Candidate Selection Results:")
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(
"overwrite_training": False,
"do_test": False,
},
already_collected_cands: Optional[dict] = dict(),
already_collected_cands: Optional[dict] = None,
):
"""
Initialize a Ranker object.
Expand All @@ -147,7 +147,11 @@ def __init__(
self.wikidata_to_mentions = wikidata_to_mentions
self.strvar_parameters = strvar_parameters
self.deezy_parameters = deezy_parameters
self.already_collected_cands = already_collected_cands

if already_collected_cands:
self.already_collected_cands = already_collected_cands
else:
self.already_collected_cands = dict()

def __str__(self) -> str:
"""
Expand Down Expand Up @@ -462,7 +466,7 @@ def partial_match(self, queries: List[str], damlev: bool) -> Tuple[dict, dict]:

self.already_collected_cands[query] = mention_df

return candidates, self.already_collected_cands
return candidates

def deezy_on_the_fly(self, queries: List[str]) -> Tuple[dict, dict]:
"""
Expand Down Expand Up @@ -511,7 +515,7 @@ def deezy_on_the_fly(self, queries: List[str]) -> Tuple[dict, dict]:
dm_output = self.deezy_parameters["dm_output"]

# first we fill in the perfect matches and already collected queries
cands_dict, self.already_collected_cands = self.perfect_match(queries)
cands_dict = self.perfect_match(queries)

# the rest go through
remainers = [x for x, y in cands_dict.items() if len(y) == 0]
Expand Down Expand Up @@ -561,7 +565,7 @@ def deezy_on_the_fly(self, queries: List[str]) -> Tuple[dict, dict]:

self.already_collected_cands[row["query"]] = returned_cands

return cands_dict, self.already_collected_cands
return cands_dict

def run(self, queries: List[str]) -> Tuple[dict, dict]:
"""
Expand All @@ -581,7 +585,7 @@ def run(self, queries: List[str]) -> Tuple[dict, dict]:
>>> myranker = Ranker(method="perfectmatch", ...)
>>> myranker.load_resources()
>>> queries = ['London', 'Barcelona', 'Bologna']
>>> candidates, already_collected = myranker.run(queries)
>>> candidates = myranker.run(queries)
>>> print(candidates)
{'London': {'London': 1.0}, 'Barcelona': {'Barcelona': 1.0}, 'Bologna': {'Bologna': 1.0}}
>>> print(already_collected)
Expand Down Expand Up @@ -670,7 +674,7 @@ def find_candidates(self, mentions: List[dict]) -> Tuple[dict, dict]:
queries = list(set([mention["mention"] for mention in mentions]))

# Pass the mentions to :py:meth:`geoparser.ranking.Ranker.run`
cands, self.already_collected_cands = self.run(queries)
cands = self.run(queries)

# Get Wikidata candidates
wk_cands = dict()
Expand Down Expand Up @@ -698,4 +702,4 @@ def find_candidates(self, mentions: List[dict]) -> Tuple[dict, dict]:
"Candidates": found_cands,
}

return wk_cands, self.already_collected_cands
return wk_cands
52 changes: 18 additions & 34 deletions tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def test_perfect_match():
resources_path="resources/wikidata/",
)
myranker.load_resources()
candidates, already_collected_cands = myranker.perfect_match(["London"])
candidates = myranker.perfect_match(["London"])
assert candidates["London"]["London"] == 1.0

candidates, already_collected_cands = myranker.perfect_match(["Lvndon"])
candidates = myranker.perfect_match(["Lvndon"])
assert candidates["Lvndon"] == {}

candidates, already_collected_cands = myranker.perfect_match(["Paperopoli"])
candidates = myranker.perfect_match(["Paperopoli"])
assert candidates["Paperopoli"] == {}


Expand Down Expand Up @@ -103,15 +103,15 @@ def test_partial_match():

# Test that perfect_match acts before partial match
myranker.mentions_to_wikidata = {"London": "Q84"}
candidates, already_collected_cands = myranker.partial_match(
candidates = myranker.partial_match(
["London"], damlev=False
)
assert candidates["London"]["London"] == 1.0

# Test that damlev works
myranker.already_collected_cands = {}

candidates, already_collected_cands = myranker.partial_match(
candidates = myranker.partial_match(
["Lvndvn"], damlev=True
)
assert candidates["Lvndvn"]["London"] == 0.6666666567325592
Expand All @@ -120,22 +120,22 @@ def test_partial_match():
myranker.mentions_to_wikidata = {"New York City": "Q60"}
myranker.already_collected_cands = {}

candidates, already_collected_cands = myranker.partial_match(
candidates = myranker.partial_match(
["New York"], damlev=False
)
assert candidates["New York"]["New York City"] == 0.6153846153846154

myranker.mentions_to_wikidata = {"New York City": "Q60"}
myranker.already_collected_cands = {}

candidates, already_collected_cands = myranker.partial_match(
candidates = myranker.partial_match(
["Lvndvn"], damlev=False
)
assert candidates["Lvndvn"] == {}

myranker.already_collected_cands = {}

candidates, already_collected_cands = myranker.partial_match(
candidates = myranker.partial_match(
["asdasd"], damlev=True
)
assert candidates["asdasd"] == {"New York City": 0.0}
Expand Down Expand Up @@ -177,13 +177,13 @@ def test_deezy_on_the_fly():

# Test that perfect_match acts before deezy
myranker.load_resources()
candidates, already_collected_cands = myranker.deezy_on_the_fly(["London"])
candidates = myranker.deezy_on_the_fly(["London"])
assert candidates["London"]["London"] == 1.0

# Test that deezy works
myranker.already_collected_cands = {}

candidates, already_collected_cands = myranker.deezy_on_the_fly(
candidates = myranker.deezy_on_the_fly(
["Ashton-cnderLyne"]
)
assert (
Expand Down Expand Up @@ -228,18 +228,14 @@ def test_find_candidates():

# Test that perfect_match acts before deezy
myranker.load_resources()
candidates, already_collected_cands = myranker.find_candidates(
[{"mention": "London"}]
)
candidates = myranker.find_candidates([{"mention": "London"}])
assert candidates["London"]["London"]["Score"] == 1.0
assert "Q84" in candidates["London"]["London"]["Candidates"]

# Test that deezy works
myranker.already_collected_cands = {}

candidates, already_collected_cands = myranker.find_candidates(
[{"mention": "Sheftield"}]
)
candidates = myranker.find_candidates([{"mention": "Sheftield"}])
assert (
candidates["Sheftield"]["Sheffield"]["Score"] > 0.0
and candidates["Sheftield"]["Sheffield"]["Score"] < 1.0
Expand All @@ -251,17 +247,13 @@ def test_find_candidates():

# Test that perfect_match acts before deezy
myranker.load_resources()
candidates, already_collected_cands = myranker.find_candidates(
[{"mention": "Sheffield"}]
)
candidates = myranker.find_candidates([{"mention": "Sheffield"}])
assert candidates["Sheffield"]["Sheffield"]["Score"] == 1.0
assert "Q42448" in candidates["Sheffield"]["Sheffield"]["Candidates"]

myranker.already_collected_cands = {}

candidates, already_collected_cands = myranker.find_candidates(
[{"mention": "Sheftield"}]
)
candidates = myranker.find_candidates([{"mention": "Sheftield"}])
assert candidates["Sheftield"] == {}

# Test that check if contained works
Expand All @@ -270,17 +262,13 @@ def test_find_candidates():
# Test that perfect_match acts before partialmatch
myranker.load_resources()

candidates, already_collected_cands = myranker.find_candidates(
[{"mention": "Sheffield"}]
)
candidates = myranker.find_candidates([{"mention": "Sheffield"}])
assert candidates["Sheffield"]["Sheffield"]["Score"] == 1.0
assert "Q42448" in candidates["Sheffield"]["Sheffield"]["Candidates"]

myranker.already_collected_cands = {}

candidates, already_collected_cands = myranker.find_candidates(
[{"mention": "Sheftield"}]
)
candidates = myranker.find_candidates([{"mention": "Sheftield"}])
assert "Sheffield" not in candidates["Sheftield"]

# Test that levenshtein works
Expand All @@ -289,17 +277,13 @@ def test_find_candidates():
# Test that perfect_match acts before partialmatch
myranker.load_resources()

candidates, already_collected_cands = myranker.find_candidates(
[{"mention": "Sheffield"}]
)
candidates = myranker.find_candidates([{"mention": "Sheffield"}])
assert candidates["Sheffield"]["Sheffield"]["Score"] == 1.0
assert "Q42448" in candidates["Sheffield"]["Sheffield"]["Candidates"]

myranker.already_collected_cands = {}

candidates, already_collected_cands = myranker.find_candidates(
[{"mention": "Sheftield"}]
)
candidates = myranker.find_candidates([{"mention": "Sheftield"}])
assert (
candidates["Sheftield"]["Sheffield"]["Score"] > 0.0
and candidates["Sheftield"]["Sheffield"]["Score"] < 1.0
Expand Down
2 changes: 1 addition & 1 deletion utils/rel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def prepare_rel_trainset(
# Format the mentions are required by the ranker:
all_mentions = [{"mention": mention} for mention in all_mentions]
# Use the ranker to find candidates:
wk_cands, myranker.already_collected_cands = myranker.find_candidates(all_mentions)
wk_cands = myranker.find_candidates(all_mentions)
# Rank the candidates:
rel_json = rank_candidates(
rel_json,
Expand Down

0 comments on commit 81bc751

Please sign in to comment.