Skip to content

Commit

Permalink
Support Half/Bfloat for rand() and fill().
Browse files Browse the repository at this point in the history
Summary: .

Differential Revision: D68984778
  • Loading branch information
shoumikhin authored and facebook-github-bot committed Jan 31, 2025
1 parent dd8da0f commit b132696
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 5 deletions.
13 changes: 8 additions & 5 deletions extension/tensor/tensor_ptr_maker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ bool extract_scalar(executorch::aten::Scalar scalar, INT_T* out_val) {

template <
typename FLOAT_T,
typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
type = true>
typename std::enable_if<
std::is_floating_point_v<FLOAT_T> ||
std::is_same_v<FLOAT_T, executorch::aten::BFloat16> ||
std::is_same_v<FLOAT_T, executorch::aten::Half>,
bool>::type = true>
bool extract_scalar(executorch::aten::Scalar scalar, FLOAT_T* out_val) {
double val;
if (scalar.isFloatingPoint()) {
Expand All @@ -59,7 +62,7 @@ template <
typename std::enable_if<std::is_same<BOOL_T, bool>::value, bool>::type =
true>
bool extract_scalar(executorch::aten::Scalar scalar, BOOL_T* out_val) {
if (scalar.isIntegral(false)) {
if (scalar.isIntegral(/*includeBool=*/false)) {
*out_val = static_cast<bool>(scalar.to<int64_t>());
return true;
}
Expand All @@ -86,7 +89,7 @@ TensorPtr random_strided(
empty_strided(std::move(sizes), std::move(strides), type, dynamism);
std::default_random_engine gen{std::random_device{}()};

ET_SWITCH_REALB_TYPES(type, nullptr, "random_strided", CTYPE, [&] {
ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "random_strided", CTYPE, [&] {
std::generate_n(tensor->mutable_data_ptr<CTYPE>(), tensor->numel(), [&]() {
return static_cast<CTYPE>(distribution(gen));
});
Expand Down Expand Up @@ -121,7 +124,7 @@ TensorPtr full_strided(
executorch::aten::TensorShapeDynamism dynamism) {
auto tensor =
empty_strided(std::move(sizes), std::move(strides), type, dynamism);
ET_SWITCH_REALB_TYPES(type, nullptr, "full_strided", CTYPE, [&] {
ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "full_strided", CTYPE, [&] {
CTYPE value;
ET_EXTRACT_SCALAR(fill_value, value);
std::fill(
Expand Down
44 changes: 44 additions & 0 deletions extension/tensor/test/tensor_ptr_maker_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,20 @@ TEST_F(TensorPtrMakerTest, CreateFull) {
EXPECT_EQ(tensor4->size(1), 5);
EXPECT_EQ(tensor4->scalar_type(), executorch::aten::ScalarType::Double);
EXPECT_EQ(tensor4->const_data_ptr<double>()[0], 11);

auto tensor5 = full({4, 5}, 13, executorch::aten::ScalarType::Half);
EXPECT_EQ(tensor5->dim(), 2);
EXPECT_EQ(tensor5->size(0), 4);
EXPECT_EQ(tensor5->size(1), 5);
EXPECT_EQ(tensor5->scalar_type(), executorch::aten::ScalarType::Half);
EXPECT_EQ(tensor5->const_data_ptr<executorch::aten::Half>()[0], 13);

auto tensor6 = full({4, 5}, 15, executorch::aten::ScalarType::BFloat16);
EXPECT_EQ(tensor6->dim(), 2);
EXPECT_EQ(tensor6->size(0), 4);
EXPECT_EQ(tensor6->size(1), 5);
EXPECT_EQ(tensor6->scalar_type(), executorch::aten::ScalarType::BFloat16);
EXPECT_EQ(tensor6->const_data_ptr<executorch::aten::BFloat16>()[0], 15);
}

TEST_F(TensorPtrMakerTest, CreateScalar) {
Expand Down Expand Up @@ -363,6 +377,36 @@ TEST_F(TensorPtrMakerTest, CreateRandTensorWithDoubleType) {
}
}

TEST_F(TensorPtrMakerTest, CreateRandTensorWithHalfType) {
auto tensor = rand({4, 5}, executorch::aten::ScalarType::Half);

EXPECT_EQ(tensor->dim(), 2);
EXPECT_EQ(tensor->size(0), 4);
EXPECT_EQ(tensor->size(1), 5);
EXPECT_EQ(tensor->scalar_type(), executorch::aten::ScalarType::Half);

for (auto i = 0; i < tensor->numel(); ++i) {
auto val = tensor->const_data_ptr<executorch::aten::Half>()[i];
EXPECT_GE(val, 0.0);
EXPECT_LT(val, 1.0);
}
}

TEST_F(TensorPtrMakerTest, CreateRandTensorWithBFloatType) {
auto tensor = rand({4, 5}, executorch::aten::ScalarType::BFloat16);

EXPECT_EQ(tensor->dim(), 2);
EXPECT_EQ(tensor->size(0), 4);
EXPECT_EQ(tensor->size(1), 5);
EXPECT_EQ(tensor->scalar_type(), executorch::aten::ScalarType::BFloat16);

for (auto i = 0; i < tensor->numel(); ++i) {
auto val = tensor->const_data_ptr<executorch::aten::BFloat16>()[i];
EXPECT_GE(val, 0.0);
EXPECT_LT(val, 1.0);
}
}

TEST_F(TensorPtrMakerTest, CreateRandnTensor) {
auto tensor = randn({100, 100});

Expand Down

0 comments on commit b132696

Please sign in to comment.