-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathfra_eng_dataset.py
167 lines (129 loc) · 5.89 KB
/
fra_eng_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from torch.utils.data import Dataset
import pickle
from nltk.tokenize import word_tokenize
import os
import torch
import numpy as np
class FraEngDataset(Dataset):
def __init__(self, data_source_path = 'fra-eng/fra.txt'):
super().__init__()
data_file_path = "sentences.pkl"
self.sentence_list = []
self.eng_token_dict = dict()
self.eng_token_dict['<PAD>'] = 0
self.eng_token_dict['<EOS>'] = 1
self.eng_token_dict['<START>'] = 2
self.eng_token_count = 2
self.eng_token_to_text = ['<PAD>', '<EOS>', '<START>']
self.fra_token_dict = dict()
self.fra_token_dict['<PAD>'] = 0
self.fra_token_dict['<EOS>'] = 1
self.fra_token_dict['<START>'] = 2
self.fra_token_count = 2
self.fra_token_to_text = ['<PAD>', '<EOS>', '<START>']
if os.path.exists(data_file_path):
with open(data_file_path, 'rb') as f:
pickle_data = pickle.load(f)
self.sentence_list = pickle_data['sentence_list']
self.eng_token_count = pickle_data['eng_token_count']
self.eng_token_to_text = pickle_data['eng_token_to_text']
self.fra_token_count = pickle_data['fra_token_count']
self.fra_token_to_text = pickle_data['fra_token_to_text']
else:
with open(data_source_path, "r", encoding='utf-8') as f:
for idx, line in enumerate(f.readlines()):
eng_token_sentence = []
fra_token_sentence = []
snt = line.split('\t')
eng_sentence = snt[0]
fra_sentence = snt[1]
eng_token_list = word_tokenize(eng_sentence)
for token in eng_token_list:
if token not in self.eng_token_dict:
self.eng_token_count += 1
self.eng_token_dict[token] = self.eng_token_count
self.eng_token_to_text.append(token)
token_idx = self.eng_token_dict[token]
eng_token_sentence.append(token_idx)
eng_token_sentence = [self.eng_token_dict['<START>']] + eng_token_sentence
eng_token_sentence.append(self.eng_token_dict['<EOS>'])
fra_token_list = word_tokenize(fra_sentence)
for token in fra_token_list:
if token not in self.fra_token_dict:
self.fra_token_count += 1
self.fra_token_dict[token] = self.fra_token_count
self.fra_token_to_text.append(token)
token_idx = self.fra_token_dict[token]
fra_token_sentence.append(token_idx)
fra_token_sentence = [self.eng_token_dict['<START>']] + eng_token_sentence
fra_token_sentence.append(self.fra_token_dict['<EOS>'])
self.sentence_list.append(
dict(
eng = eng_token_sentence,
fra = fra_token_sentence
))
with open(data_file_path, "wb") as f:
pickle_data = dict(
sentence_list = self.sentence_list,
eng_token_count = self.eng_token_count,
fra_token_count = self.fra_token_count,
eng_token_to_text = self.eng_token_to_text,
fra_token_to_text = self.fra_token_to_text
)
pickle.dump(pickle_data, f)
print(len(self.sentence_list))
def get_eng_dict_size(self):
return self.eng_token_count + 1
def get_fra_dict_size(self):
return self.fra_token_count + 1
def get_fra_eos_code(self):
return self.fra_token_dict['<EOS>']
def get_eng_eos_code(self):
return self.eng_token_dict['<EOS>']
def get_fra_start_code(self):
return self.fra_token_dict['<START>']
def get_eng_start_code(self):
return self.eng_token_dict['<START>']
def get_eng_pad_code(self):
return self.eng_token_dict['<PAD>']
def __len__(self):
return len(self.sentence_list)
def __getitem__(self, item):
ret = dict()
for key in self.sentence_list[item]:
ret[key] = torch.tensor(self.sentence_list[item][key])
return ret
def fra_eng_dataset_collate(data):
MAXMAX_SENTENCE_LEN = 40
eng_sentences = []
eng_sentence_lens = []
fra_sentences = []
fra_sentence_lens = []
eng_sentences_sorted = []
eng_sentence_lens_sorted = []
fra_sentences_sorted = []
fra_sentence_lens_sorted = []
for s in data:
sent = s['eng']
if len(sent) > MAXMAX_SENTENCE_LEN:
sent = sent[0:MAXMAX_SENTENCE_LEN]
eng_sentences.append(sent.unsqueeze(dim=1))
eng_sentence_lens.append(len(sent))
sent = s['fra']
if len(sent) > MAXMAX_SENTENCE_LEN:
sent = sent[0:MAXMAX_SENTENCE_LEN]
fra_sentences.append(sent.unsqueeze(dim=1))
fra_sentence_lens.append(len(sent))
#Rearrange everything by eng sentence lens
sort_idxes = np.argsort(np.array(eng_sentence_lens))[::-1]
for idx in sort_idxes:
eng_sentences_sorted.append(eng_sentences[idx])
eng_sentence_lens_sorted.append(eng_sentence_lens[idx])
fra_sentences_sorted.append(fra_sentences[idx])
fra_sentence_lens_sorted.append(fra_sentence_lens[idx])
return dict(
eng_sentences = eng_sentences_sorted,
eng_lens = eng_sentence_lens_sorted,
fra_sentences = fra_sentences_sorted,
fra_lens = fra_sentence_lens_sorted
)