Skip to content

Commit

Permalink
modified all the operators to support generic implementation in speci…
Browse files Browse the repository at this point in the history
…fic cases (#7936)

* modified all the operators to support generic implementation when input data types are different or when input data type is different from output datatype

* boolean variables are assiged with true or false instead of 1 or 0. In Scalar operation fixed an issue related to datatype of the second input

---------

Co-authored-by: [email protected] <[email protected]>
Co-authored-by: JP <[email protected]>
  • Loading branch information
3 people authored Jan 29, 2025
1 parent 3be1c5e commit c361431
Show file tree
Hide file tree
Showing 13 changed files with 378 additions and 282 deletions.
105 changes: 60 additions & 45 deletions backends/cadence/fusion_g3/operators/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,7 @@ Tensor& add_out(
const Tensor& b,
const Scalar& alpha,
Tensor& out) {
// Common Dtype
ScalarType common_type =
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());

#ifdef OP_ARG_CHECK
// Check Common Dtype
ET_KERNEL_CHECK(
ctx,
(canCast(common_type, out.scalar_type()) &&
torch::executor::check_alpha_type(
torch::executor::native::utils::get_scalar_dtype(alpha),
common_type)),
InvalidArgument,
out);

// Check Dim Order
ET_KERNEL_CHECK(
ctx,
Expand All @@ -65,10 +51,6 @@ Tensor& add_out(
out);
#endif

// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

static constexpr const char op_name[] = "add.out";

int kTensorDimensionLimit = 5;
Expand All @@ -77,12 +59,12 @@ Tensor& add_out(
int inp2_shape[kTensorDimensionLimit];
int out_shape[kTensorDimensionLimit];

bool broadcast = 0;
bool broadcast = false;

int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
max_dim = out.dim() > max_dim ? out.dim() : max_dim;

bool optimized = 1;
bool optimized = true;

/* Added change to work with input dimensions more than 5 */
for (int i = 0; i < max_dim; i++) {
Expand All @@ -109,15 +91,19 @@ Tensor& add_out(
for (int i = 0; i < out.dim(); i++) {
if (((inp1_shape[i]) != (out_shape[i])) ||
((inp2_shape[i]) != (out_shape[i]))) {
broadcast = 1;
broadcast = true;
}
}

if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) {
optimized = 0;
if (((broadcast) && (max_dim > kTensorDimensionLimit)) ||
(!(((a.scalar_type() == ScalarType::Int) ||
(a.scalar_type() == ScalarType::Float)) &&
(a.scalar_type() == b.scalar_type()) &&
(a.scalar_type() == out.scalar_type())))) {
optimized = false;
}

if ((compute_type == ScalarType::Int) && (optimized)) {
if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
const int* const inp1_data = a.const_data_ptr<int>();
const int* const inp2_data = b.const_data_ptr<int>();
int* const out_data = out.mutable_data_ptr<int>();
Expand Down Expand Up @@ -169,7 +155,7 @@ Tensor& add_out(
alpha_val,
out.numel());
}
} else if ((compute_type == ScalarType::Float) && (optimized)) {
} else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
const float* const inp1_data = a.const_data_ptr<float>();
const float* const inp2_data = b.const_data_ptr<float>();
float* const out_data = out.mutable_data_ptr<float>();
Expand Down Expand Up @@ -222,6 +208,23 @@ Tensor& add_out(
out.numel());
}
} else {
// Common Dtype
ScalarType common_type =
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

// Check Common Dtype
ET_KERNEL_CHECK(
ctx,
(canCast(common_type, out.scalar_type()) &&
torch::executor::check_alpha_type(
torch::executor::native::utils::get_scalar_dtype(alpha),
common_type)),
InvalidArgument,
out);

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_alpha =
torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(alpha);
Expand Down Expand Up @@ -249,22 +252,7 @@ Tensor& add_scalar_out(
const Scalar& b,
const Scalar& alpha,
Tensor& out) {
// Common Dtype
ScalarType common_type =
torch::executor::native::utils::promote_type_with_scalar(
a.scalar_type(), b);

#ifdef OP_ARG_CHECK
// Check Common Dtype
ET_KERNEL_CHECK(
ctx,
(common_type == out.scalar_type() &&
torch::executor::check_alpha_type(
torch::executor::native::utils::get_scalar_dtype(alpha),
common_type)),
InvalidArgument,
out);

// Check Dim Order
ET_KERNEL_CHECK(
ctx,
Expand All @@ -279,14 +267,23 @@ Tensor& add_scalar_out(
InvalidArgument,
out);
#endif
// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "add.Scalar_out";

if (compute_type == ScalarType::Int) {
bool optimized = true;

if (!(((a.scalar_type() == ScalarType::Int) ||
(a.scalar_type() == ScalarType::Float)) &&
(a.scalar_type() == out.scalar_type()))) {
optimized = false;
}

if ((b.isFloatingPoint()) && (a.scalar_type() == ScalarType::Int)) {
optimized = false;
}

if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
const int* const inp1_data = a.const_data_ptr<int>();
int inp2_val;
torch::executor::native::utils::extract_scalar(b, &inp2_val);
Expand All @@ -306,7 +303,7 @@ Tensor& add_scalar_out(
alpha_val,
out.numel());

} else if (compute_type == ScalarType::Float) {
} else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
const float* const inp1_data = a.const_data_ptr<float>();
float inp2_val;
torch::executor::native::utils::extract_scalar(b, &inp2_val);
Expand All @@ -327,6 +324,24 @@ Tensor& add_scalar_out(
out.numel());

} else {
// Common Dtype
ScalarType common_type =
torch::executor::native::utils::promote_type_with_scalar(
a.scalar_type(), b);
// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

// Check Common Dtype
ET_KERNEL_CHECK(
ctx,
(common_type == out.scalar_type() &&
torch::executor::check_alpha_type(
torch::executor::native::utils::get_scalar_dtype(alpha),
common_type)),
InvalidArgument,
out);

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
torch::executor::native::utils::
apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
Expand Down
22 changes: 16 additions & 6 deletions backends/cadence/fusion_g3/operators/op_cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,6 @@ Tensor& cat_out(
int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit;

#ifdef OP_ARG_CHECK
ET_KERNEL_CHECK(
ctx,
torch::executor::check_cat_args(tensors, dim, out),
InvalidArgument,
out);

Tensor::SizesType expected_out_size[kTensorDimensionLimit];
size_t expected_out_dim = 0;
Expand Down Expand Up @@ -106,7 +101,16 @@ Tensor& cat_out(
out_shapes[i] = out_size[i];
}

if ((out.scalar_type() == ScalarType::Int) ||
bool optimized = true;

for (int i = 0; i < tensors.size(); i++) {
if (out.scalar_type() != tensors[i].scalar_type()) {
optimized = false;
break;
}
}

if ((optimized) && (out.scalar_type() == ScalarType::Int) ||
(out.scalar_type() == ScalarType::Short) ||
(out.scalar_type() == ScalarType::Char) ||
(out.scalar_type() == ScalarType::UInt32) ||
Expand All @@ -125,6 +129,12 @@ Tensor& cat_out(
(int)dim,
get_element_size(out.scalar_type()));
} else {
ET_KERNEL_CHECK(
ctx,
torch::executor::check_cat_args(tensors, dim, out),
InvalidArgument,
out);

const size_t outer = executorch::runtime::getLeadingDims(out, dim);
const size_t dim_stride = executorch::runtime::getTrailingDims(out, dim);
const size_t ninputs = tensors.size();
Expand Down
32 changes: 19 additions & 13 deletions backends/cadence/fusion_g3/operators/op_dequantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,22 @@ Tensor& dequantize_impl(
}
}
} else {
if (*zero_point_data != 0) // tesor
if (*zero_point_data != 0) // tensor
{
is_asym_dequant |= 1;
}
}
}
float* out_data = out.mutable_data_ptr<float>();

bool optimized = true;

if (out.scalar_type() != ScalarType::Float) {
optimized = false;
}

if (is_asym_dequant) {
if (input.scalar_type() == ScalarType::Byte) {
if ((input.scalar_type() == ScalarType::Byte) && (optimized)) {
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -139,7 +145,7 @@ Tensor& dequantize_impl(
axis,
zero_point_data,
scale_data);
} else if (input.scalar_type() == ScalarType::Char) {
} else if ((input.scalar_type() == ScalarType::Char) && (optimized)) {
const int8_t* input_data = input.const_data_ptr<int8_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -152,7 +158,7 @@ Tensor& dequantize_impl(
axis,
zero_point_data,
scale_data);
} else if (input.scalar_type() == ScalarType::UInt16) {
} else if ((input.scalar_type() == ScalarType::UInt16) && (optimized)) {
const uint16_t* input_data = input.const_data_ptr<uint16_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -165,7 +171,7 @@ Tensor& dequantize_impl(
axis,
zero_point_data,
scale_data);
} else if (input.scalar_type() == ScalarType::Short) {
} else if ((input.scalar_type() == ScalarType::Short) && (optimized)) {
const int16_t* input_data = input.const_data_ptr<int16_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -178,7 +184,7 @@ Tensor& dequantize_impl(
axis,
zero_point_data,
scale_data);
} else if (input.scalar_type() == (ScalarType)Bits4u) {
} else if ((input.scalar_type() == (ScalarType)Bits4u) && (optimized)) {
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -191,7 +197,7 @@ Tensor& dequantize_impl(
axis,
zero_point_data,
scale_data);
} else if (input.scalar_type() == (ScalarType)Bits4) {
} else if ((input.scalar_type() == (ScalarType)Bits4) && (optimized)) {
const int8_t* input_data = input.const_data_ptr<int8_t>();
XT_KERNEL_CHECK(
ctx,
Expand Down Expand Up @@ -338,7 +344,7 @@ Tensor& dequantize_impl(
}
}
} else {
if (input.scalar_type() == ScalarType::Byte) {
if ((input.scalar_type() == ScalarType::Byte) && (optimized)) {
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -350,7 +356,7 @@ Tensor& dequantize_impl(
input.dim(),
axis,
scale_data);
} else if (input.scalar_type() == ScalarType::Char) {
} else if ((input.scalar_type() == ScalarType::Char) && (optimized)) {
const int8_t* input_data = input.const_data_ptr<int8_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -362,7 +368,7 @@ Tensor& dequantize_impl(
input.dim(),
axis,
scale_data);
} else if (input.scalar_type() == ScalarType::UInt16) {
} else if ((input.scalar_type() == ScalarType::UInt16) && (optimized)) {
const uint16_t* input_data = input.const_data_ptr<uint16_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -374,7 +380,7 @@ Tensor& dequantize_impl(
input.dim(),
axis,
scale_data);
} else if (input.scalar_type() == ScalarType::Short) {
} else if ((input.scalar_type() == ScalarType::Short) && (optimized)) {
const int16_t* input_data = input.const_data_ptr<int16_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -386,7 +392,7 @@ Tensor& dequantize_impl(
input.dim(),
axis,
scale_data);
} else if (input.scalar_type() == (ScalarType)Bits4u) {
} else if ((input.scalar_type() == (ScalarType)Bits4u) && (optimized)) {
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -398,7 +404,7 @@ Tensor& dequantize_impl(
input.dim(),
axis,
scale_data);
} else if (input.scalar_type() == (ScalarType)Bits4) {
} else if ((input.scalar_type() == (ScalarType)Bits4) && (optimized)) {
const int8_t* input_data = input.const_data_ptr<int8_t>();
XT_KERNEL_CHECK(
ctx,
Expand Down
Loading

0 comments on commit c361431

Please sign in to comment.