-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathpredict_sparc.py
101 lines (86 loc) · 3.35 KB
/
predict_sparc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import copy
import json
import torch
import pprint
import utils
import random
import dataset
import argparse
from model.model import Module
mydir = os.path.dirname(__file__)
parser = argparse.ArgumentParser()
parser.add_argument('resume')
parser.add_argument('input')
parser.add_argument('--dataset', default='sparc', choices=['cosql', 'sparc'])
parser.add_argument('--tables', default='tables.json')
parser.add_argument('--db', default='database')
parser.add_argument('--dcache', default='cache')
parser.add_argument('--batch', type=int, default=6)
parser.add_argument('--output', default='output.txt')
def main(orig_args):
# load pretrained model
fresume = os.path.abspath(orig_args.resume)
# print('resuming from {}'.format(fresume))
assert os.path.isfile(fresume), '{} does not exist'.format(fresume)
orig_args.input = os.path.abspath(orig_args.input)
orig_args.tables = os.path.abspath(orig_args.tables)
orig_args.db = os.path.abspath(orig_args.db)
orig_args.dcache = os.path.abspath(orig_args.dcache)
binary = torch.load(fresume, map_location=torch.device('cpu'))
args = binary['args']
ext = binary['ext']
args.gpu = torch.cuda.is_available()
args.tables = orig_args.tables
args.db = orig_args.db
args.dcache = orig_args.dcache
args.batch = orig_args.batch
Model = utils.load_module(args.model)
if args.model == 'nl2sql':
Reranker = utils.load_module(args.beam_rank)
ext['reranker'] = Reranker(args, ext)
m = Model(args, ext).place_on_device()
m.load_save(fname=fresume)
# preprocess data
data = dataset.Dataset()
if args.dataset == 'sparc':
import preprocess_nl2sql_sparc as preprocess
elif args.dataset == 'cosql':
import preprocess_nl2sql_cosql as preprocess
else:
raise NotImplementedError()
proc_errors = set()
with open(orig_args.input) as f:
C = preprocess.SQLDataset
raw = json.load(f)
# make contexts and populate vocab
for i, ex in enumerate(raw):
for turn_i, turn in enumerate(ex['interaction']):
turn['id'] = '{}/{}:{}'.format(ex['database_id'], i, turn_i)
turn['db_id'] = ex['database_id']
for k in ['query', 'query_toks', 'query_toks_no_value', 'sql']:
if k in turn:
del turn[k]
turn['question'] = turn['utterance']
turn['g_question_toks'] = C.tokenize_question(turn['utterance'].split(), m.bert_tokenizer)
turn['value_context'] = [m.bert_tokenizer.cls_token] + turn['g_question_toks'] + [m.bert_tokenizer.sep_token]
turn['turn_i'] = turn_i
data.append(turn)
# run preds
preds = m.run_interactive_pred(data, args, verbose=True)
assert len(preds) == len(data), 'got {} predictions for {} examples'.format(len(preds), len(data))
# print('writing to {}'.format(orig_args.output))
with open(orig_args.output, 'wt') as f:
for i, ex in enumerate(data):
if i != 0 and ex['turn_i'] == 0:
f.write('\n')
if ex['id'] in proc_errors:
s = 'ERROR'
else:
p = preds[ex['id']]
s = p['query']
f.write(s + '\n')
f.flush()
if __name__ == '__main__':
args = parser.parse_args()
main(args)