Skip to content

Commit

Permalink
SINGA-386 Implement RNN operation for autograd
Browse files Browse the repository at this point in the history
- redesign some APIs to adapt to autograd
  • Loading branch information
xuewanqi committed Jul 19, 2018
1 parent b176cb4 commit 95b4377
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 70 deletions.
125 changes: 113 additions & 12 deletions python/singa/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,34 +943,135 @@ class _RNN(Operation):
def __init__(self, handle):
self.handle = handle

def forward(self, X, W):
def forward(self, X, h0, c0, W):
# X of shape (seq_len, batch, input_size)
# h0_c0: (h0, c0) if lstm, else (h0,)
# h0, c0 of shape (num_layers * num_directions, batch, hidden_size)
if c0 is None:
assert self.rnn_mode != 'lstm'
c0= CTensor([]) # CTensor([]) and Tensor cx are the same?

if self.handle.device_id == -1:
raise NotImplementedError
else:
if training:
out, self.cache = singa.GpuRNNForwardTraining(
self.handle, X, W)
Y, hout, cout = singa.GpuRNNForwardTraining(
self.handle, X, h0, c0, W)
self.cache=(X, Y, h0, c0, W)
else:
out = singa.GpuRNNForwardInference(self.handle, X, W)
return out
Y, hout, cout = singa.GpuRNNForwardInference(
self.handle, X, h0, c0, W)

# Y of shape (seq_len, batch, hidden_size * num_directions)
# hout_cout: (hout, cout) if lstm, else (hout,)
# hout, cout of shape (num_layers * num_directions, batch,
# hidden_size)
oututs= 1dTo3d(Y)

if self.rnn_mode != 'lstm':
return outputs, hout
else:
return outputs, hout, cout

def backward(self, dY):
def backward(self, dY, dh, dc=CTensor([])):
assert training is True and hasattr(
self, 'cache'), 'Please set training as True before do BP. '

if dY.device().id() != self.handle.device_id:
dY.ToDevice(self.inputs[0].device())
dY_1d= 3dTo1d(dY)

if dY_1d.device().id() != self.handle.device_id:
dY_1d.ToDevice(self.cache[0].device())

if self.handle.device_id == -1:
raise NotImplementedError
else:
dX, dW = singa.GpuRNNBackward(self.handle, dY, self.cache)
return dX, dW
dX_1d, dhout, dcout, dW = singa.GpuRNNBackward(
self.handle, dY_1d, dh, dc, self.cache)

dX = 1dTo3d(dX_1d)

def rnn():
pass
if self.rnn_mode != 'lstm':
return dX, dhout, dW
else:
return dX, dhout, dcout, dW


def rnn(handle, x, h0, c0, W):
return _RNN(handle)(x, h0, c0, W)


class RNN(Layer):

def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False, rnn_mode='tanh'):
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = dropout
self.bidirectional = bidirectional
self.rnn_mode = rnn_mode

if bias is not True or batch_first is not False:
raise NotImplementedError

mult = 1
if self.rnn_mode == 'tanh' or self.rnn_mode == 'relu':
mult *= 1
elif self.rnn_mode == 'lstm':
mult *= 4
elif self.rnn_mode == 'gru':
mult *= 3
else:
raise ValueError

if self.bidirectional:
mult *= 2

for k in range(num_layers):
if k == 1:
w_size = self.hidden_size * \
(self.input_size + self.hidden_size + 2)
else:
w_size = self.hidden_size * \
(self.hidden_size + self.hidden_size + 2)
W_Size *= mult * w_size

self.W_Size = W_Size
self.W = Tensor(shape=(W_Size,), requires_grad=True, stores_grad=True)
self.W.uniform(0.0, 1.0)

def __call__(self, inputs, h0, c0=None):
# inputs of shape (seq_len, batch, input_size)
# h0_c0: (h0, c0) if lstm, else (h0,)
# h0, c0 of shape (num_layers * num_directions, batch, hidden_size)

self.device_check(inputs, h0, self.W)

if self.rnn_mode == 'lstm':
assert c0 is not None, 'Please input c0.'
self.device_check(h0, c0)

self.handle = signa.CudnnRNNHandle(inputs.data, *SOME_PARAMETERS*)
self.handle.device_id = inputs.device.id()

X= 3dTo1d(inputs)
outputs = rnn(self.handle, X, h0, c0, self.W)
return outputs

def 3dTo1d(self, inputs):
pass

def 1dTo3d(self, *args):
pass

class LSTM(RNN):

def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False):
super(LSTM, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectionalrnn_mode='lstm')


class GRU(RNN):

def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False):
super(GRU, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectionalrnn_mode='gru')
79 changes: 25 additions & 54 deletions src/model/operation/rnn.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -263,24 +263,21 @@ vector<Tensor> SplitOutput(size_t num, size_t dim,
return outputs;
};

std::vector<std::vector<Tensor>> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const vector<Tensor> &inputs, const Tensor &W) {
DataType dtype = inputs.at(0).data_type();
auto dev = inputs.at(0).device();
std::vector<Tensor> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) {
DataType dtype = input.data_type();
auto dev = input.at(0).device();

CHECK_GT(inputs.size(), 1u + crh.has_cell_);
size_t num_x = inputs.size() - crh.has_cell_ - 1;
Tensor input = MergeInputs(num_x, inputs);

Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_};
Tensor output(outshape, dev, dtype);
// LOG(INFO) << "output size " << output.Size();
Tensor hx = inputs.at(num_x);

Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_};
CHECK_EQ(hx.shape(), state_shape);
Tensor hy(state_shape, dev, dtype);

