-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
135 lines (111 loc) · 4.72 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import random
import torch
import numpy as np
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:2"
os.environ["PL_GLOBAL_SEED"] = str(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def fix_tokenizer(tokenizer):
# Fixing broken tokenizers
special_tokens = dict()
for token_id in range(1000):
token = tokenizer.convert_ids_to_tokens(token_id)
if tokenizer.pad_token_id in (None, tokenizer.vocab_size) and "pad" in token:
special_tokens["pad_token"] = token
if tokenizer.bos_token_id in (None, tokenizer.vocab_size) and "<s>" in token:
special_tokens["bos_token"] = token
if tokenizer.eos_token_id in (None, tokenizer.vocab_size) and "</s>" in token:
special_tokens["eos_token"] = token
if tokenizer.unk_token_id in (None, tokenizer.vocab_size) and "unk" in token:
special_tokens["unk_token"] = token
if tokenizer.sep_token_id in (None, tokenizer.vocab_size) and "sep" in token:
special_tokens["sep_token"] = token
if tokenizer.sep_token_id in (None, tokenizer.vocab_size) and "bos_token" in special_tokens:
special_tokens["sep_token"] = special_tokens["bos_token"]
if tokenizer.pad_token_id in (None, tokenizer.vocab_size) and "pad_token" not in special_tokens:
if tokenizer.unk_token_id is not None:
special_tokens["pad_token"] = tokenizer.unk_token
else:
special_tokens["pad_token"] = "<|pad|>"
if tokenizer.sep_token_id in (None, tokenizer.vocab_size) and "sep_token" not in special_tokens:
if tokenizer.bos_token_id is not None:
special_tokens["sep_token"] = tokenizer.bos_token
else:
special_tokens["sep_token"] = "<|sep|>"
print(special_tokens)
tokenizer.add_special_tokens(special_tokens)
print("Vocab size: ", tokenizer.vocab_size)
print("PAD: ", tokenizer.pad_token_id, tokenizer.pad_token)
print("BOS: ", tokenizer.bos_token_id, tokenizer.bos_token)
print("EOS: ", tokenizer.eos_token_id, tokenizer.eos_token)
print("UNK: ", tokenizer.unk_token_id, tokenizer.unk_token)
print("SEP: ", tokenizer.sep_token_id, tokenizer.sep_token)
return tokenizer
def fix_model(model, tokenizer, use_resize=True):
model.config.pad_token_id = tokenizer.pad_token_id
assert model.config.pad_token_id is not None
bos_candidates = (
tokenizer.bos_token_id,
tokenizer.cls_token_id,
tokenizer.sep_token_id,
tokenizer.unk_token_id
)
for bos_candidate in bos_candidates:
model.config.bos_token_id = bos_candidate
if bos_candidate is not None:
break
assert model.config.bos_token_id is not None
model.config.decoder_start_token_id = model.config.bos_token_id
eos_candidates = (tokenizer.eos_token_id, tokenizer.sep_token_id)
for eos_candidate in eos_candidates:
model.config.eos_token_id = eos_candidate
if eos_candidate is not None:
break
assert model.config.eos_token_id is not None
if use_resize:
model.resize_token_embeddings(len(tokenizer))
return model
def gen_batch(records, batch_size):
batch_start = 0
while batch_start < len(records):
batch_end = batch_start + batch_size
batch = records[batch_start: batch_end]
batch_start = batch_end
yield batch
def print_special_tokens(tokenizer):
print("Vocab size: ", tokenizer.vocab_size)
print("PAD: ", tokenizer.pad_token_id, tokenizer.pad_token)
print("BOS: ", tokenizer.bos_token_id, tokenizer.bos_token)
print("EOS: ", tokenizer.eos_token_id, tokenizer.eos_token)
print("UNK: ", tokenizer.unk_token_id, tokenizer.unk_token)
print("SEP: ", tokenizer.sep_token_id, tokenizer.sep_token)
return tokenizer
# PAD: 0 <unk>
# BOS: 1 <s>
# EOS: 2 </s>
# UNK: 0 <unk>
# SEP: 1 <s>
def fix_tokenizer_opt(tokenizer):
# Fixing broken tokenizers
special_tokens = {
'pad_token': '<unk>',
'bos_token': '<s>',
'eos_token': '</s>',
'unk_token': '<unk>',
'sep_token': '<s>'
}
tokenizer.add_special_tokens(special_tokens)
print("Vocab size: ", tokenizer.vocab_size)
print("PAD: ", tokenizer.pad_token_id, tokenizer.pad_token)
print("BOS: ", tokenizer.bos_token_id, tokenizer.bos_token)
print("EOS: ", tokenizer.eos_token_id, tokenizer.eos_token)
print("UNK: ", tokenizer.unk_token_id, tokenizer.unk_token)
print("SEP: ", tokenizer.sep_token_id, tokenizer.sep_token)
return tokenizer