Skip to content

Commit

Permalink
SINGA-386 Implement RNN operation for autograd
Browse files Browse the repository at this point in the history
- redesign some RNN related functions and their APIs.

- Now the design of RNN operation is for mini-batch train.

- related files can be built without error.
  • Loading branch information
xuewanqi committed Jul 20, 2018
1 parent 95b4377 commit c48c6d6
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 99 deletions.
24 changes: 15 additions & 9 deletions python/singa/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,8 +966,8 @@ def forward(self, X, h0, c0, W):
# hout_cout: (hout, cout) if lstm, else (hout,)
# hout, cout of shape (num_layers * num_directions, batch,
# hidden_size)
oututs= 1dTo3d(Y)
oututs= _1dTo3d(Y)

if self.rnn_mode != 'lstm':
return outputs, hout
else:
Expand All @@ -977,7 +977,7 @@ def backward(self, dY, dh, dc=CTensor([])):
assert training is True and hasattr(
self, 'cache'), 'Please set training as True before do BP. '

dY_1d= 3dTo1d(dY)
dY_1d= _3dTo1d(dY)

if dY_1d.device().id() != self.handle.device_id:
dY_1d.ToDevice(self.cache[0].device())
Expand All @@ -988,7 +988,7 @@ def backward(self, dY, dh, dc=CTensor([])):
dX_1d, dhout, dcout, dW = singa.GpuRNNBackward(
self.handle, dY_1d, dh, dc, self.cache)

dX = 1dTo3d(dX_1d)
dX = _1dTo3d(dX_1d)

if self.rnn_mode != 'lstm':
return dX, dhout, dW
Expand Down Expand Up @@ -1038,7 +1038,7 @@ def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first
W_Size *= mult * w_size

self.W_Size = W_Size
self.W = Tensor(shape=(W_Size,), requires_grad=True, stores_grad=True)
self.W = Tensor(shape=(W_Size,), requires_grad=True, stores_grad=True) # TODO: assign value of Wi separately
self.W.uniform(0.0, 1.0)

def __call__(self, inputs, h0, c0=None):
Expand All @@ -1052,17 +1052,23 @@ def __call__(self, inputs, h0, c0=None):
assert c0 is not None, 'Please input c0.'
self.device_check(h0, c0)

self.handle = signa.CudnnRNNHandle(inputs.data, *SOME_PARAMETERS*)
if not hasattr(self, 'handle'):
self.handle = signa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers,
self.rnn_mode, self.dropout, self.bidirectional, self.W_Size)
elif inputs.shape[0] != self.handle.seq_length_ or inputs.shape[1] != self.handle.batch_size_:
self.handle = signa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers,
self.rnn_mode, self.dropout, self.bidirectional, self.W_Size)

self.handle.device_id = inputs.device.id()

X= 3dTo1d(inputs)
X= _3dTo1d(inputs)
outputs = rnn(self.handle, X, h0, c0, self.W)
return outputs

def 3dTo1d(self, inputs):
def _3dTo1d(self, inputs):
pass

def 1dTo3d(self, *args):
def _1dTo3d(self, *args):
pass

class LSTM(RNN):
Expand Down
28 changes: 27 additions & 1 deletion src/api/model_operation.i
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "../src/model/operation/convolution.h"
#include "../src/model/operation/batchnorm.h"
#include "../src/model/operation/pooling.h"

#include "../src/model/operation/rnn.h"
%}

namespace singa {
Expand Down Expand Up @@ -51,6 +51,14 @@ class PoolingHandle {
int pooled_width;
};

class RNNHandle {
public:
RNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size);

size_t batch_size_;
size_t seq_length_;
};

#if USE_CUDNN
class CudnnConvHandle: public ConvHandle {
Expand Down Expand Up @@ -106,6 +114,24 @@ Tensor GpuPoolingForward(const CudnnPoolingHandle &cph, const Tensor &x);

Tensor GpuPoolingBackward(const CudnnPoolingHandle &cph, const Tensor &dy, const Tensor& x, const Tensor& y);


class CudnnRNNHandle: public RNNHandle {
public:
CudnnRNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size);

size_t batch_size_;
size_t seq_length_;

};

