-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtokenization.py
68 lines (62 loc) · 2.18 KB
/
tokenization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from transformers.file_utils import PaddingStrategy
def tokenized_dataset(dataset, tokenizer, is_inference=False, is_mlm=False):
"""
tokenizer에 따라 sentence를 tokenizing 합니다.
"""
concat_entity = []
concat_entity = list(
dataset.apply(
lambda x: x["subject_entity"] + "[SEP]" + x["object_entity"], axis=1
)
)
""" Roberta TTI_flag with dynamic_padding"""
if "roberta" in tokenizer.name_or_path and not "xlm" in tokenizer.name_or_path:
tokenized_sentences = tokenizer(
concat_entity,
list(dataset["sentence"]),
return_tensors="pt" if is_mlm or is_inference else None,
padding=True
if is_mlm or is_inference
else PaddingStrategy.DO_NOT_PAD.value,
truncation=True,
max_length=256,
add_special_tokens=True,
return_token_type_ids=False,
)
else:
tokenized_sentences = tokenizer(
concat_entity,
list(dataset["sentence"]),
return_tensors="pt" if is_mlm or is_inference else None,
padding=True
if is_mlm or is_inference
else PaddingStrategy.DO_NOT_PAD.value,
truncation=True,
max_length=256,
add_special_tokens=True,
)
return tokenized_sentences
def tokenized_mlm_dataset(dataset, tokenizer):
"""
Masked Language Model의 입력 dataset을 토큰화 합니다.
"""
if "roberta" in tokenizer.name_or_path and not "xlm" in tokenizer.name_or_path:
tokenized_sentences = tokenizer(
list(dataset["sentence"]),
return_tensors="pt",
padding=True, # TO DO: dynamic_padding 사용하는데도 에러 발생 (mlm)
truncation=True,
max_length=256,
add_special_tokens=True,
return_token_type_ids=False,
)
else:
tokenized_sentences = tokenizer(
list(dataset["sentence"]),
return_tensors="pt",
padding=True,
truncation=True,
max_length=256,
add_special_tokens=True,
)
return tokenized_sentences