Tensor cy, cx;
Tensor cy;
if (crh.has_cell_) {
cx = inputs.at(num_x + 1);
cy.ResetLike(hy);
}

Expand Down Expand Up @@ -330,39 +327,23 @@ std::vector<std::vector<Tensor>> GpuRNNForwardTraining(const CudnnRNNHandle &crh
},
{inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace});

auto outputs =
SplitOutput(num_x, crh.hidden_size_ * crh.num_directions_, inputs, output);
outputs.push_back(hy);
if (crh.has_cell_) outputs.push_back(cy);

std::vector<Tensor> cache;
cache.push_back(input);
cache.push_back(output);
cache.push_back(hx);
cache.push_back(cx);
cache.push_back(W);

return {outputs, cache};
return {output, hy, cy};
};

std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const vector<Tensor> &inputs, const Tensor &W) {
DataType dtype = inputs.at(0).data_type();
auto dev = inputs.at(0).device();

CHECK_GT(inputs.size(), 1u + crh.has_cell_);
size_t num_x = inputs.size() - crh.has_cell_ - 1;
Tensor input = MergeInputs(num_x, inputs);
std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) {
DataType dtype = input.data_type();
auto dev = input.device();

Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_};
Tensor output(outshape, dev, dtype);
// LOG(INFO) << "output size " << output.Size();
Tensor hx = inputs.at(num_x);

Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_};
CHECK_EQ(hx.shape(), state_shape);
Tensor hy(state_shape, dev, dtype);

Tensor cy, cx;
Tensor cy;
if (crh.has_cell_) {
cx = inputs.at(num_x + 1);
cy.ResetLike(hy);
}

Expand Down Expand Up @@ -405,15 +386,10 @@ std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const vect
// clang-format on
}, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace});

auto outputs =
SplitOutput(num_x, crh.hidden_size_ * crh.num_directions_, inputs, output);
outputs.push_back(hy);
if (crh.has_cell_) outputs.push_back(cy);

return outputs;
return {output, hy, cy};
};

std::pair<vector<Tensor>, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector<Tensor> &grads, const vector<Tensor> &cache) {
std::vector<Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector<Tensor> &dY, const Tensor &dh, const Tensor &dc, const vector<Tensor> &cache) {
const Tensor x = cache[0];
const Tensor y = cache[1];
const Tensor hx = cache[2];
Expand All @@ -423,24 +399,24 @@ std::pair<vector<Tensor>, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, cons
auto dev = y.device();
auto dtype = y.data_type();

CHECK_GT(grads.size(), 1u + crh.has_cell_);
size_t num_dy = grads.size() - crh.has_cell_ - 1;
CHECK_EQ(num_dy, crh.seq_length_);
const Tensor dy = MergeInputs(num_dy, grads);
CHECK_EQ(dy.Size(), y.Size());
const Tensor dhy = grads.at(num_dy);
Tensor dcy;
if (crh.has_cell_)
dcy = grads.at(num_dy + 1);

CHECK_EQ(dY.Size(), y.Size());


Shape xshape{y.Size() * crh.input_size_ / crh.hidden_size_ / crh.num_directions_};
CHECK_EQ(x.shape(), xshape)
Tensor dx(xshape, dev, dtype);

Tensor dw(W.shape(), dev, dtype);

Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_};
CHECK_EQ(hx.shape(), state_shape)
Tensor dhx(state_shape, dev, dtype);

Tensor dcx;
if (crh.has_cell_)
dcx.ResetLike(dhx);

dw.SetValue(0.0f);
Block *yb = y.block(), *dyb = dy.block(), *dhyb = dhy.block(),
*dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(),
Expand Down Expand Up @@ -483,12 +459,7 @@ std::pair<vector<Tensor>, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, cons
{yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace},
{dxb, dwb, dhxb, dcxb, wspace, rspace});

auto data_grads = SplitOutput(num_dy, crh.input_size_, grads, dx);
data_grads.push_back(dhx);
if (crh.has_cell_)
data_grads.push_back(dcx);

return std::make_pair(data_grads, dw);
return {dx, dhx, dcx, dw};
};

#endif // USE_CUDNN
Expand Down
7 changes: 3 additions & 4 deletions src/model/operation/rnn.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,17 @@ class CudnnRNNHandle: public RNNHandle {
Tensor reserve_space_;
Tensor dropout_state_;
};

Tensor MergeInputs(size_t num, const vector<Tensor> &in);

vector<Tensor> SplitOutput(size_t num, size_t dim,
const vector<Tensor> &in,
const Tensor output);

std::vector<std::vector<Tensor>> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const vector<Tensor> &inputs, const Tensor &W);
std::vector<Tensor> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) ;

std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const vector<Tensor> &inputs, const Tensor &W);
std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W);

std::pair<vector<Tensor>, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector<Tensor> &grads, const vector<Tensor> &cache);
std::vector<Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector<Tensor> &dY, const Tensor &dh, const Tensor &dc, const vector<Tensor> &cache);

#endif // USE_CUDNN

Expand Down

0 comments on commit 95b4377

Please sign in to comment.