std::vector<Tensor> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) ;

std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W);

std::vector<Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const Tensor &dY, const Tensor &dhy, const Tensor &dcy, const std::vector<Tensor> &cache);


#endif // USE_CUDNN

} //namespace singa
147 changes: 69 additions & 78 deletions src/model/operation/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

namespace singa {

RNNHandle::RNNHandle(const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
const std::string Rnn_mode, const float Dropout, const bool bidirectional) {
RNNHandle::RNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size) {

CHECK_EQ(input.shape(2), Input_size);
batch_size_ = input.shape(1);
seq_length_= input.shape(0);

input_size_ = Input_size;
CHECK_GT(input_size_, 0u);
Expand All @@ -28,68 +32,62 @@ RNNHandle::RNNHandle(const size_t Input_size, const size_t Hidden_size, const si
}
// the first constant (4) is the size of float
// the second constant (2, 8, 6) is the number of sets of params
int mult = 1;
if (rnn_mode_ == "relu" || rnn_mode_ == "tanh")
mult *= 1;
else if (rnn_mode_ == "lstm")
mult *= 4;
else if (rnn_mode_ == "gru")
mult *= 3;
if (bidirectional)
mult *= 2;

weight_size = 0;
for (size_t i = 0; i < num_stacks_; i++) {
size_t dim = hidden_size_ * (input_size_ + hidden_size_ + 2);
if (i > 0)
dim = hidden_size_ * (hidden_size_ + hidden_size_ + 2);
weight_size += mult * dim;
}
weight_size= Weight_size;

};

#ifdef USE_CUDNN

CudnnRNNHandle::CudnnRNNHandle(const vector<Tensor> &inputs, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
const std::string Rnn_mode, const float Dropout, const bool bidirectional):
RNNHandle(Input_size, Hidden_size, Num_stacks, Rnn_mode, Dropout, bidirectional) {
CudnnRNNHandle::CudnnRNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size):
RNNHandle(input, Input_size, Hidden_size, Num_stacks, Rnn_mode, Dropout, bidirectional, Weight_size) {

CHECK_GT(inputs.size(), 1u + has_cell_);
size_t num_x = inputs.size() - has_cell_ - 1;

DataType dtype = inputs.at(0).data_type();
if (rnn_desc_ != nullptr)
CHECK_EQ(dtype_, GetCudnnDataType(dtype))
<< "Cannot change cudnn data type during training from " << dtype_
<< " to " << GetCudnnDataType(dtype);
else
dtype_ = GetCudnnDataType(dtype);
DataType dtype = input.data_type();
dtype_ = GetCudnnDataType(dtype);

UpdateStates(num_x, inputs);
UpdateIODescriptors(input);
ResetHiddenAndCellDescriptors();
SetRNNDescriptor(input.device());
UpdateSpaces(seq_length_, input.device());
};

void CudnnRNNHandle::UpdateStates(size_t num_x, const vector<Tensor> &inputs) {
UpdateIODescriptors(num_x, inputs);
size_t new_batch_size = inputs.at(0).shape(0);
if (batch_size_ != new_batch_size)
ResetHiddenAndCellDescriptors(new_batch_size);
if (rnn_desc_ == nullptr)
SetRNNDescriptor(inputs.at(0).device());
UpdateSpaces(num_x, inputs.at(0).device());
batch_size_ = new_batch_size;
seq_length_ = num_x;
CudnnRNNHandle::~CudnnRNNHandle() {
if (weight_desc_ != nullptr)
CUDNN_CHECK(cudnnDestroyFilterDescriptor(weight_desc_));
if (dropout_desc_ != nullptr)
CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc_));
if (rnn_desc_ != nullptr)
CUDNN_CHECK(cudnnDestroyRNNDescriptor(rnn_desc_));
if (hx_desc_ != nullptr)
CUDNN_CHECK(cudnnDestroyTensorDescriptor(hx_desc_));
if (hy_desc_ != nullptr)
CUDNN_CHECK(cudnnDestroyTensorDescriptor(hy_desc_));
if (cx_desc_ != nullptr)
CUDNN_CHECK(cudnnDestroyTensorDescriptor(cx_desc_));
if (cy_desc_ != nullptr)
CUDNN_CHECK(cudnnDestroyTensorDescriptor(cy_desc_));
if (dhx_desc_ != nullptr)
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dhx_desc_));
if (dhy_desc_ != nullptr)
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dhy_desc_));
if (dcx_desc_ != nullptr)
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dcx_desc_));
if (dcy_desc_ != nullptr)
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dcy_desc_));
DestroyIODescriptors();
};

