diff --git a/baselines/EMNLP2019/entitylinker.py b/baselines/EMNLP2019/entitylinker.py index ce307b2..1eebaac 100644 --- a/baselines/EMNLP2019/entitylinker.py +++ b/baselines/EMNLP2019/entitylinker.py @@ -8,6 +8,7 @@ import time import diffbot_nlapi import logging +import pathlib from config import MODEL, NUMBER_URI_CANDIDATES, SOFT_COREF_CANDIDATES @@ -15,6 +16,7 @@ Span.set_extension("el_candidates", default=[]) Span.set_extension("uri_candidates", default=[]) +pathlib.Path('tmp').mkdir(parents=True, exist_ok=True) db = SqliteDict(os.path.join('tmp','el.db'), autocommit=True) configuration = diffbot_nlapi.Configuration() diff --git a/baselines/EMNLP2019/prepare_data.py b/baselines/EMNLP2019/prepare_data.py index 80a2577..e7bb2fb 100644 --- a/baselines/EMNLP2019/prepare_data.py +++ b/baselines/EMNLP2019/prepare_data.py @@ -12,7 +12,7 @@ import evaluator from instance import Instance import numpy as np -from run import generate_candidates, classify_instances +from run import generate_candidates from vocab import Vocab from config import NUMBER_URI_CANDIDATES_TO_CONSIDER, URI_THRESHOLD, MODEL, KNOWLEDGE_NET_DIR, MULTITASK diff --git a/baselines/EMNLP2019/run.py b/baselines/EMNLP2019/run.py index 6abfcf1..8333e0f 100644 --- a/baselines/EMNLP2019/run.py +++ b/baselines/EMNLP2019/run.py @@ -5,6 +5,7 @@ import entitylinker import bert_wrapper import yaml +import sys import spacy nlp = spacy.load('en_core_web_lg') @@ -122,7 +123,7 @@ def generate_candidates(text): return instances -def classify_instances(instances, predicate_ids_to_classify=predicate_thresholds.keys(), includeUri=True): +def classify_instances(instances, predicate_ids_to_classify=None, includeUri=True): if len(instances) == 0: return max_length = max([ len(instance.get_words()) for instance in instances ]) @@ -148,6 +149,8 @@ def classify_instances(instances, predicate_ids_to_classify=predicate_thresholds global_features_batch.append(global_features) for sess, model_preds, predicate_thresholds in models: + if predicate_ids_to_classify is None: + predicate_ids_to_classify = predicate_thresholds.keys() model_preds = predicate_ids_to_classify if CANDIDATE_RECALL else model_preds if len(set(model_preds).intersection(predicate_ids_to_classify)) == 0: continue @@ -195,6 +198,9 @@ def run_batch(texts): return ret if __name__ == "__main__": + if len(models) == 0: + print("No trained model") + sys.exit(0) import fileinput for line in fileinput.input(): for fact in run_batch([line])[0]: