-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathMultiSequence.lua
139 lines (104 loc) · 4 KB
/
MultiSequence.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
local dl = require 'dataload._env'
local MultiSequence, parent = torch.class('dl.MultiSequence', 'dl.DataLoader', dl)
-- used by Billion Words dataset to encapsulate unordered sentences.
-- The inputs and targets for sequence of sequences look as follows:
-- target : [ ] E L L O [ ] C R E E N ...
-- input : [ ] H E L L [ ] S C R E E ...
-- Note that [ ] is a zero mask used to forget between sequences.
function MultiSequence:__init(sequences, batchsize)
assert(torch.isTensor(sequences[1]))
assert(torch.type(batchsize) == 'number')
-- sequence is a list of tensors where the first dimension indexes time
self.sequences = sequences
self.batchsize = batchsize
self.seqlen = 0
for i, seq in ipairs(self.sequences) do
self.seqlen = self.seqlen + seq:size(1)
end
self.seqlen = torch.ceil(self.seqlen/batchsize)
self:reset()
end
function MultiSequence:reset()
parent.reset(self)
self.trackers = {nextseq=1}
end
-- inputs : seqlen x batchsize [x inputsize]
-- targets : seqlen x batchsize [x inputsize]
function MultiSequence:sub(start, stop, inputs, targets)
local seqlen = stop - start + 1
inputs = inputs or self.sequences[1].new()
inputs:resize(seqlen, self.batchsize, unpack(self:isize())):zero()
targets = targets or inputs.new()
targets:resize(seqlen, self.batchsize, unpack(self:tsize())):zero()
for i=1,self.batchsize do
local input = inputs:select(2,i)
local target = targets:select(2,i)
local tracker = self.trackers[i] or {}
self.trackers[i] = tracker
local start = 1
while start <= seqlen do
if not tracker.seqid then
tracker.idx = 1
-- each sequence is separated by a zero input and -1 target.
-- this should make the model forget between sequences
-- (use with AbstractRecurrent:maskZero() and LookupTableMaskZero)
if input:dim() == 1 then
input[start] = 0
target[start] = 1
else
input[start]:fill(0)
target[start]:fill(0)
end
start = start + 1
if self.randseq then
tracker.seqid = math.random(1,#self.sequences)
else
tracker.seqid = self.trackers.nextseq
self.trackers.nextseq = self.trackers.nextseq + 1
if self.trackers.nextseq > #self.sequences then
self.trackers.nextseq = 1
end
end
end
if start <= seqlen then
local sequence = self.sequences[tracker.seqid]
local stop = math.min(tracker.idx+seqlen-start, sequence:size(1) - 1)
local size = stop - tracker.idx + 1
input:narrow(1,start,size):copy(sequence:sub(tracker.idx, stop))
target:narrow(1,start,size):copy(sequence:sub(tracker.idx+1, stop+1))
start = start + size
tracker.idx = stop+1
if stop == sequence:size(1) - 1 then
tracker.seqid = nil
end
end
end
assert(start-1 == seqlen)
end
return inputs, targets
end
function MultiSequence:sample()
error"Not Implemented"
end
-- returns size of sequences
function MultiSequence:size()
return self.seqlen
end
function MultiSequence:isize(excludedim)
-- by default, sequence dimension is excluded
excludedim = excludedim == nil and 1 or excludedim
local size = torchx.recursiveSize(self.sequences[1], excludedim)
if excludedim ~= 1 then
size[1] = self:size()
end
return size
end
function MultiSequence:tsize(excludedim)
return self:isize(excludedim)
end
function MultiSequence:subiter(seqlen, epochsize, ...)
return parent.subiter(self, seqlen, epochsize, ...)
end
function MultiSequence:sampleiter(seqlen, epochsize, ...)
error"Not Implemented. Use subiter instead."
end