-
Notifications
You must be signed in to change notification settings - Fork 17
/
prepro.py
101 lines (88 loc) · 3.35 KB
/
prepro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
preprocess NLVR annotations into LMDB
"""
import argparse
import json
import os
from os.path import exists
from cytoolz import curry
from tqdm import tqdm
from pytorch_pretrained_bert import BertTokenizer
from data.data import open_lmdb
@curry
def bert_tokenize(tokenizer, text):
ids = []
for word in text.strip().split():
ws = tokenizer.tokenize(word)
if not ws:
# some special char
continue
ids.extend(tokenizer.convert_tokens_to_ids(ws))
return ids
def process_nlvr2(jsonl, db, tokenizer, missing=None):
id2len = {}
txt2img = {} # not sure if useful
for line in tqdm(jsonl, desc='processing NLVR2'):
example = json.loads(line)
id_ = example['identifier']
img_id = '-'.join(id_.split('-')[:-1])
img_fname = (f'nlvr2_{img_id}-img0.npz', f'nlvr2_{img_id}-img1.npz')
if missing and (img_fname[0] in missing or img_fname[1] in missing):
continue
input_ids = tokenizer(example['sentence'])
if 'label' in example:
target = 1 if example['label'] == 'True' else 0
else:
target = None
txt2img[id_] = img_fname
id2len[id_] = len(input_ids)
example['input_ids'] = input_ids
example['img_fname'] = img_fname
example['target'] = target
db[id_] = example
return id2len, txt2img
def main(opts):
if not exists(opts.output):
os.makedirs(opts.output)
else:
raise ValueError('Found existing DB. Please explicitly remove '
'for re-processing')
meta = vars(opts)
meta['tokenizer'] = opts.toker
toker = BertTokenizer.from_pretrained(
opts.toker, do_lower_case='uncased' in opts.toker)
tokenizer = bert_tokenize(toker)
meta['UNK'] = toker.convert_tokens_to_ids(['[UNK]'])[0]
meta['CLS'] = toker.convert_tokens_to_ids(['[CLS]'])[0]
meta['SEP'] = toker.convert_tokens_to_ids(['[SEP]'])[0]
meta['MASK'] = toker.convert_tokens_to_ids(['[MASK]'])[0]
meta['v_range'] = (toker.convert_tokens_to_ids('!')[0],
len(toker.vocab))
with open(f'{opts.output}/meta.json', 'w') as f:
json.dump(vars(opts), f, indent=4)
open_db = curry(open_lmdb, opts.output, readonly=False)
with open_db() as db:
with open(opts.annotation) as ann:
if opts.missing_imgs is not None:
missing_imgs = set(json.load(open(opts.missing_imgs)))
else:
missing_imgs = None
id2lens, txt2img = process_nlvr2(ann, db, tokenizer, missing_imgs)
with open(f'{opts.output}/id2len.json', 'w') as f:
json.dump(id2lens, f)
with open(f'{opts.output}/txt2img.json', 'w') as f:
json.dump(txt2img, f)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--annotation', required=True,
help='annotation JSON')
parser.add_argument('--missing_imgs',
help='some training image features are corrupted')
parser.add_argument('--output', required=True,
help='output dir of DB')
parser.add_argument('--toker', default='bert-base-cased',
help='which BERT tokenizer to used')
args = parser.parse_args()
main(args)