-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathutils_v2.py
255 lines (192 loc) · 8.93 KB
/
utils_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
# -*- coding: utf-8 -*-
from __future__ import print_function
from params import Params as pm
import os
from collections import Counter
from tqdm import tqdm
import tensorflow as tf
import matplotlib.pyplot as plt
def build_vocab(path, fname):
"""
Constructs vocabulary as a dictionary.
Args:
:param path: [String], Input file path
:param fname: [String], Output file name
"""
words = open(path, 'r', encoding='utf-8').read().split()
wordCount = Counter(words)
if not os.path.exists(pm.vocab_path):
os.makedirs(pm.vocab_path)
with open(pm.vocab_path + fname, 'w', encoding='utf-8') as f:
f.write("{}\t1000000000\n{}\t1000000000\n{}\t1000000000\n{}\t1000000000\n".format("<PAD>", "<UNK>", "<SOS>", "<EOS>"))
for word, count in wordCount.most_common(len(wordCount)):
f.write(u"{}\t{}\n".format(word, count))
def load_vocab(vocab):
"""
Load word token from encoding dictionary.
Args:
:param vocab: [String], vocabulary files
:return: tokenizer
"""
vocab = [line.split()[0] for line in open(
'{}{}'.format(pm.vocab_path, vocab), 'r', encoding='utf-8').read().splitlines()
if int(line.split()[1]) >= pm.word_limit_size]
word2idx_dic = {word: idx for idx, word in enumerate(vocab)}
idx2word_dic = {idx: word for idx, word in enumerate(vocab)}
return word2idx_dic, idx2word_dic
if not os.path.exists(pm.vocab_path) or pm.rebuild_vocabulary:
build_vocab(pm.src_train, "en.vocab.tsv")
build_vocab(pm.tgt_train, "de.vocab.tsv")
en2idx, idx2en = load_vocab("en.vocab.tsv")
de2idx, idx2de = load_vocab("de.vocab.tsv")
def tokenize_sequences(source_sent, target_sent):
"""
Parse source sentences and target sentences from corpus with some formats.
Parse word token from each sentences.
Padding for word token sentence list.
Args:
:param source_sent: [List], encoding sentences from src-train file
:param target_sent: [List], decoding sentences from tgt-train file
:return: token sequences & source sentences
"""
source_sent = source_sent.numpy().decode('utf-8')
target_sent = target_sent.numpy().decode('utf-8')
inpt = [en2idx.get(word, 1) for word in (u"<SOS> " + source_sent + u" <EOS>").split()]
outpt = [de2idx.get(word, 1) for word in (u"<SOS> " + target_sent + u" <EOS>").split()]
if len(inpt) < pm.maxlen:
inpt += [0 for _ in range(pm.maxlen - len(inpt))]
if len(outpt) < pm.maxlen:
outpt += [0 for _ in range(pm.maxlen - len(outpt))]
return inpt, outpt
def jit_tokenize_sequences(source_sent, target_sent):
return tf.py_function(tokenize_sequences, [source_sent, target_sent], [tf.int64, tf.int64])
def filter_single_word(source_sent, target_sent):
return tf.logical_and(tf.size(source_sent) <= pm.maxlen, tf.size(target_sent) <= pm.maxlen)
def _byte_features(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def dump2record(filename, corpus1, corpus2):
"""
Writedown the data into tfrecord format.
Args:
:param filename:
:param corpus1:
:param corpus2:
"""
assert len(corpus1) == len(corpus2)
writer = tf.io.TFRecordWriter(filename)
for sent1, sent2 in tqdm(zip(corpus1, corpus2)):
features = {}
features['src_sent'] = _byte_features(sent1.encode('utf-8'))
features['tgt_sent'] = _byte_features(sent2.encode('utf-8'))
tf_features = tf.train.Features(feature=features)
tf_examples = tf.train.Example(features=tf_features)
tf_serialized = tf_examples.SerializeToString()
writer.write(tf_serialized)
writer.close()
def build_dataset(mode, batch_size, cache_name, filename=None, corpus=None, is_training=True):
"""
Read train-data from input datasets.
Args:
:param mode: [String], the tfrecord load mode, including 'array'(load from array) or 'file'(load from file)
:param batch_size: [String], cut data into batches for training
:param filename: [String], if mode == 'file' then input the path of tfrecord
:param corpus: [String], if mode == 'array' then input the corpus with array type
:return: datasets
"""
dataset_root = "/".join(pm.train_record.split('/')[:-1])
if mode == 'array':
assert corpus is not None
def _parse(example):
return example[0], example[1]
src, tgt = corpus
real_data = [(inp.encode('utf-8'), tar.encode('utf-8')) for inp, tar in zip(src, tgt)]
dataset = tf.data.Dataset.from_tensor_slices(real_data)
dataset = dataset.map(_parse, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.map(jit_tokenize_sequences, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.filter(filter_single_word).cache(filename='{}/{}'.format(dataset_root, cache_name)).shuffle(pm.buffer_size) if is_training else dataset
dataset = dataset.padded_batch(batch_size, padded_shapes=([-1], [-1])) if is_training else \
dataset.padded_batch(1, padded_shapes=([-1], [-1]))
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) if is_training else dataset
return dataset
elif mode == 'file':
def _parse(example):
dics = {
'src_sent': tf.io.FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
'tgt_sent': tf.io.FixedLenFeature(shape=(), dtype=tf.string, default_value=None)
}
parsed_data = tf.io.parse_single_example(example, dics)
src_sent = parsed_data['src_sent']
tgt_sent = parsed_data['tgt_sent']
return src_sent, tgt_sent
assert filename is not None
dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.map(_parse, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.map(jit_tokenize_sequences, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.filter(filter_single_word).cache(filename='{}/{}'.format(dataset_root, cache_name)).shuffle(pm.buffer_size) if is_training else dataset
dataset = dataset.padded_batch(batch_size, padded_shapes=([-1], [-1])) if is_training else \
dataset.padded_batch(1, padded_shapes=([-1], [-1]))
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) if is_training else dataset
return dataset
else:
raise ValueError('Something wrong about the mode when loading dataset ...')
class LRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, d_model, warmup_steps=4000):
super(LRSchedule, self).__init__()
# It must be tensor else raise "Could not find valid device for node." error.
self.d_model = tf.cast(d_model, tf.float32)
self.warmup_steps = warmup_steps
def __call__(self, step):
arg1 = tf.math.rsqrt(step)
arg2 = step * (self.warmup_steps ** -1.5)
return tf.math.rsqrt(self.d_model) * tf.minimum(arg1, arg2)
class polynomialLR(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, sl, el, decay_steps, power):
super(polynomialLR, self).__init__()
# It must be tensor else raise "Could not find valid device for node." error.
self.sl = sl
self.el = el
self.decay_steps = decay_steps
self.power = power
def __call__(self, step):
arg1 = self.decay_steps * tf.math.ceil(step / self.decay_steps)
return (self.sl - self.el) * (1 - step / arg1) ** self.power + self.el
def masking(sequence, task='padding'):
"""
Masking operation.
Args:
:param sequence: [Tensor], A tensor contains the ids to be search from the lookup table, shape = [batch_size, seq_len]
:param task: [String], 'padding' or 'look_ahead' tasks, set 'padding' default
:return: [Tensor], Masked matrix
"""
if task == 'padding':
return tf.cast(tf.math.equal(sequence, 0), tf.float32)[:, tf.newaxis, tf.newaxis, :]
elif task == 'look_ahead':
size = tf.shape(sequence)[1]
return 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
else:
raise ValueError('Please check the tasks that masking operation dealing with ("padding" or "look_ahead")...')
def create_masks(inp, tar):
enc_padding_mask = masking(inp, task='padding')
dec_padding_mask = masking(inp, task='padding')
look_ahead_mask = masking(tar, task='look_ahead')
dec_tar_padding_mask = masking(tar, task='padding')
combined_mask = tf.maximum(dec_tar_padding_mask, look_ahead_mask)
return enc_padding_mask, combined_mask, dec_padding_mask
def plot_attention_weights(attention, sentence, result, layer):
fig = plt.figure(figsize=(16, 8))
sentence = [en2idx.get(word, 1) for word in sentence.split()]
attention = tf.squeeze(attention[layer], axis=0)
for head in range(attention.shape[0]):
ax = fig.add_subplot(2, 4, head + 1)
ax.matshow(attention[head][:-1, :], cmap='viridis')
fontdict = {'fontsize': 10}
ax.set_xticks(range(len(sentence) + 2))
ax.set_yticks(range(len(result)))
ax.set_ylim(len(result)-1.5, -0.5)
ax.set_xticklabels(['<SOS>'] + [idx2en.get(i, 1) for i in sentence] + ['<EOS>'],
fontdict=fontdict, rotation=90)
ax.set_yticklabels([idx2de.get(i, 1) for i in result.numpy() if i < len(idx2de) and i not in [0, 2, 3]],
fontdict=fontdict)
ax.set_xlabel('Head {}'.format(head + 1))
plt.tight_layout()
plt.show()