-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgen_collate_fn.py
83 lines (57 loc) · 2.65 KB
/
gen_collate_fn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
def gen_collate_fn(data):
"""
Creates mini-batch tensors from (src_sent, tgt_sent, src_seq, tgt_seq).
We should build a custom collate_fn rather than using default collate_fn,
because merging sequences (including padding) is not supported in default.
Seqeuences are padded to the maximum length of mini-batch sequences (dynamic padding).
Args:
data: list of tuple (src_sents, tgt_sents, src_seqs, tgt_seqs)
- src_sents, tgt_sents: batch of original tokenized sentences
- src_seqs, tgt_seqs: batch of original tokenized sentence ids
Returns:
- src_sents, tgt_sents (tuple): batch of original tokenized sentences
- src_seqs, tgt_seqs (tensor): (max_src_len, batch_size)
- src_lens, tgt_lens (tensor): (batch_size)
"""
def _pad_sequences(seqs):
lens = [len(seq) for seq in seqs]
padded_seqs = torch.zeros(len(seqs), max(lens)).long()
for i, seq in enumerate(seqs):
end = lens[i]
padded_seqs[i, :end] = torch.LongTensor(seq[:end])
return padded_seqs, lens
# Sort a list by *source* sequence length (descending order) to use `pack_padded_sequence`.
# The *target* sequence is not sorted <-- It's ok, cause `pack_padded_sequence` only takes
# *source* sequence, which is in the EncoderRNN
data.sort(key=lambda x: len(x[1]), reverse=True) ## 本来为 0
# Seperate source and target sequences.
src_sents, tgt_sents, src_seqs, tgt_seqs = zip(*data)
# Merge sequences (from tuple of 1D tensor to 2D tensor)
src_seqs, src_lens = _pad_sequences(src_seqs) ## (batch, seq_len), (batch)
tgt_seqs, tgt_lens = _pad_sequences(tgt_seqs) ## (batch, seq_len), (batch)
# # (batch, seq_len) => (seq_len, batch)
# src_seqs = src_seqs.transpose(0,1)
# tgt_seqs = tgt_seqs.transpose(0,1)
return src_sents, tgt_sents, src_seqs, tgt_seqs, src_lens, tgt_lens
"""
In[0]:
src_seqs = [np.random.randint(0,100,(len)).tolist() for len in [20, 12, 16]]
src_seqs
Out[0]:
[[79, 22, 10, 38, 67, 73, 33, 60, 78, 94, 35, 49, 30, 33, 85, 71, 72, 75, 19, 46],
[22, 57, 97, 19, 95, 30, 67, 3, 47, 21, 39, 25],
[97, 28, 47, 49, 55, 73, 94, 69, 35, 51, 10, 27, 12, 85, 42, 69]]
In[1]:
src_seqs, src_lens = _pad_sequences(src_seqs)
In[2]:
src_seqs
Out[2]:
tensor([[79, 22, 10, 38, 67, 73, 33, 60, 78, 94, 35, 49, 30, 33, 85, 71, 72, 75,19, 46],
[22, 57, 97, 19, 95, 30, 67, 3, 47, 21, 39, 25, 0, 0, 0, 0, 0, 0, 0, 0],
[97, 28, 47, 49, 55, 73, 94, 69, 35, 51, 10, 27, 12, 85, 42, 69, 0, 0, 0, 0]])
In[3]:
src_lens
Out[3]:
[20, 12, 16]
"""