-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain_tilde.py
163 lines (130 loc) · 5.43 KB
/
train_tilde.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import pytorch_lightning as pl
from transformers import BertTokenizer
import torch
from tools import get_stop_ids
import random
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import loggers as pl_loggers
from argparse import ArgumentParser
from torch.utils.data import Dataset, DataLoader
from modeling import TILDE
import os
MODEL_TYPE = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", cache_dir='./cache')
class CheckpointEveryEpoch(pl.Callback):
"""
Save a checkpoint every N steps, instead of Lightning's default that checkpoints
based on validation loss.
"""
def __init__(
self,
start_epoc,
save_path,
):
self.start_epoc = start_epoc
self.file_path = save_path
def on_epoch_end(self, trainer: pl.Trainer, _):
""" Check if we should save a checkpoint after every train epoch """
epoch = trainer.current_epoch
if epoch >= self.start_epoc:
ckpt_path = os.path.join(self.save_path, f"epoch_{epoch+1}.ckpt")
trainer.save_checkpoint(ckpt_path)
class MsmarcoDocumentQueryPair(Dataset):
def __init__(self, path):
self.tokenizer = BertTokenizer.from_pretrained(MODEL_TYPE, cache_dir='./cache')
self.path = path
self.queries = []
self.passages = []
self.stop_ids = list(get_stop_ids(self.tokenizer))
with open(path, 'r') as f:
contents = f.readlines()
for line in contents:
passage, query = line.strip().split('\t')
self.queries.append(query)
self.passages.append(passage)
def __getitem__(self, index):
query = self.queries[index]
passage = self.passages[index]
ind = self.tokenizer(query, add_special_tokens=False)['input_ids']
cleaned_ids = []
for id in ind:
if id not in self.stop_ids:
cleaned_ids.append(id)
yq = torch.zeros(self.tokenizer.vocab_size, dtype=torch.float32)
yq[cleaned_ids] = 1
yq[self.stop_ids] = -1
ind = self.tokenizer(passage, add_special_tokens=False)['input_ids']
cleaned_ids = []
for id in ind:
if id not in self.stop_ids:
cleaned_ids.append(id)
yd = torch.zeros(self.tokenizer.vocab_size, dtype=torch.float32)
yd[cleaned_ids] = 1
yd[self.stop_ids] = -1
return passage, yq, query, yd
def __len__(self):
return len(self.queries)
def make_negative_labels(ys):
batch_size = len(ys)
neg_ys = []
for i in range(batch_size):
weigths = [1/(batch_size-1)] * batch_size
weigths[i] = 0
neg_ys.append(random.choices(ys, weights=weigths)[0])
return neg_ys
def collate_fn(batch):
passages = []
queries = []
yqs = []
yds = []
for passage, yq, query, yd in batch:
passages.append(passage)
yqs.append(yq)
queries.append(query)
yds.append(yd)
passage_inputs = tokenizer(passages, return_tensors="pt", padding=True, truncation=True)
passage_input_ids = passage_inputs["input_ids"]
passage_token_type_ids = passage_inputs["token_type_ids"]
passage_attention_mask = passage_inputs["attention_mask"]
neg_yqs = make_negative_labels(yqs)
query_inputs = tokenizer(queries, return_tensors="pt", padding=True, truncation=True)
query_input_ids = query_inputs["input_ids"]
query_token_type_ids = query_inputs["token_type_ids"]
query_attention_mask = query_inputs["attention_mask"]
neg_yds = make_negative_labels(yds)
passage_input_ids[:, 0] = 1 # 1 is token id for [DOC]
query_input_ids[:, 0] = 2 # 2 is token id for [QRY]
return passage_input_ids, passage_token_type_ids, passage_attention_mask, torch.stack(yqs), torch.stack(neg_yqs), \
query_input_ids, query_token_type_ids, query_attention_mask, torch.stack(yds), torch.stack(neg_yds)
def main(args):
seed_everything(313)
tb_logger = pl_loggers.TensorBoardLogger('logs/'.format(MODEL_TYPE))
model = TILDE(MODEL_TYPE, gradient_checkpointing=args.gradient_checkpoint)
dataset = MsmarcoDocumentQueryPair(args.train_path)
loader = DataLoader(dataset,
batch_size=128,
drop_last=True,
pin_memory=True,
shuffle=True,
num_workers=10,
collate_fn=collate_fn)
trainer = Trainer(max_epochs=10,
gpus=1,
checkpoint_callback=False,
logger=tb_logger,
# accelerator="ddp",
# plugins='ddp_sharded',
callbacks=[CheckpointEveryEpoch(0, args.save_path)]
)
trainer.fit(model, loader)
print("Saving the final checkpoint as a huggingface model...")
model_to_save = TILDE.load_from_checkpoint(model_type=MODEL_TYPE, checkpoint_path=os.path.join(args.save_path, 'epoch_10.ckpt'))
model_to_save.save(os.path.join(args.save_path, 'TILDE'))
if __name__ == '__main__':
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
parser.add_argument("--train_path", required=True)
parser.add_argument("--save_path", required=True)
parser.add_argument("--gradient_checkpoint", action='store_true', help='Ture for trade off training speed for larger batch size')
args = parser.parse_args()
main(args)