forked from AnthonyMRios/relation-extraction-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pred.py
81 lines (64 loc) · 2.56 KB
/
pred.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
Usage:
pred.py [options]
Options:
-h --help show this help message and exit
--word2vec=<file> word vectors in gensim format
--dataset=<file> dataset (see data folder for example)
--test_ids=<file> ids of test examples (see data folder for example)
--model=<file> filename to use to save model
--mini_batch_size=<arg> Minibatch size [default: 32]
--num_classes=<arg> Total number of classes for training [default: 5]
--lstm_hidden_state=<arg> lstm hidden state size [default: 256]
--random_seed=<arg> random seed [default: 42]
"""
import logging
import pickle
import random
import sys
from models.bilstm import BiLSTM
import docopt
import numpy as np
from sklearn.metrics import f1_score
def main(argv):
argv = docopt.docopt(__doc__, argv=argv)
random_seed = argv['--random_seed']
np.random.seed(random_seed)
random.seed(random_seed)
mini_batch_size = argv['--mini_batch_size']
def read_ids(file):
ids = []
with open(file, 'r') as fp:
for row in fp:
ids.append(row.strip())
return ids
test_ids = read_ids(argv['<test_ids>'])
with open(argv['--model']) as fp:
tmp = pickle.load(fp)
ld = tmp['token']
mod = BiLSTM(ld.embs, ld.pos, ld.pospeech, ld.chunk, nc=5, nh=2048, de=ld.embs.shape[1])
mod.__setstate__(tmp['model_params'])
pairs_idx, chunk_idx, pos_idx, pos_e1_idx, pos_e2_idx, _, subj_y, pred_y, obj_y, idents, e1_ids, e2_ids = ld.transform(argv['--dataset'], test_ids)
test_idxs = list(range(len(pairs_idx)))
all_test_preds = []
scores = []
for start, end in zip(range(0, len(test_idxs), mini_batch_size),
range(mini_batch_size, len(test_idxs) + mini_batch_size,
mini_batch_size)):
if len(test_idxs[start:end]) == 0:
continue
tpairs = ld.pad_data([pairs_idx[i] for i in test_idxs[start:end]])
te1 = ld.pad_data([pos_e1_idx[i] for i in test_idxs[start:end]])
te2 = ld.pad_data([pos_e2_idx[i] for i in test_idxs[start:end]])
preds = mod.predict_proba(tpairs, te1, te2, np.float32(1.))
for x in preds:
if x > 0.5:
all_test_preds.append(1)
else:
all_test_preds.append(0)
test_f1 = f1_score(y, all_test_preds, average='binary')
print("test_f1: %.4f" % (test_f1))
sys.stdout.flush()
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
main(sys.argv[1:])