diff --git a/backends/cadence/fusion_g3/operators/op_add.cpp b/backends/cadence/fusion_g3/operators/op_add.cpp index d51fee5338..409c4cc510 100644 --- a/backends/cadence/fusion_g3/operators/op_add.cpp +++ b/backends/cadence/fusion_g3/operators/op_add.cpp @@ -35,21 +35,7 @@ Tensor& add_out( const Tensor& b, const Scalar& alpha, Tensor& out) { - // Common Dtype - ScalarType common_type = - executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); - #ifdef OP_ARG_CHECK - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, - (canCast(common_type, out.scalar_type()) && - torch::executor::check_alpha_type( - torch::executor::native::utils::get_scalar_dtype(alpha), - common_type)), - InvalidArgument, - out); - // Check Dim Order ET_KERNEL_CHECK( ctx, @@ -65,10 +51,6 @@ Tensor& add_out( out); #endif - // Compute Dtype - ScalarType compute_type = - torch::executor::native::utils::get_compute_type(common_type); - static constexpr const char op_name[] = "add.out"; int kTensorDimensionLimit = 5; @@ -77,12 +59,12 @@ Tensor& add_out( int inp2_shape[kTensorDimensionLimit]; int out_shape[kTensorDimensionLimit]; - bool broadcast = 0; + bool broadcast = false; int max_dim = a.dim() > b.dim() ? a.dim() : b.dim(); max_dim = out.dim() > max_dim ? out.dim() : max_dim; - bool optimized = 1; + bool optimized = true; /* Added change to work with input dimensions more than 5 */ for (int i = 0; i < max_dim; i++) { @@ -109,15 +91,19 @@ Tensor& add_out( for (int i = 0; i < out.dim(); i++) { if (((inp1_shape[i]) != (out_shape[i])) || ((inp2_shape[i]) != (out_shape[i]))) { - broadcast = 1; + broadcast = true; } } - if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) { - optimized = 0; + if (((broadcast) && (max_dim > kTensorDimensionLimit)) || + (!(((a.scalar_type() == ScalarType::Int) || + (a.scalar_type() == ScalarType::Float)) && + (a.scalar_type() == b.scalar_type()) && + (a.scalar_type() == out.scalar_type())))) { + optimized = false; } - if ((compute_type == ScalarType::Int) && (optimized)) { + if ((a.scalar_type() == ScalarType::Int) && (optimized)) { const int* const inp1_data = a.const_data_ptr<int>(); const int* const inp2_data = b.const_data_ptr<int>(); int* const out_data = out.mutable_data_ptr<int>(); @@ -169,7 +155,7 @@ Tensor& add_out( alpha_val, out.numel()); } - } else if ((compute_type == ScalarType::Float) && (optimized)) { + } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) { const float* const inp1_data = a.const_data_ptr<float>(); const float* const inp2_data = b.const_data_ptr<float>(); float* const out_data = out.mutable_data_ptr<float>(); @@ -222,6 +208,23 @@ Tensor& add_out( out.numel()); } } else { + // Common Dtype + ScalarType common_type = + executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + torch::executor::check_alpha_type( + torch::executor::native::utils::get_scalar_dtype(alpha), + common_type)), + InvalidArgument, + out); + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_alpha = torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(alpha); @@ -249,22 +252,7 @@ Tensor& add_scalar_out( const Scalar& b, const Scalar& alpha, Tensor& out) { - // Common Dtype - ScalarType common_type = - torch::executor::native::utils::promote_type_with_scalar( - a.scalar_type(), b); - #ifdef OP_ARG_CHECK - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, - (common_type == out.scalar_type() && - torch::executor::check_alpha_type( - torch::executor::native::utils::get_scalar_dtype(alpha), - common_type)), - InvalidArgument, - out); - // Check Dim Order ET_KERNEL_CHECK( ctx, @@ -279,14 +267,23 @@ Tensor& add_scalar_out( InvalidArgument, out); #endif - // Compute Dtype - ScalarType compute_type = - torch::executor::native::utils::get_compute_type(common_type); // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "add.Scalar_out"; - if (compute_type == ScalarType::Int) { + bool optimized = true; + + if (!(((a.scalar_type() == ScalarType::Int) || + (a.scalar_type() == ScalarType::Float)) && + (a.scalar_type() == out.scalar_type()))) { + optimized = false; + } + + if ((b.isFloatingPoint()) && (a.scalar_type() == ScalarType::Int)) { + optimized = false; + } + + if ((a.scalar_type() == ScalarType::Int) && (optimized)) { const int* const inp1_data = a.const_data_ptr<int>(); int inp2_val; torch::executor::native::utils::extract_scalar(b, &inp2_val); @@ -306,7 +303,7 @@ Tensor& add_scalar_out( alpha_val, out.numel()); - } else if (compute_type == ScalarType::Float) { + } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) { const float* const inp1_data = a.const_data_ptr<float>(); float inp2_val; torch::executor::native::utils::extract_scalar(b, &inp2_val); @@ -327,6 +324,24 @@ Tensor& add_scalar_out( out.numel()); } else { + // Common Dtype + ScalarType common_type = + torch::executor::native::utils::promote_type_with_scalar( + a.scalar_type(), b); + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (common_type == out.scalar_type() && + torch::executor::check_alpha_type( + torch::executor::native::utils::get_scalar_dtype(alpha), + common_type)), + InvalidArgument, + out); + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { torch::executor::native::utils:: apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>( diff --git a/backends/cadence/fusion_g3/operators/op_cat.cpp b/backends/cadence/fusion_g3/operators/op_cat.cpp index 74fd96a212..84224b37b0 100644 --- a/backends/cadence/fusion_g3/operators/op_cat.cpp +++ b/backends/cadence/fusion_g3/operators/op_cat.cpp @@ -46,11 +46,6 @@ Tensor& cat_out( int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit; #ifdef OP_ARG_CHECK - ET_KERNEL_CHECK( - ctx, - torch::executor::check_cat_args(tensors, dim, out), - InvalidArgument, - out); Tensor::SizesType expected_out_size[kTensorDimensionLimit]; size_t expected_out_dim = 0; @@ -106,7 +101,16 @@ Tensor& cat_out( out_shapes[i] = out_size[i]; } - if ((out.scalar_type() == ScalarType::Int) || + bool optimized = true; + + for (int i = 0; i < tensors.size(); i++) { + if (out.scalar_type() != tensors[i].scalar_type()) { + optimized = false; + break; + } + } + + if ((optimized) && (out.scalar_type() == ScalarType::Int) || (out.scalar_type() == ScalarType::Short) || (out.scalar_type() == ScalarType::Char) || (out.scalar_type() == ScalarType::UInt32) || @@ -125,6 +129,12 @@ Tensor& cat_out( (int)dim, get_element_size(out.scalar_type())); } else { + ET_KERNEL_CHECK( + ctx, + torch::executor::check_cat_args(tensors, dim, out), + InvalidArgument, + out); + const size_t outer = executorch::runtime::getLeadingDims(out, dim); const size_t dim_stride = executorch::runtime::getTrailingDims(out, dim); const size_t ninputs = tensors.size(); diff --git a/backends/cadence/fusion_g3/operators/op_dequantize.cpp b/backends/cadence/fusion_g3/operators/op_dequantize.cpp index 3e0235170b..dd9d4f2a51 100644 --- a/backends/cadence/fusion_g3/operators/op_dequantize.cpp +++ b/backends/cadence/fusion_g3/operators/op_dequantize.cpp @@ -117,7 +117,7 @@ Tensor& dequantize_impl( } } } else { - if (*zero_point_data != 0) // tesor + if (*zero_point_data != 0) // tensor { is_asym_dequant |= 1; } @@ -125,8 +125,14 @@ Tensor& dequantize_impl( } float* out_data = out.mutable_data_ptr<float>(); + bool optimized = true; + + if (out.scalar_type() != ScalarType::Float) { + optimized = false; + } + if (is_asym_dequant) { - if (input.scalar_type() == ScalarType::Byte) { + if ((input.scalar_type() == ScalarType::Byte) && (optimized)) { const uint8_t* input_data = input.const_data_ptr<uint8_t>(); XT_KERNEL_CHECK( ctx, @@ -139,7 +145,7 @@ Tensor& dequantize_impl( axis, zero_point_data, scale_data); - } else if (input.scalar_type() == ScalarType::Char) { + } else if ((input.scalar_type() == ScalarType::Char) && (optimized)) { const int8_t* input_data = input.const_data_ptr<int8_t>(); XT_KERNEL_CHECK( ctx, @@ -152,7 +158,7 @@ Tensor& dequantize_impl( axis, zero_point_data, scale_data); - } else if (input.scalar_type() == ScalarType::UInt16) { + } else if ((input.scalar_type() == ScalarType::UInt16) && (optimized)) { const uint16_t* input_data = input.const_data_ptr<uint16_t>(); XT_KERNEL_CHECK( ctx, @@ -165,7 +171,7 @@ Tensor& dequantize_impl( axis, zero_point_data, scale_data); - } else if (input.scalar_type() == ScalarType::Short) { + } else if ((input.scalar_type() == ScalarType::Short) && (optimized)) { const int16_t* input_data = input.const_data_ptr<int16_t>(); XT_KERNEL_CHECK( ctx, @@ -178,7 +184,7 @@ Tensor& dequantize_impl( axis, zero_point_data, scale_data); - } else if (input.scalar_type() == (ScalarType)Bits4u) { + } else if ((input.scalar_type() == (ScalarType)Bits4u) && (optimized)) { const uint8_t* input_data = input.const_data_ptr<uint8_t>(); XT_KERNEL_CHECK( ctx, @@ -191,7 +197,7 @@ Tensor& dequantize_impl( axis, zero_point_data, scale_data); - } else if (input.scalar_type() == (ScalarType)Bits4) { + } else if ((input.scalar_type() == (ScalarType)Bits4) && (optimized)) { const int8_t* input_data = input.const_data_ptr<int8_t>(); XT_KERNEL_CHECK( ctx, @@ -338,7 +344,7 @@ Tensor& dequantize_impl( } } } else { - if (input.scalar_type() == ScalarType::Byte) { + if ((input.scalar_type() == ScalarType::Byte) && (optimized)) { const uint8_t* input_data = input.const_data_ptr<uint8_t>(); XT_KERNEL_CHECK( ctx, @@ -350,7 +356,7 @@ Tensor& dequantize_impl( input.dim(), axis, scale_data); - } else if (input.scalar_type() == ScalarType::Char) { + } else if ((input.scalar_type() == ScalarType::Char) && (optimized)) { const int8_t* input_data = input.const_data_ptr<int8_t>(); XT_KERNEL_CHECK( ctx, @@ -362,7 +368,7 @@ Tensor& dequantize_impl( input.dim(), axis, scale_data); - } else if (input.scalar_type() == ScalarType::UInt16) { + } else if ((input.scalar_type() == ScalarType::UInt16) && (optimized)) { const uint16_t* input_data = input.const_data_ptr<uint16_t>(); XT_KERNEL_CHECK( ctx, @@ -374,7 +380,7 @@ Tensor& dequantize_impl( input.dim(), axis, scale_data); - } else if (input.scalar_type() == ScalarType::Short) { + } else if ((input.scalar_type() == ScalarType::Short) && (optimized)) { const int16_t* input_data = input.const_data_ptr<int16_t>(); XT_KERNEL_CHECK( ctx, @@ -386,7 +392,7 @@ Tensor& dequantize_impl( input.dim(), axis, scale_data); - } else if (input.scalar_type() == (ScalarType)Bits4u) { + } else if ((input.scalar_type() == (ScalarType)Bits4u) && (optimized)) { const uint8_t* input_data = input.const_data_ptr<uint8_t>(); XT_KERNEL_CHECK( ctx, @@ -398,7 +404,7 @@ Tensor& dequantize_impl( input.dim(), axis, scale_data); - } else if (input.scalar_type() == (ScalarType)Bits4) { + } else if ((input.scalar_type() == (ScalarType)Bits4) && (optimized)) { const int8_t* input_data = input.const_data_ptr<int8_t>(); XT_KERNEL_CHECK( ctx, diff --git a/backends/cadence/fusion_g3/operators/op_div.cpp b/backends/cadence/fusion_g3/operators/op_div.cpp index 1461f643a8..85e5da4276 100644 --- a/backends/cadence/fusion_g3/operators/op_div.cpp +++ b/backends/cadence/fusion_g3/operators/op_div.cpp @@ -54,10 +54,6 @@ Tensor& div_out( const Tensor& a, const Tensor& b, Tensor& out) { - // Common Dtype - ScalarType common_type = - executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); - #ifdef OP_ARG_CHECK // Check Dim Order ET_KERNEL_CHECK( @@ -73,11 +69,6 @@ Tensor& div_out( InvalidArgument, out); #endif - - // Compute Dtype - ScalarType compute_type = - torch::executor::native::utils::get_compute_type(common_type); - // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "div.out"; @@ -87,12 +78,12 @@ Tensor& div_out( int inp2_shape[kTensorDimensionLimit]; int out_shape[kTensorDimensionLimit]; - bool broadcast = 0; + bool broadcast = false; int max_dim = a.dim() > b.dim() ? a.dim() : b.dim(); max_dim = out.dim() > max_dim ? out.dim() : max_dim; - bool optimized = 1; + bool optimized = true; for (int i = 0; i < max_dim; i++) { out_shape[i] = 1; @@ -118,15 +109,19 @@ Tensor& div_out( for (int i = 0; i < out.dim(); i++) { if (((inp1_shape[i]) != (out_shape[i])) || ((inp2_shape[i]) != (out_shape[i]))) { - broadcast = 1; + broadcast = true; } } - if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) { - optimized = 0; + if (((broadcast) && (max_dim > kTensorDimensionLimit)) || + (!(((a.scalar_type() == ScalarType::Int) || + (a.scalar_type() == ScalarType::Float)) && + (a.scalar_type() == b.scalar_type()) && + (out.scalar_type() == ScalarType::Float)))) { + optimized = false; } - if ((compute_type == ScalarType::Int) && (optimized)) { + if ((a.scalar_type() == ScalarType::Int) && (optimized)) { const int* const inp1_data = a.const_data_ptr<int>(); const int* const inp2_data = b.const_data_ptr<int>(); float* const out_data = out.mutable_data_ptr<float>(); @@ -162,7 +157,7 @@ Tensor& div_out( inp2_data, out.numel()); } - } else if ((compute_type == ScalarType::Float) && (optimized)) { + } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) { const float* const inp1_data = a.const_data_ptr<float>(); const float* const inp2_data = b.const_data_ptr<float>(); float* const out_data = out.mutable_data_ptr<float>(); @@ -244,19 +239,7 @@ Tensor& div_out_mode( ET_KERNEL_CHECK( ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out); - // Common Dtype - ScalarType common_type = - executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); - #ifdef OP_ARG_CHECK - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, - (canCast(common_type, out.scalar_type()) && - common_type != ScalarType::Bool), - InvalidArgument, - out); - // Check Dim Order ET_KERNEL_CHECK( ctx, @@ -271,9 +254,6 @@ Tensor& div_out_mode( InvalidArgument, out); #endif - // Compute Dtype - ScalarType compute_type = - torch::executor::native::utils::get_compute_type(common_type); // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "div.out_mode"; @@ -287,12 +267,12 @@ Tensor& div_out_mode( int inp2_shape[kTensorDimensionLimit]; int out_shape[kTensorDimensionLimit]; - bool broadcast = 0; + bool broadcast = false; int max_dim = a.dim() > b.dim() ? a.dim() : b.dim(); max_dim = out.dim() > max_dim ? out.dim() : max_dim; - bool optimized = 1; + bool optimized = true; for (int i = 0; i < max_dim; i++) { out_shape[i] = 1; @@ -318,17 +298,21 @@ Tensor& div_out_mode( for (int i = 0; i < out.dim(); i++) { if (((inp1_shape[i]) != (out_shape[i])) || ((inp2_shape[i]) != (out_shape[i]))) { - broadcast = 1; + broadcast = true; } } - if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) { - optimized = 0; + if (((broadcast) && (max_dim > kTensorDimensionLimit)) || + (!(((a.scalar_type() == ScalarType::Int) || + (a.scalar_type() == ScalarType::Float)) && + (a.scalar_type() == b.scalar_type()) && + (a.scalar_type() == out.scalar_type())))) { + optimized = false; } int mode_value = (mode_val == "trunc") ? 1 : 2; - if ((compute_type == ScalarType::Int) && (optimized)) { + if ((a.scalar_type() == ScalarType::Int) && (optimized)) { const int* const inp1_data = a.const_data_ptr<int>(); const int* const inp2_data = b.const_data_ptr<int>(); int* const out_data = out.mutable_data_ptr<int>(); @@ -367,7 +351,7 @@ Tensor& div_out_mode( mode_value, out.numel()); } - } else if ((compute_type == ScalarType::Float) && (optimized)) { + } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) { const float* const inp1_data = a.const_data_ptr<float>(); const float* const inp2_data = b.const_data_ptr<float>(); float* const out_data = out.mutable_data_ptr<float>(); @@ -407,6 +391,21 @@ Tensor& div_out_mode( out.numel()); } } else { + // Common Dtype + ScalarType common_type = + executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), + InvalidArgument, + out); + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { torch::executor::native::utils:: apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>( @@ -456,15 +455,7 @@ Tensor& div_scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - // Common Dtype - ScalarType common_type = - torch::executor::native::utils::promote_type_with_scalar( - a.scalar_type(), b); - #ifdef OP_ARG_CHECK - // Check Common Dtype - ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); - // Check Dim Order ET_KERNEL_CHECK( ctx, @@ -480,14 +471,22 @@ Tensor& div_scalar_out( out); #endif - // Compute Dtype - ScalarType compute_type = - torch::executor::native::utils::get_compute_type(common_type); + bool optimized = true; + + if (!(((a.scalar_type() == ScalarType::Int) || + (a.scalar_type() == ScalarType::Float)) && + (out.scalar_type() == ScalarType::Float))) { + optimized = false; + } + + if ((b.isFloatingPoint()) && (a.scalar_type() == ScalarType::Int)) { + optimized = false; + } // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "div.Scalar_out"; - if (compute_type == ScalarType::Int) { + if ((a.scalar_type() == ScalarType::Int) && (optimized)) { const int* const inp1_data = a.const_data_ptr<int>(); int inp2_val; torch::executor::native::utils::extract_scalar(b, &inp2_val); @@ -502,7 +501,7 @@ Tensor& div_scalar_out( inp1_data, inp2_val, out.numel()); - } else if (compute_type == ScalarType::Float) { + } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) { const float* const inp1_data = a.const_data_ptr<float>(); float inp2_val; torch::executor::native::utils::extract_scalar(b, &inp2_val); @@ -526,6 +525,11 @@ Tensor& div_scalar_out( : ScalarType::Float; ScalarType compute_type = torch::executor::native::utils::get_compute_type(common_type); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, common_type == out.scalar_type(), InvalidArgument, out); + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(b); @@ -560,29 +564,7 @@ Tensor& div_scalar_mode_out( ET_KERNEL_CHECK( ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out); - // Common Dtype - ScalarType common_type = - torch::executor::native::utils::promote_type_with_scalar( - a.scalar_type(), b); - #ifdef OP_ARG_CHECK - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, - (canCast(common_type, out.scalar_type()) && - common_type != ScalarType::Bool), - InvalidArgument, - out); - - // Check for intergral division by zero - ET_KERNEL_CHECK_MSG( - ctx, - !(executorch::runtime::isIntegralType(common_type, true) && - torch::executor::native::utils::scalar_to<double>(b) == 0), - InvalidArgument, - out, - "Div mode operation encountered integer division by zero"); - // Check Dim Order ET_KERNEL_CHECK( ctx, @@ -598,18 +580,26 @@ Tensor& div_scalar_mode_out( out); #endif - // Compute Dtype - ScalarType compute_type = - torch::executor::native::utils::get_compute_type(common_type); - const bool mode_is_trunc = mode_val == "trunc"; + bool optimized = true; + + if (!(((a.scalar_type() == ScalarType::Int) || + (a.scalar_type() == ScalarType::Float)) && + (a.scalar_type() == out.scalar_type()))) { + optimized = false; + } + + if ((b.isFloatingPoint()) && (a.scalar_type() == ScalarType::Int)) { + optimized = false; + } + // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "div.Scalar_mode_out"; int mode_value = (mode_val == "trunc") ? 1 : 2; - if (compute_type == ScalarType::Int) { + if ((a.scalar_type() == ScalarType::Int) && (optimized)) { const int* const inp1_data = a.const_data_ptr<int>(); int inp2_val; torch::executor::native::utils::extract_scalar(b, &inp2_val); @@ -625,7 +615,7 @@ Tensor& div_scalar_mode_out( inp2_val, mode_value, out.numel()); - } else if (compute_type == ScalarType::Float) { + } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) { const float* const inp1_data = a.const_data_ptr<float>(); float inp2_val; torch::executor::native::utils::extract_scalar(b, &inp2_val); @@ -642,6 +632,31 @@ Tensor& div_scalar_mode_out( mode_value, out.numel()); } else { + // Common Dtype + ScalarType common_type = + torch::executor::native::utils::promote_type_with_scalar( + a.scalar_type(), b); + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), + InvalidArgument, + out); + + // Check for intergral division by zero + ET_KERNEL_CHECK_MSG( + ctx, + !(executorch::runtime::isIntegralType(common_type, true) && + torch::executor::native::utils::scalar_to<double>(b) == 0), + InvalidArgument, + out, + "Div mode operation encountered integer division by zero"); + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(b); diff --git a/backends/cadence/fusion_g3/operators/op_exp.cpp b/backends/cadence/fusion_g3/operators/op_exp.cpp index 4b6b898b17..41b5d70b22 100644 --- a/backends/cadence/fusion_g3/operators/op_exp.cpp +++ b/backends/cadence/fusion_g3/operators/op_exp.cpp @@ -49,9 +49,10 @@ Tensor& exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { out); #endif - if (in.scalar_type() == ScalarType::Float) { - float* __restrict__ out_data = out.mutable_data_ptr<float>(); - const float* __restrict__ in_data = in.const_data_ptr<float>(); + if ((in.scalar_type() == ScalarType::Float) && + (out.scalar_type() == ScalarType::Float)) { + float* const out_data = out.mutable_data_ptr<float>(); + const float* const in_data = in.const_data_ptr<float>(); XT_KERNEL_CHECK( ctx, out, xa_nn_elm_exp_f32_f32, out_data, in_data, out.numel()); diff --git a/backends/cadence/fusion_g3/operators/op_mean.cpp b/backends/cadence/fusion_g3/operators/op_mean.cpp index 289baceb12..cd02714113 100644 --- a/backends/cadence/fusion_g3/operators/op_mean.cpp +++ b/backends/cadence/fusion_g3/operators/op_mean.cpp @@ -44,22 +44,23 @@ int prepare_data( for (int i = 0; i < num_out_dims; i++) { out_shape[i] = out.size(i); } - int num_axis_dims = 0; - for (const auto& d : dim_list.value()) { - if (d < 0) { - p_axis[num_axis_dims] = num_inp_dims + d; - num_axis_dims++; - } else { - p_axis[num_axis_dims] = d; - num_axis_dims++; + if (dim_list.has_value()) { + for (const auto& d : dim_list.value()) { + if (d < 0) { + p_axis[num_axis_dims] = num_inp_dims + d; + num_axis_dims++; + } else { + p_axis[num_axis_dims] = d; + num_axis_dims++; + } } } return num_axis_dims; } -Tensor& mean_out( +Tensor& mean_dim_out( KernelRuntimeContext& ctx, const Tensor& in, optional<ArrayRef<int64_t>> dim_list, @@ -69,12 +70,6 @@ Tensor& mean_out( (void)ctx; #ifdef OP_ARG_CHECK - ET_KERNEL_CHECK( - ctx, - torch::executor::check_mean_dim_args(in, dim_list, keepdim, dtype, out), - InvalidArgument, - out); - ET_KERNEL_CHECK( ctx, executorch::runtime::tensors_have_same_dim_order(in, out), @@ -97,13 +92,14 @@ Tensor& mean_out( constexpr int kNnlibMaxDim = 5; - bool optimized = 1; + bool optimized = true; - if (out.scalar_type() != ScalarType::Float) - optimized = 0; + if (!((out.scalar_type() == ScalarType::Float) && + (in.scalar_type() == ScalarType::Float))) + optimized = false; if (in.dim() > kNnlibMaxDim) - optimized = 0; + optimized = false; if (optimized) { float* __restrict__ p_out = out.mutable_data_ptr<float>(); @@ -135,9 +131,8 @@ Tensor& mean_out( num_inp_dims, num_out_dims); - if (num_axis_dims == num_inp_dims) { + if ((num_axis_dims == num_inp_dims) || (!dim_list.has_value())) { num_out_dims = 1; - out_shape[0] = 1; } int inp_shape_max = inp_shape[p_axis[0]]; @@ -168,6 +163,12 @@ Tensor& mean_out( num_axis_dims, p_scratch_in); } else { + ET_KERNEL_CHECK( + ctx, + torch::executor::check_mean_dim_args(in, dim_list, keepdim, dtype, out), + InvalidArgument, + out); + ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] { ET_SWITCH_FLOATH_TYPES( out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] { diff --git a/backends/cadence/fusion_g3/operators/op_mul.cpp b/backends/cadence/fusion_g3/operators/op_mul.cpp index 93b4c5a992..bee6ac9cbd 100644 --- a/backends/cadence/fusion_g3/operators/op_mul.cpp +++ b/backends/cadence/fusion_g3/operators/op_mul.cpp @@ -33,15 +33,7 @@ Tensor& mul_out( const Tensor& a, const Tensor& b, Tensor& out) { - // Common Dtype - ScalarType common_type = - executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); - #ifdef OP_ARG_CHECK - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out); - // Check Dim Order ET_KERNEL_CHECK( ctx, @@ -57,10 +49,6 @@ Tensor& mul_out( out); #endif - // Compute Dtype - ScalarType compute_type = - torch::executor::native::utils::get_compute_type(common_type); - // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "mul.out"; int kTensorDimensionLimit = 5; @@ -69,12 +57,12 @@ Tensor& mul_out( int inp2_shape[kTensorDimensionLimit]; int out_shape[kTensorDimensionLimit]; - bool broadcast = 0; + bool broadcast = false; int max_dim = a.dim() > b.dim() ? a.dim() : b.dim(); max_dim = out.dim() > max_dim ? out.dim() : max_dim; - bool optimized = 1; + bool optimized = true; /* Added change to work with input dimensions more than 5 */ for (int i = 0; i < max_dim; i++) { @@ -101,15 +89,19 @@ Tensor& mul_out( for (int i = 0; i < out.dim(); i++) { if (((inp1_shape[i]) != (out_shape[i])) || ((inp2_shape[i]) != (out_shape[i]))) { - broadcast = 1; + broadcast = true; } } - if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) { - optimized = 0; + if (((broadcast) && (max_dim > kTensorDimensionLimit)) || + (!(((a.scalar_type() == ScalarType::Int) || + (a.scalar_type() == ScalarType::Float)) && + (a.scalar_type() == b.scalar_type()) && + (a.scalar_type() == out.scalar_type())))) { + optimized = false; } - if ((compute_type == ScalarType::Int) && (optimized)) { + if ((a.scalar_type() == ScalarType::Int) && (optimized)) { const int* const inp1_data = a.const_data_ptr<int>(); const int* const inp2_data = b.const_data_ptr<int>(); int* const out_data = out.mutable_data_ptr<int>(); @@ -154,7 +146,7 @@ Tensor& mul_out( inp2_data, out.numel()); } - } else if ((compute_type == ScalarType::Float) && (optimized)) { + } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) { const float* const inp1_data = a.const_data_ptr<float>(); const float* const inp2_data = b.const_data_ptr<float>(); float* const out_data = out.mutable_data_ptr<float>(); @@ -200,6 +192,16 @@ Tensor& mul_out( out.numel()); } } else { + // Common Dtype + ScalarType common_type = + executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out); + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { torch::executor::native::utils::apply_bitensor_elementwise_fn< CTYPE_COMPUTE, @@ -224,15 +226,7 @@ Tensor& mul_scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - // Common Dtype - ScalarType common_type = - torch::executor::native::utils::promote_type_with_scalar( - a.scalar_type(), b); - #ifdef OP_ARG_CHECK - // Check Common Dtype - ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); - // Check Dim Order ET_KERNEL_CHECK( ctx, @@ -244,13 +238,23 @@ Tensor& mul_scalar_out( ET_KERNEL_CHECK( ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); #endif - // Compute Dtype - ScalarType compute_type = - torch::executor::native::utils::get_compute_type(common_type); // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "mul.Scalar_out"; - if (compute_type == ScalarType::Int) { + + bool optimized = true; + + if (!(((a.scalar_type() == ScalarType::Int) || + (a.scalar_type() == ScalarType::Float)) && + (a.scalar_type() == out.scalar_type()))) { + optimized = false; + } + + if ((b.isFloatingPoint()) && (a.scalar_type() == ScalarType::Int)) { + optimized = false; + } + + if ((a.scalar_type() == ScalarType::Int) && (optimized)) { const int* const inp1_data = a.const_data_ptr<int>(); int inp2_val; torch::executor::native::utils::extract_scalar(b, &inp2_val); @@ -264,7 +268,7 @@ Tensor& mul_scalar_out( inp1_data, inp2_val, out.numel()); - } else if (compute_type == ScalarType::Float) { + } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) { const float* const inp1_data = a.const_data_ptr<float>(); float inp2_val; torch::executor::native::utils::extract_scalar(b, &inp2_val); @@ -279,6 +283,17 @@ Tensor& mul_scalar_out( inp2_val, out.numel()); } else { + // Common Dtype + ScalarType common_type = + torch::executor::native::utils::promote_type_with_scalar( + a.scalar_type(), b); + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, common_type == out.scalar_type(), InvalidArgument, out); + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(b); diff --git a/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp b/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp index 9857bbce37..b4f076e810 100644 --- a/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp +++ b/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp @@ -123,14 +123,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> native_layer_norm_out( std::tuple<Tensor&, Tensor&, Tensor&> ret_val(out, mean_out, rstd_out); int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit; - #ifdef OP_ARG_CHECK - ET_KERNEL_CHECK( - ctx, - torch::executor::check_layer_norm_args( - input, normalized_shape, weight, bias, out, mean_out, rstd_out), - InvalidArgument, - ret_val); // Only support default dim order for now. // TODO: Support other dim orders. @@ -189,12 +182,34 @@ std::tuple<Tensor&, Tensor&, Tensor&> native_layer_norm_out( ret_val); #endif + bool optimized = true; + int input_shape[kTensorDimensionLimit]; for (int i = 0; i < input.dim(); i++) { input_shape[i] = input.size(i); } - if (out.scalar_type() == ScalarType::Float) { + if (!(((input.scalar_type() == ScalarType::Float) && + (input.scalar_type() == out.scalar_type()) && + (out.scalar_type() == mean_out.scalar_type()) && + (mean_out.scalar_type() == rstd_out.scalar_type())))) { + optimized = false; + } + + if (optimized) { + if (weight.has_value()) { + if (!(input.scalar_type() == weight.value().scalar_type())) { + optimized = false; + } + } + if (bias.has_value()) { + if (!(input.scalar_type() == bias.value().scalar_type())) { + optimized = false; + } + } + } + + if ((input.scalar_type() == ScalarType::Float) && (optimized)) { float* const out_data = out.mutable_data_ptr<float>(); float* const mean_data = mean_out.mutable_data_ptr<float>(); float* const rstd_data = rstd_out.mutable_data_ptr<float>(); @@ -247,6 +262,13 @@ std::tuple<Tensor&, Tensor&, Tensor&> native_layer_norm_out( free(weight_data); } } else { + ET_KERNEL_CHECK( + ctx, + torch::executor::check_layer_norm_args( + input, normalized_shape, weight, bias, out, mean_out, rstd_out), + InvalidArgument, + ret_val); + ET_SWITCH_FLOAT_TYPES( input.scalar_type(), ctx, "native_layer_norm.out", CTYPE, [&]() { layer_norm<CTYPE>( diff --git a/backends/cadence/fusion_g3/operators/op_permute_copy.cpp b/backends/cadence/fusion_g3/operators/op_permute_copy.cpp index 23c2d1e5fb..34def4fd1b 100644 --- a/backends/cadence/fusion_g3/operators/op_permute_copy.cpp +++ b/backends/cadence/fusion_g3/operators/op_permute_copy.cpp @@ -65,12 +65,6 @@ Tensor& permute_copy_out( * the checks only in operator level(As there are no checks in kernel). */ #ifdef OP_ARG_CHECK - ET_KERNEL_CHECK( - ctx, - torch::executor::check_permute_copy_args(in, dims, out), - InvalidArgument, - out); - ET_KERNEL_CHECK( ctx, executorch::runtime::tensors_have_same_dim_order(in, out), @@ -112,7 +106,8 @@ Tensor& permute_copy_out( signed char* out_data = out.mutable_data_ptr<signed char>(); const signed char* const inp_data = in.const_data_ptr<signed char>(); - if (((out.scalar_type() == ScalarType::Int) || + if (((out.scalar_type() == in.scalar_type()) && + (out.scalar_type() == ScalarType::Int) || (out.scalar_type() == ScalarType::Short) || (out.scalar_type() == ScalarType::Char) || (out.scalar_type() == ScalarType::UInt32) || @@ -131,9 +126,15 @@ Tensor& permute_copy_out( in.dim(), get_element_size(out.scalar_type())); } else { + ET_KERNEL_CHECK( + ctx, + torch::executor::check_permute_copy_args(in, dims, out), + InvalidArgument, + out); + const auto in_type = out.scalar_type(); - size_t in_coord[5] = {0}; - size_t trailing_dims_memo[kTensorDimensionLimit]; + size_t in_coord[executorch::runtime::kTensorDimensionLimit] = {0}; + size_t trailing_dims_memo[executorch::runtime::kTensorDimensionLimit]; executorch::runtime::memoizeTrailingDims(in, trailing_dims_memo); // in and out must be the same dtype ET_SWITCH_ALL_TYPES(in_type, ctx, "permute_copy.out", CTYPE, [&] { diff --git a/backends/cadence/fusion_g3/operators/op_quantize.cpp b/backends/cadence/fusion_g3/operators/op_quantize.cpp index 8237c3c266..2af77eca6c 100644 --- a/backends/cadence/fusion_g3/operators/op_quantize.cpp +++ b/backends/cadence/fusion_g3/operators/op_quantize.cpp @@ -159,6 +159,12 @@ Tensor& quantize_impl( bool is_asym_quant = 0; + bool optimized = true; + + if (input.scalar_type() != ScalarType::Float) { + optimized = false; + } + if (zero_point_data != NULL) // asymmetric quant { if (axis != NULL) // channel @@ -177,7 +183,7 @@ Tensor& quantize_impl( } if (is_asym_quant) { - if (out.scalar_type() == ScalarType::Byte) { + if ((out.scalar_type() == ScalarType::Byte) && (optimized)) { uint8_t* out_data = out.mutable_data_ptr<uint8_t>(); XT_KERNEL_CHECK( ctx, @@ -192,7 +198,7 @@ Tensor& quantize_impl( zero_point_data, quant_min, quant_max); - } else if (out.scalar_type() == ScalarType::Char) { + } else if ((out.scalar_type() == ScalarType::Char) && (optimized)) { int8_t* out_data = out.mutable_data_ptr<int8_t>(); XT_KERNEL_CHECK( @@ -208,7 +214,7 @@ Tensor& quantize_impl( zero_point_data, quant_min, quant_max); - } else if (out.scalar_type() == ScalarType::UInt16) { + } else if ((out.scalar_type() == ScalarType::UInt16) && (optimized)) { uint16_t* out_data = out.mutable_data_ptr<uint16_t>(); XT_KERNEL_CHECK( ctx, @@ -223,7 +229,7 @@ Tensor& quantize_impl( zero_point_data, quant_min, quant_max); - } else if (out.scalar_type() == ScalarType::Short) { + } else if ((out.scalar_type() == ScalarType::Short) && (optimized)) { int16_t* out_data = out.mutable_data_ptr<int16_t>(); XT_KERNEL_CHECK( ctx, @@ -238,7 +244,7 @@ Tensor& quantize_impl( zero_point_data, quant_min, quant_max); - } else if (out.scalar_type() == (ScalarType)Bits4u) { + } else if ((out.scalar_type() == (ScalarType)Bits4u) && (optimized)) { uint8_t* out_data = out.mutable_data_ptr<uint8_t>(); XT_KERNEL_CHECK( ctx, @@ -253,7 +259,7 @@ Tensor& quantize_impl( zero_point_data, quant_min, quant_max); - } else if (out.scalar_type() == (ScalarType)Bits4) { + } else if ((out.scalar_type() == (ScalarType)Bits4) && (optimized)) { int8_t* out_data = out.mutable_data_ptr<int8_t>(); XT_KERNEL_CHECK( ctx, @@ -391,7 +397,7 @@ Tensor& quantize_impl( #undef ASYM_QUANTIZE_IMPL_CHANNEL } } else { - if (out.scalar_type() == ScalarType::Byte) { + if ((out.scalar_type() == ScalarType::Byte) && (optimized)) { uint8_t* out_data = out.mutable_data_ptr<uint8_t>(); XT_KERNEL_CHECK( ctx, @@ -405,7 +411,7 @@ Tensor& quantize_impl( scale_data, quant_min, quant_max); - } else if (out.scalar_type() == ScalarType::Char) { + } else if ((out.scalar_type() == ScalarType::Char) && (optimized)) { int8_t* out_data = out.mutable_data_ptr<int8_t>(); XT_KERNEL_CHECK( ctx, @@ -419,7 +425,7 @@ Tensor& quantize_impl( scale_data, quant_min, quant_max); - } else if (out.scalar_type() == ScalarType::UInt16) { + } else if ((out.scalar_type() == ScalarType::UInt16) && (optimized)) { uint16_t* out_data = out.mutable_data_ptr<uint16_t>(); XT_KERNEL_CHECK( ctx, @@ -433,7 +439,7 @@ Tensor& quantize_impl( scale_data, quant_min, quant_max); - } else if (out.scalar_type() == ScalarType::Short) { + } else if ((out.scalar_type() == ScalarType::Short) && (optimized)) { int16_t* out_data = out.mutable_data_ptr<int16_t>(); XT_KERNEL_CHECK( ctx, @@ -447,7 +453,7 @@ Tensor& quantize_impl( scale_data, quant_min, quant_max); - } else if (out.scalar_type() == (ScalarType)Bits4u) { + } else if ((out.scalar_type() == (ScalarType)Bits4u) && (optimized)) { uint8_t* out_data = out.mutable_data_ptr<uint8_t>(); XT_KERNEL_CHECK( ctx, @@ -461,7 +467,7 @@ Tensor& quantize_impl( scale_data, quant_min, quant_max); - } else if (out.scalar_type() == (ScalarType)Bits4) { + } else if ((out.scalar_type() == (ScalarType)Bits4) && (optimized)) { int8_t* out_data = out.mutable_data_ptr<int8_t>(); XT_KERNEL_CHECK( ctx, diff --git a/backends/cadence/fusion_g3/operators/op_slice_copy.cpp b/backends/cadence/fusion_g3/operators/op_slice_copy.cpp index c481cf726b..9158eecf13 100644 --- a/backends/cadence/fusion_g3/operators/op_slice_copy.cpp +++ b/backends/cadence/fusion_g3/operators/op_slice_copy.cpp @@ -58,12 +58,6 @@ Tensor& slice_copy_Tensor_out( int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit; #ifdef OP_ARG_CHECK - ET_KERNEL_CHECK( - ctx, - torch::executor::check_slice_copy_args(in, dim, step, out), - InvalidArgument, - out); - ET_KERNEL_CHECK( ctx, executorch::runtime::tensors_have_same_dim_order(in, out), @@ -101,12 +95,13 @@ Tensor& slice_copy_Tensor_out( signed char* out_data = out.mutable_data_ptr<signed char>(); const signed char* const inp_data = in.const_data_ptr<signed char>(); - if ((out.scalar_type() == ScalarType::Int) || - (out.scalar_type() == ScalarType::Short) || - (out.scalar_type() == ScalarType::Char) || - (out.scalar_type() == ScalarType::UInt32) || - (out.scalar_type() == ScalarType::UInt16) || - (out.scalar_type() == ScalarType::Byte)) { + if ((out.scalar_type() == in.scalar_type()) && + ((out.scalar_type() == ScalarType::Int) || + (out.scalar_type() == ScalarType::Short) || + (out.scalar_type() == ScalarType::Char) || + (out.scalar_type() == ScalarType::UInt32) || + (out.scalar_type() == ScalarType::UInt16) || + (out.scalar_type() == ScalarType::Byte))) { XT_KERNEL_CHECK( ctx, out, @@ -122,6 +117,12 @@ Tensor& slice_copy_Tensor_out( (int)dim, get_element_size(out.scalar_type())); } else { + ET_KERNEL_CHECK( + ctx, + torch::executor::check_slice_copy_args(in, dim, step, out), + InvalidArgument, + out); + torch::executor::compute_slice(in, dim, start, length, step, out); } diff --git a/backends/cadence/fusion_g3/operators/op_softmax.cpp b/backends/cadence/fusion_g3/operators/op_softmax.cpp index ee87ebaf5a..14b128e928 100644 --- a/backends/cadence/fusion_g3/operators/op_softmax.cpp +++ b/backends/cadence/fusion_g3/operators/op_softmax.cpp @@ -39,14 +39,7 @@ Tensor& _softmax_out( // Adjust for negative dim dim = dim < 0 ? dim + executorch::runtime::nonzero_dim(in) : dim; - #ifdef OP_ARG_CHECK - ET_KERNEL_CHECK( - ctx, - torch::executor::check_softmax_args(in, dim, half_to_float, out), - InvalidArgument, - out); - ET_KERNEL_CHECK( ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); @@ -63,7 +56,8 @@ Tensor& _softmax_out( inp_shapes[i] = in_size[i]; } - if (out.scalar_type() == ScalarType::Float) { + if ((in.scalar_type() == ScalarType::Float) && + (out.scalar_type() == ScalarType::Float)) { const float* const inp_data = in.const_data_ptr<float>(); float* const out_data = out.mutable_data_ptr<float>(); int axis = dim; @@ -77,6 +71,12 @@ Tensor& _softmax_out( in.dim(), &axis); } else { + ET_KERNEL_CHECK( + ctx, + torch::executor::check_softmax_args(in, dim, half_to_float, out), + InvalidArgument, + out); + ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() { const CTYPE* const in_data = in.const_data_ptr<CTYPE>(); CTYPE* const out_data = out.mutable_data_ptr<CTYPE>(); diff --git a/backends/cadence/fusion_g3/operators/op_sub.cpp b/backends/cadence/fusion_g3/operators/op_sub.cpp index 4bae81c5b2..9bafec5df9 100644 --- a/backends/cadence/fusion_g3/operators/op_sub.cpp +++ b/backends/cadence/fusion_g3/operators/op_sub.cpp @@ -35,19 +35,6 @@ Tensor& sub_out( const Scalar& alpha, Tensor& out) { #ifdef OP_ARG_CHECK - ScalarType alpha_type = - torch::executor::native::utils::get_scalar_dtype(alpha); - // Check alpha type - ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out); - - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, - (canCast(common_type, out.scalar_type()) && - canCast(alpha_type, common_type)), - InvalidArgument, - out); - // Check Dim Order ET_KERNEL_CHECK( ctx, @@ -72,12 +59,12 @@ Tensor& sub_out( int inp2_shape[kTensorDimensionLimit]; int out_shape[kTensorDimensionLimit]; - bool broadcast = 0; + bool broadcast = false; int max_dim = a.dim() > b.dim() ? a.dim() : b.dim(); max_dim = out.dim() > max_dim ? out.dim() : max_dim; - bool optimized = 1; + bool optimized = true; for (int i = 0; i < max_dim; i++) { out_shape[i] = 1; @@ -103,16 +90,16 @@ Tensor& sub_out( for (int i = 0; i < out.dim(); i++) { if (((inp1_shape[i]) != (out_shape[i])) || ((inp2_shape[i]) != (out_shape[i]))) { - broadcast = 1; + broadcast = true; } } - if (((broadcast == 1) && (max_dim > kTensorDimensionLimit)) || + if (((broadcast) && (max_dim > kTensorDimensionLimit)) || (!(((a.scalar_type() == ScalarType::Int) || (a.scalar_type() == ScalarType::Float)) && (a.scalar_type() == b.scalar_type()) && (a.scalar_type() == out.scalar_type())))) { - optimized = 0; + optimized = false; } if ((a.scalar_type() == ScalarType::Int) && (optimized)) { @@ -207,6 +194,19 @@ Tensor& sub_out( ScalarType compute_type = torch::executor::native::utils::get_compute_type(common_type); + ScalarType alpha_type = + torch::executor::native::utils::get_scalar_dtype(alpha); + // Check alpha type + ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + canCast(alpha_type, common_type)), + InvalidArgument, + out); + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_alpha = torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(alpha); @@ -236,18 +236,6 @@ Tensor& sub_scalar_out( const Scalar& alpha, Tensor& out) { #ifdef OP_ARG_CHECK - ScalarType alpha_type = - torch::executor::native::utils::get_scalar_dtype(alpha); - // Check alpha type - ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out); - - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, - (common_type == out.scalar_type() && canCast(alpha_type, common_type)), - InvalidArgument, - out); - // Check Dim Order ET_KERNEL_CHECK( ctx, @@ -266,14 +254,16 @@ Tensor& sub_scalar_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "sub.Scalar_out"; - bool optimized = 1; - ScalarType b_type = torch::executor::native::utils::get_scalar_dtype(b); + bool optimized = true; if (!(((a.scalar_type() == ScalarType::Int) || (a.scalar_type() == ScalarType::Float)) && - (a.scalar_type() == b_type) && (a.scalar_type() == out.scalar_type()))) { - optimized = 0; + optimized = false; + } + + if ((b.isFloatingPoint()) && (a.scalar_type() == ScalarType::Int)) { + optimized = false; } if ((a.scalar_type() == ScalarType::Int) && (optimized)) { @@ -322,6 +312,19 @@ Tensor& sub_scalar_out( // Compute Dtype ScalarType compute_type = torch::executor::native::utils::get_compute_type(common_type); + + ScalarType alpha_type = + torch::executor::native::utils::get_scalar_dtype(alpha); + // Check alpha type + ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (common_type == out.scalar_type() && canCast(alpha_type, common_type)), + InvalidArgument, + out); + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(b);