Skip to content

Commit

Permalink
Fixed an issue where preparing the data would fail when you didn't ha…
Browse files Browse the repository at this point in the history
…ve the pretrained models
  • Loading branch information
schmidek committed Nov 27, 2019
1 parent 5eb188f commit 6b1a7de
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
2 changes: 2 additions & 0 deletions baselines/EMNLP2019/entitylinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import time
import diffbot_nlapi
import logging
import pathlib

from config import MODEL, NUMBER_URI_CANDIDATES, SOFT_COREF_CANDIDATES

# el_candidate has types, uri, score
Span.set_extension("el_candidates", default=[])
Span.set_extension("uri_candidates", default=[])

pathlib.Path('tmp').mkdir(parents=True, exist_ok=True)
db = SqliteDict(os.path.join('tmp','el.db'), autocommit=True)

configuration = diffbot_nlapi.Configuration()
Expand Down
2 changes: 1 addition & 1 deletion baselines/EMNLP2019/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import evaluator
from instance import Instance
import numpy as np
from run import generate_candidates, classify_instances
from run import generate_candidates
from vocab import Vocab
from config import NUMBER_URI_CANDIDATES_TO_CONSIDER, URI_THRESHOLD, MODEL, KNOWLEDGE_NET_DIR, MULTITASK

Expand Down
8 changes: 7 additions & 1 deletion baselines/EMNLP2019/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import entitylinker
import bert_wrapper
import yaml
import sys

import spacy
nlp = spacy.load('en_core_web_lg')
Expand Down Expand Up @@ -122,7 +123,7 @@ def generate_candidates(text):

return instances

def classify_instances(instances, predicate_ids_to_classify=predicate_thresholds.keys(), includeUri=True):
def classify_instances(instances, predicate_ids_to_classify=None, includeUri=True):
if len(instances) == 0:
return
max_length = max([ len(instance.get_words()) for instance in instances ])
Expand All @@ -148,6 +149,8 @@ def classify_instances(instances, predicate_ids_to_classify=predicate_thresholds
global_features_batch.append(global_features)

for sess, model_preds, predicate_thresholds in models:
if predicate_ids_to_classify is None:
predicate_ids_to_classify = predicate_thresholds.keys()
model_preds = predicate_ids_to_classify if CANDIDATE_RECALL else model_preds
if len(set(model_preds).intersection(predicate_ids_to_classify)) == 0:
continue
Expand Down Expand Up @@ -195,6 +198,9 @@ def run_batch(texts):
return ret

if __name__ == "__main__":
if len(models) == 0:
print("No trained model")
sys.exit(0)
import fileinput
for line in fileinput.input():
for fact in run_batch([line])[0]:
Expand Down

0 comments on commit 6b1a7de

Please sign in to comment.