void CudnnRNNHandle::DestroyIODescriptors() {
if (x_descs_ != nullptr) {
for (size_t i = 0; i < max_length_; i++) {
for (size_t i = 0; i < seq_length_; i++) {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_descs_[i]));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dx_descs_[i]));
}
delete [] x_descs_;
delete [] dx_descs_;
}
if (y_descs_ != nullptr) {
for (size_t i = 0; i < max_length_; i++) {
for (size_t i = 0; i < seq_length_; i++) {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_descs_[i]));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dy_descs_[i]));
}
Expand All @@ -98,61 +96,60 @@ void CudnnRNNHandle::DestroyIODescriptors() {
}
};

void CudnnRNNHandle::UpdateIODescriptors(size_t len, const vector<Tensor> &inputs) {
bool reset = false;
if (max_length_ < len) {
DestroyIODescriptors();
max_length_ = len;
x_descs_ = new cudnnTensorDescriptor_t[len];
dx_descs_ = new cudnnTensorDescriptor_t[len];
y_descs_ = new cudnnTensorDescriptor_t[len];
dy_descs_ = new cudnnTensorDescriptor_t[len];
for (size_t i = 0; i < len; i++) {

void CudnnRNNHandle::UpdateIODescriptors(const Tensor &input) {
x_descs_ = new cudnnTensorDescriptor_t[seq_length_];
dx_descs_ = new cudnnTensorDescriptor_t[seq_length_];
y_descs_ = new cudnnTensorDescriptor_t[seq_length_];
dy_descs_ = new cudnnTensorDescriptor_t[seq_length_];
for (size_t i = 0; i < seq_length_; i++) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_descs_[i]));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dx_descs_[i]));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_descs_[i]));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dy_descs_[i]));
}
reset = true;
}

for (size_t i = 0; i < len; i++) {
CHECK_EQ(inputs[i].shape(1), input_size_);
if (inputs[i].shape(0) != batch_size_ || reset) {
for (size_t i = 0; i < seq_length_; i++) {
CHECK_EQ(input.shape(2), input_size_);
int d[3] = {1, 1, 1}, s[3] = {1, 1, 1};
d[0] = static_cast<int>(inputs[i].shape(0));
d[0] = static_cast<int>(batch_size_);
CHECK_GT(d[0], 0);
d[1] = static_cast<int>(inputs[i].shape(1));
d[1] = static_cast<int>(input_size_);
s[0] = d[1] * d[2];
s[1] = d[2];
CUDNN_CHECK(cudnnSetTensorNdDescriptor(x_descs_[i], dtype_, 3, d, s));
CUDNN_CHECK(cudnnSetTensorNdDescriptor(dx_descs_[i], dtype_, 3, d, s));

d[0] = static_cast<int>(inputs[i].shape(0));
d[0] = static_cast<int>(batch_size_);
d[1] = static_cast<int>(hidden_size_ * num_directions_);
s[0] = d[1] * d[2];
s[1] = d[2];
CUDNN_CHECK(cudnnSetTensorNdDescriptor(y_descs_[i], dtype_, 3, d, s));
CUDNN_CHECK(cudnnSetTensorNdDescriptor(dy_descs_[i], dtype_, 3, d, s));
}
}
};

