-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathindexingv2.py
121 lines (100 loc) · 4.28 KB
/
indexingv2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import argparse
from tqdm import tqdm
import numpy as np
import json
import h5py
import os
from tools import get_stop_ids
from transformers import AutoTokenizer, PreTrainedTokenizer, BatchEncoding, DataCollatorWithPadding
from modelingv2 import TILDEv2
from torch.utils.data import Dataset, DataLoader
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class MsmarcoDataset(Dataset):
def __init__(self, collection_path: str, tokenizer: PreTrainedTokenizer, p_max_len=192):
self.collection = []
self.docids = []
for filename in os.listdir(collection_path):
with open(f"{collection_path}/{filename}", 'r') as f:
lines = f.readlines()
for line in tqdm(lines, desc="loading collection...."):
data = json.loads(line)
self.collection.append(data['psg'])
self.docids.append(data['pid'])
self.tok = tokenizer
self.p_max_len = p_max_len
def __len__(self):
return len(self.collection)
def __getitem__(self, item) -> [BatchEncoding, BatchEncoding]:
psg = self.collection[item]
encoded_psg = self.tok.encode_plus(
psg,
max_length=self.p_max_len,
truncation='only_first',
return_attention_mask=False,
)
return encoded_psg
def get_docids(self):
return self.docids
def main(args):
tokenizer = AutoTokenizer.from_pretrained(args.ckpt_path_or_name, use_fast=False, cache_dir='./cache')
model = TILDEv2.from_pretrained(args.ckpt_path_or_name, cache_dir='./cache').eval().to(DEVICE)
sepcial_token_ids = tokenizer.all_special_ids
stop_ids = get_stop_ids(tokenizer)
stop_ids = stop_ids.union(sepcial_token_ids) # add bert special token ids as well.
dataset = MsmarcoDataset(
args.collection_path, tokenizer, p_max_len=args.p_max_len,
)
docids = dataset.get_docids()
np.save(os.path.join(args.output_path, "docids.npy"), np.array(docids))
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
collate_fn=DataCollatorWithPadding(
tokenizer,
max_length=args.p_max_len,
padding='max_length'
),
shuffle=False, # important
drop_last=False, # important
num_workers=args.num_workers,
)
dt_token_id = h5py.vlen_dtype(np.dtype('int16'))
dt_embedding = h5py.vlen_dtype(np.dtype('float16'))
dt_compound = np.dtype([('embedding', dt_embedding), ('token_ids', dt_token_id)])
f = h5py.File(os.path.join(args.output_path, "tildev2_index.hdf5"), "w")
dset = f.create_dataset("documents", (len(docids),), dtype=dt_compound)
docno = 0
for passage_inputs in tqdm(data_loader):
passage_token_ids = passage_inputs["input_ids"].cpu().numpy().astype(np.int16)
passage_inputs.to(DEVICE)
with torch.no_grad():
passage_outputs = model.encode(**passage_inputs)
passage_outputs = passage_outputs.squeeze(1).cpu().numpy().astype(np.float16)
for inbatch_idx in range(len(passage_token_ids)):
token_scores = []
token_ids = []
for idx, token_id in enumerate(passage_token_ids[inbatch_idx]):
if token_id in stop_ids:
continue
score = passage_outputs[inbatch_idx][idx]
token_scores.append(score)
token_ids.append(token_id)
token_scores = np.array(token_scores, dtype=np.float16)
token_ids = np.array(token_ids, dtype=np.int16)
doc = (token_scores, token_ids)
dset[docno] = doc
docno += 1
f.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path_or_name", type=str, required=True)
parser.add_argument("--collection_path", type=str, required=True)
parser.add_argument("--output_path", type=str, default="./data/index/TILDEv2")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--p_max_len", type=int, default=192)
parser.add_argument("--num_workers", type=int, default=4)
args = parser.parse_args()
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
main(args)