Skip to content

Commit

Permalink
Merge pull request #632 from dcslin/astype-fix
Browse files Browse the repository at this point in the history
make astype return new tensor
  • Loading branch information
nudles authored Mar 23, 2020
2 parents 07ff7d9 + 205a283 commit aac1d88
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 62 deletions.
2 changes: 1 addition & 1 deletion include/singa/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class Tensor {
Tensor &ResetLike(const Tensor &t);

/// Reset the data type, it would reallocate block if type changes.
Tensor &AsType(const DataType type);
Tensor AsType(const DataType type);

/// Reset the device.
/// If the target device is a diff device, then do deep data copy.
Expand Down
4 changes: 3 additions & 1 deletion python/singa/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def as_type(self, dtype):
dtype = singa.kFloat32
else:
raise TypeError("invalid data type %s" % dtype)
self.data.AsType(dtype)
t = Tensor(self.shape, self.device, dtype)
t.data = self.data.AsType(dtype)
return t

def to_device(self, device):
'''Move the tensor data onto a given device.
Expand Down
2 changes: 1 addition & 1 deletion src/api/core_tensor.i
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ namespace singa{
size_t MemSize() const;

void ResetLike(const Tensor &t);
void AsType(DataType type);
Tensor AsType(DataType type);
void ToDevice(std::shared_ptr<singa::Device> dev);
void ToHost();
float L2() const;
Expand Down
37 changes: 16 additions & 21 deletions src/core/tensor/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,29 +152,24 @@ Tensor Resize(const Tensor &in, const Shape &shape) {
} \
} while (0)

Tensor &Tensor::AsType(const DataType type) {
// return new tensor
Tensor Tensor::AsType(const DataType type) {
if (data_type_ != type) {
if (block_ != nullptr && block_->DecRefCount() == 0) {
auto offset = Product(shape_);
auto new_block_ =
device_->NewBlock((int)(Product(shape_) * SizeOf(type)));
TYPE_TYPE_LANG_SWITCH(
data_type_, LDType, type, RDType, device_->lang(), Lang, {
device_->Exec(
[this, new_block_, offset, type](Context *ctx) {
CastAsType<LDType, RDType, Lang>(this, new_block_, offset,
ctx);
},
{}, {});
});
device_->FreeBlock(block_);
block_ = new_block_;
} else {
block_ = device_->NewBlock((int)(Product(shape_) * SizeOf(type)));
}
data_type_ = type;
Tensor ret(shape_, device_, type);
auto *retptr = &ret;
TYPE_TYPE_LANG_SWITCH(
data_type_, LDType, type, RDType, device_->lang(), Lang, {
retptr->device()->Exec(
[this, retptr](Context *ctx) {
CastCopy<LDType, RDType, Lang>(this, retptr, ctx);
},
{this->block()}, {retptr->block()});
});
return ret;
} else {
Tensor t = this->Clone();
return t;
}
return *this;
}

Tensor &Tensor::ToDevice(std::shared_ptr<Device> dst) {
Expand Down
4 changes: 2 additions & 2 deletions src/core/tensor/tensor_math.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ void Abs(const Tensor &in, Tensor *out, Context *ctx) {
}

template <typename DTypeSrc, typename DTypeDst, typename Lang>
void CastAsType(const Tensor *src, Block *dst, int offset, Context *ctx) {
LOG(FATAL) << "CastAsType Not Implemented";
void CastCopy(const Tensor *src, Tensor *dst, Context *ctx) {
LOG(FATAL) << "CastCopy Not Implemented";
}

template <typename DType, typename Lang>
Expand Down
16 changes: 8 additions & 8 deletions src/core/tensor/tensor_math_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,19 @@ void Abs<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
}

template <>
void CastAsType<float, int, lang::Cpp>(const Tensor *src, Block *dst,
int offset, Context *ctx) {
int *dst_array = static_cast<int *>(dst->mutable_data());
void CastCopy<float, int, lang::Cpp>(const Tensor *src, Tensor *dst,
Context *ctx) {
int *dst_array = static_cast<int *>(dst->block()->mutable_data());
const float *src_array = static_cast<const float *>(src->block()->data());
for (int i = 0; i < offset; ++i) dst_array[i] = (int)src_array[i];
for (int i = 0; i < dst->Size(); ++i) dst_array[i] = (int)src_array[i];
}

template <>
void CastAsType<int, float, lang::Cpp>(const Tensor *src, Block *dst,
int offset, Context *ctx) {
float *dst_array = static_cast<float *>(dst->mutable_data());
void CastCopy<int, float, lang::Cpp>(const Tensor *src, Tensor *dst,
Context *ctx) {
float *dst_array = static_cast<float *>(dst->block()->mutable_data());
const int *src_array = static_cast<const int *>(src->block()->data());
for (int i = 0; i < offset; ++i) dst_array[i] = (float)src_array[i];
for (int i = 0; i < dst->Size(); ++i) dst_array[i] = (float)src_array[i];
}

template <>
Expand Down
17 changes: 8 additions & 9 deletions src/core/tensor/tensor_math_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,20 +170,19 @@ void Abs<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) {
}

template <>
void CastAsType<float, int, lang::Cuda>(const Tensor* src, Block* dst,
int offset, Context* ctx) {
void CastCopy<float, int, lang::Cuda>(const Tensor* src, Tensor* dst,
Context* ctx) {
const float* srcPtr = static_cast<const float*>(src->block()->data());
int* dstPtr = static_cast<int*>(dst->mutable_data());
const size_t num = src->Size();
cuda::cast_float_2_int(num, srcPtr, dstPtr, ctx->stream);
int* dstPtr = static_cast<int*>(dst->block()->mutable_data());
cuda::cast_float_2_int(dst->Size(), srcPtr, dstPtr, ctx->stream);
}

template <>
void CastAsType<int, float, lang::Cuda>(const Tensor* src, Block* dst,
int offset, Context* ctx) {
void CastCopy<int, float, lang::Cuda>(const Tensor* src, Tensor* dst,
Context* ctx) {
const int* srcPtr = static_cast<const int*>(src->block()->data());
float* dstPtr = static_cast<float*>(dst->mutable_data());
cuda::cast_int_2_float(offset, srcPtr, dstPtr, ctx->stream);
float* dstPtr = static_cast<float*>(dst->block()->mutable_data());
cuda::cast_int_2_float(dst->Size(), srcPtr, dstPtr, ctx->stream);
}

template <>
Expand Down
55 changes: 53 additions & 2 deletions test/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,35 @@ def _cTensor_to_pyTensor(cTensor):
return new_t


def _ctensor_eq_ndarray(t1, np1):
d = t1.device()
t1.ToHost()
if t1.data_type() == singa_api.kInt:
np.testing.assert_array_almost_equal(t1.GetIntValue(t1.Size()),
np1.flatten())
elif t1.data_type() == singa_api.kFloat32:
np.testing.assert_array_almost_equal(t1.GetFloatValue(t1.Size()),
np1.flatten())

if np1.dtype == np.float32:
np.testing.assert_equal(t1.data_type(), singa_api.kFloat32)
elif np1.dtype == np.int32:
np.testing.assert_equal(t1.data_type(), singa_api.kInt)

np.testing.assert_array_almost_equal(t1.shape(), np1.shape)
t1.ToDevice(d)


def print_t(t1):
d = t1.device()
t1.ToHost()
if t1.data_type() == singa_api.kInt:
print(t1.GetIntValue(t1.Size()))
elif t1.data_type() == singa_api.kFloat32:
print(t1.GetFloatValue(t1.Size()))
t1.ToDevice(d)


class TestAPI(unittest.TestCase):

def test_batchnorm_training_gpu(self):
Expand Down Expand Up @@ -618,20 +647,42 @@ def test_as_type(self):

self.assertEqual(t1_ct.data_type(), singa_api.kFloat32)

t1_ct.AsType(singa_api.kInt)
t1_ct = t1_ct.AsType(singa_api.kInt)

self.assertEqual(t1_ct.data_type(), singa_api.kInt)

np.testing.assert_array_almost_equal(
tensor.to_numpy(_cTensor_to_pyTensor(t1_ct)), np2)

t1_ct.AsType(singa_api.kFloat32)
t1_ct = t1_ct.AsType(singa_api.kFloat32)

self.assertEqual(t1_ct.data_type(), singa_api.kFloat32)

np.testing.assert_array_almost_equal(
tensor.to_numpy(_cTensor_to_pyTensor(t1_ct)), np3)

def test_as_type2(self):
for dev in [cpu_dev, gpu_dev]:
shape1 = [1, 2, 3, 4]
shape2 = [4, 3, 2, 1]
np_int = np.random.randint(0, 10, shape1).astype(np.int32)
np_flt = np_int.astype(np.float32)

t1 = singa_api.Tensor(shape1, dev, singa_api.kInt)
t1.CopyIntDataFromHostPtr(np_int.flatten())
_ctensor_eq_ndarray(t1, np_int)

t1 = singa_api.Reshape(t1, shape2)
t2 = t1.AsType(singa_api.kFloat32)
_ctensor_eq_ndarray(t2, np_flt.reshape(shape2))

t3 = t2.AsType(singa_api.kInt)
_ctensor_eq_ndarray(t3, np_int.reshape(shape2))

t1 = singa_api.Reshape(t1, shape1)
t4 = t1.AsType(singa_api.kFloat32)
_ctensor_eq_ndarray(t4, np_flt.reshape(shape1))


if __name__ == '__main__':
unittest.main()
27 changes: 18 additions & 9 deletions test/python/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,18 +340,27 @@ def test_ceil(self):

def test_astype(self):
for dev in [cpu_dev, gpu_dev]:
shape1 = [2, 3]
shape2 = [3, 2]

np1 = np.random.random([5, 6, 7, 8]).astype(np.float32)
np1 = np1 * 10 - 5
np_flt = np.random.random(shape1).astype(np.float32)
np_flt = np_flt * 10 - 5

np2 = np1.astype(np.int32)
np3 = np2.astype(np.float32)
np_int = np_flt.astype(np.int32)
np_flt2 = np_int.astype(np.float32)

t2 = tensor.Tensor(device=dev, data=np_flt)
t2 = t2.as_type('int')
np.testing.assert_array_almost_equal(tensor.to_numpy(t2), np_int)

t1 = t2.reshape(shape2)
np.testing.assert_array_almost_equal(tensor.to_numpy(t1),
np_int.reshape(shape2))

t1 = t1.as_type('float')
np.testing.assert_array_almost_equal(tensor.to_numpy(t1),
np_flt2.reshape(shape2))

t1 = tensor.Tensor(device=dev, data=np1)
t1.as_type('int')
np.testing.assert_array_almost_equal(tensor.to_numpy(t1), np2)
t1.as_type('float')
np.testing.assert_array_almost_equal(tensor.to_numpy(t1), np3)


if __name__ == '__main__':
Expand Down
36 changes: 28 additions & 8 deletions test/singa/test_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ TEST(TensorClass, FloatAsTypeIntCuda) {
t.CopyDataFromHostPtr(data, 3);
EXPECT_EQ(singa::kFloat32, t.data_type());

t.AsType(singa::kInt);
t = t.AsType(singa::kInt);

EXPECT_EQ(singa::kInt, t.data_type());

Expand All @@ -97,7 +97,7 @@ TEST(TensorClass, IntAsTypeFloatCuda) {
t.CopyDataFromHostPtr(data, 3);
EXPECT_EQ(singa::kInt, t.data_type());

t.AsType(singa::kFloat32);
t = t.AsType(singa::kFloat32);

EXPECT_EQ(singa::kFloat32, t.data_type());

Expand All @@ -110,6 +110,26 @@ TEST(TensorClass, IntAsTypeFloatCuda) {

#endif // USE_CUDA

TEST(TensorClass, FloatAsTypeFloatCPU) {
Tensor t(Shape{3});
float data[] = {1.0f, 2.0f, 3.0f};
t.CopyDataFromHostPtr(data, 3);
EXPECT_EQ(singa::kFloat32, t.data_type());
const float* dptr = static_cast<const float*>(t.block()->data());
EXPECT_FLOAT_EQ(1.0f, dptr[0]);
EXPECT_FLOAT_EQ(2.0f, dptr[1]);
EXPECT_FLOAT_EQ(3.0f, dptr[2]);

Tensor t2 = t.AsType(singa::kFloat32);

EXPECT_EQ(singa::kFloat32, t2.data_type());

const float* dptr2 = static_cast<const float*>(t2.block()->data());
EXPECT_EQ(1.0f, dptr2[0]);
EXPECT_EQ(2.0f, dptr2[1]);
EXPECT_EQ(3.0f, dptr2[2]);
}

TEST(TensorClass, FloatAsTypeIntCPU) {
Tensor t(Shape{3});
float data[] = {1.0f, 2.0f, 3.0f};
Expand All @@ -120,10 +140,10 @@ TEST(TensorClass, FloatAsTypeIntCPU) {
EXPECT_FLOAT_EQ(2.0f, dptr[1]);
EXPECT_FLOAT_EQ(3.0f, dptr[2]);

t.AsType(singa::kInt);
Tensor t2 = t.AsType(singa::kInt);

EXPECT_EQ(singa::kInt, t.data_type());
const int* dptr2 = static_cast<const int*>(t.block()->data());
EXPECT_EQ(singa::kInt, t2.data_type());
const int* dptr2 = static_cast<const int*>(t2.block()->data());
EXPECT_EQ(1, dptr2[0]);
EXPECT_EQ(2, dptr2[1]);
EXPECT_EQ(3, dptr2[2]);
Expand All @@ -135,11 +155,11 @@ TEST(TensorClass, IntAsTypeFloatCPU) {
t.CopyDataFromHostPtr(data, 3);
EXPECT_EQ(singa::kInt, t.data_type());

t.AsType(singa::kFloat32);
auto t2 = t.AsType(singa::kFloat32);

EXPECT_EQ(singa::kFloat32, t.data_type());
EXPECT_EQ(singa::kFloat32, t2.data_type());

const float* dptr2 = static_cast<const float*>(t.block()->data());
const float* dptr2 = static_cast<const float*>(t2.block()->data());
EXPECT_EQ(1.0f, dptr2[0]);
EXPECT_EQ(2.0f, dptr2[1]);
EXPECT_EQ(3.0f, dptr2[2]);
Expand Down

0 comments on commit aac1d88

Please sign in to comment.