void CudnnRNNHandle::ResetHiddenAndCellDescriptors(size_t batch_size) {
if (batch_size_ == 0) {
void CudnnRNNHandle::ResetHiddenAndCellDescriptors() {
if (cx_desc_ == nullptr)
CUDNN_CHECK(cudnnCreateTensorDescriptor(&cx_desc_));
if (dcx_desc_ == nullptr)
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcx_desc_));
if (cy_desc_ == nullptr)
CUDNN_CHECK(cudnnCreateTensorDescriptor(&cy_desc_));
if (dcy_desc_ == nullptr)
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcy_desc_));
if (hx_desc_ == nullptr)
CUDNN_CHECK(cudnnCreateTensorDescriptor(&hx_desc_));
if (dhx_desc_ == nullptr)
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dhx_desc_));
if (hy_desc_ == nullptr)
CUDNN_CHECK(cudnnCreateTensorDescriptor(&hy_desc_));
if (dhy_desc_ == nullptr)
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dhy_desc_));
}

int dim[3] = {1, 1, 1};
dim[0] = static_cast<int>(num_stacks_ * num_directions_);
dim[1] = static_cast<int>(batch_size);
dim[1] = static_cast<int>(batch_size_);
dim[2] = static_cast<int>(hidden_size_);
int stride[3] = {1, 1, 1};
stride[0] = dim[1] * dim[2];
Expand Down Expand Up @@ -229,7 +226,7 @@ void CudnnRNNHandle::UpdateSpaces(size_t seq_length, shared_ptr<Device> dev) {
reserve_space_ = Tensor(Shape{count}, dev, kChar);
// reserve_space_.SetValue(0);
}
}
};

Tensor MergeInputs(size_t num, const vector<Tensor> &in) {
if (num == 1)
Expand Down Expand Up @@ -265,15 +262,14 @@ vector<Tensor> SplitOutput(size_t num, size_t dim,

std::vector<Tensor> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) {
DataType dtype = input.data_type();
auto dev = input.at(0).device();
auto dev = input.device();


Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_};
Tensor output(outshape, dev, dtype);
// LOG(INFO) << "output size " << output.Size();

Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_};
CHECK_EQ(hx.shape(), state_shape);
Tensor hy(state_shape, dev, dtype);

Tensor cy;
Expand Down Expand Up @@ -339,7 +335,6 @@ std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tens
// LOG(INFO) << "output size " << output.Size();

Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_};
CHECK_EQ(hx.shape(), state_shape);
Tensor hy(state_shape, dev, dtype);

Tensor cy;
Expand Down Expand Up @@ -389,7 +384,7 @@ std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tens
return {output, hy, cy};
};

std::vector<Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector<Tensor> &dY, const Tensor &dh, const Tensor &dc, const vector<Tensor> &cache) {
std::vector<Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const Tensor &dY, const Tensor &dhy, const Tensor &dcy, const std::vector<Tensor> &cache) {
const Tensor x = cache[0];
const Tensor y = cache[1];
const Tensor hx = cache[2];
Expand All @@ -399,26 +394,22 @@ std::vector<Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector<Tenso
auto dev = y.device();
auto dtype = y.data_type();


CHECK_EQ(dY.Size(), y.Size());


Shape xshape{y.Size() * crh.input_size_ / crh.hidden_size_ / crh.num_directions_};
CHECK_EQ(x.shape(), xshape)
Tensor dx(xshape, dev, dtype);

Tensor dw(W.shape(), dev, dtype);

Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_};
CHECK_EQ(hx.shape(), state_shape)
Tensor dhx(state_shape, dev, dtype);

Tensor dcx;
if (crh.has_cell_)
dcx.ResetLike(dhx);

dw.SetValue(0.0f);
Block *yb = y.block(), *dyb = dy.block(), *dhyb = dhy.block(),
Block *yb = y.block(), *dyb = dY.block(), *dhyb = dhy.block(),
*dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(),
*wb = W.block(), *dwb = dw.block(), *hxb = hx.block(),
*dxb = dx.block(), *dhxb = dhx.block(), *dcxb = dcx.block(),
Expand Down
Loading

0 comments on commit c48c6d6

Please sign in to comment.