From 25395b9f597796289f59d7be847ac1d2613abbc8 Mon Sep 17 00:00:00 2001 From: pncosta22 <145407099+pncosta22@users.noreply.github.com> Date: Fri, 24 Jan 2025 19:45:02 -0800 Subject: [PATCH] Back out "Support Half/BFloat16 in max_pool2d (#7829)" Differential Revision: D68647398 Pull Request resolved: https://github.com/pytorch/executorch/pull/7955 --- .../cpu/op_max_pool2d_with_indices.cpp | 2 +- .../test/op_max_pool2d_with_indices_test.cpp | 136 +++++++++--------- 2 files changed, 65 insertions(+), 73 deletions(-) diff --git a/kernels/portable/cpu/op_max_pool2d_with_indices.cpp b/kernels/portable/cpu/op_max_pool2d_with_indices.cpp index 0d6d8406ca..80c291305b 100644 --- a/kernels/portable/cpu/op_max_pool2d_with_indices.cpp +++ b/kernels/portable/cpu/op_max_pool2d_with_indices.cpp @@ -70,7 +70,7 @@ std::tuple max_pool2d_with_indices_out( ret_val); ScalarType in_type = in.scalar_type(); - ET_SWITCH_REALHBF16_TYPES( + ET_SWITCH_REAL_TYPES( in_type, ctx, "max_pool2d_with_indices.out", CTYPE, [&]() { apply_kernel_2d_reduce_then_map_fn( [](const CTYPE in_val, diff --git a/kernels/test/op_max_pool2d_with_indices_test.cpp b/kernels/test/op_max_pool2d_with_indices_test.cpp index f92f927b6a..46f232521e 100644 --- a/kernels/test/op_max_pool2d_with_indices_test.cpp +++ b/kernels/test/op_max_pool2d_with_indices_test.cpp @@ -40,81 +40,73 @@ class OpMaxPool2DWithIndicesOutTest : public OperatorTest { out, indices); } - - template - void test_4d_dtype() { - torch::executor::testing::TensorFactory tf; - torch::executor::testing::TensorFactory tfLong; - - exec_aten::Tensor self = tf.make( - {2, 3, 5, 5}, - {28.75, -38.875, -7.0, -13.5, 70.75, 53.75, 69.625, 97.375, - 25.375, 99.5, -72.125, -87.25, 79.25, 42.0, -24.75, -15.5, - 12.5, -86.0, 85.5, -0.25, 67.125, 77.0, 53.375, -61.125, - 50.0, 3.875, 42.25, -37.375, 51.0, -60.875, 87.0, 32.25, - 73.5, 68.875, -84.375, -98.75, -30.125, 94.25, 1.625, -86.25, - -56.5, -68.0, 74.25, -51.25, 8.125, 71.375, -53.125, 4.875, - 77.5, -89.875, 4.5, -46.5, -46.375, -92.625, -85.5, -23.0, - -8.875, -12.0, -46.625, -88.625, 66.75, 87.75, 90.25, -45.0, - -78.125, 63.25, 28.75, 28.125, -30.375, 17.75, -16.0, 5.0, - 11.125, 88.625, -47.625, 72.25, 32.0, -7.625, 61.625, -63.125, - -22.75, 83.125, -40.375, -78.25, 49.5, -39.125, -89.625, 47.875, - -61.375, 7.75, 16.875, -96.375, -22.5, 8.5, 74.25, 12.75, - 90.125, 73.875, -71.75, -10.0, 41.25, 1.125, 10.375, -34.625, - 29.75, -27.5, 26.625, 81.0, -8.875, 17.625, 84.375, -23.625, - -53.875, -26.0, -67.375, -90.75, 16.375, 45.625, 99.5, 56.25, - -87.625, -65.5, -79.75, 31.875, 79.75, 6.375, 44.625, -55.25, - -5.5, -68.875, -38.625, 54.125, -3.125, 5.75, 29.25, -39.5, - 26.75, 68.25, -24.625, -53.0, 51.0, 90.625, 65.375, 43.875, - 90.875, -41.625, 99.875, 6.375, -31.25, -94.0}); - ::std::vector kernel_size_vec = {2, 2}; - exec_aten::ArrayRef kernel_size = exec_aten::ArrayRef( - kernel_size_vec.data(), kernel_size_vec.size()); - ::std::vector stride_vec = {1, 1}; - exec_aten::ArrayRef stride = - exec_aten::ArrayRef(stride_vec.data(), stride_vec.size()); - ::std::vector padding_vec = {0, 0}; - exec_aten::ArrayRef padding = - exec_aten::ArrayRef(padding_vec.data(), padding_vec.size()); - ::std::vector dilation_vec = {1, 1}; - exec_aten::ArrayRef dilation = - exec_aten::ArrayRef(dilation_vec.data(), dilation_vec.size()); - bool ceil_mode = false; - exec_aten::Tensor out = tf.zeros({2, 3, 4, 4}); - exec_aten::Tensor indices = tfLong.zeros({2, 3, 4, 4}); - exec_aten::Tensor out_expected = tf.make( - {2, 3, 4, 4}, - {69.625, 97.375, 97.375, 99.5, 69.625, 97.375, 97.375, 99.5, - 12.5, 79.25, 85.5, 85.5, 77.0, 77.0, 85.5, 85.5, - 87.0, 73.5, 73.5, 68.875, 87.0, 94.25, 94.25, 68.875, - -30.125, 94.25, 94.25, 8.125, 71.375, 74.25, 77.5, 77.5, - 4.5, -8.875, -12.0, -46.625, 87.75, 90.25, 90.25, -45.0, - 87.75, 90.25, 90.25, 17.75, 63.25, 28.75, 88.625, 88.625, - 83.125, 83.125, 61.625, 61.625, 83.125, 83.125, 47.875, 49.5, - 16.875, 47.875, 47.875, 74.25, 90.125, 90.125, 73.875, 74.25, - 41.25, 81.0, 81.0, 29.75, 84.375, 81.0, 81.0, 17.625, - 84.375, 45.625, 99.5, 99.5, 16.375, 45.625, 99.5, 99.5, - 54.125, 54.125, 5.75, 29.25, 54.125, 68.25, 68.25, 29.25, - 90.625, 90.625, 68.25, 90.875, 99.875, 99.875, 65.375, 90.875}); - exec_aten::Tensor indices_expected = tfLong.make( - {2, 3, 4, 4}, - {6, 7, 7, 9, 6, 7, 7, 9, 16, 12, 18, 18, 21, 21, 18, 18, - 5, 7, 7, 8, 5, 12, 12, 8, 11, 12, 12, 19, 20, 17, 23, 23, - 0, 6, 7, 8, 11, 12, 12, 13, 11, 12, 12, 19, 15, 16, 23, 23, - 6, 6, 3, 3, 6, 6, 12, 9, 15, 12, 12, 19, 21, 21, 22, 19, - 0, 7, 7, 4, 10, 7, 7, 9, 10, 17, 18, 18, 16, 17, 18, 18, - 6, 6, 8, 9, 6, 12, 12, 9, 16, 16, 12, 19, 21, 21, 17, 19}); - op_max_pool2d_with_indices_out( - self, kernel_size, stride, padding, dilation, ceil_mode, out, indices); - EXPECT_TENSOR_CLOSE(out, out_expected); - EXPECT_TENSOR_CLOSE(indices, indices_expected); - } }; TEST_F(OpMaxPool2DWithIndicesOutTest, SanityTest4D) { -#define TEST_ENTRY(ctype, dtype) test_4d_dtype(); - ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); -#undef TEST_ENTRY + torch::executor::testing::TensorFactory tfFloat; + torch::executor::testing::TensorFactory tfLong; + + exec_aten::Tensor self = tfFloat.make( + {2, 3, 5, 5}, + {28.75, -38.875, -7.0, -13.5, 70.75, 53.75, 69.625, 97.375, + 25.375, 99.5, -72.125, -87.25, 79.25, 42.0, -24.75, -15.5, + 12.5, -86.0, 85.5, -0.25, 67.125, 77.0, 53.375, -61.125, + 50.0, 3.875, 42.25, -37.375, 51.0, -60.875, 87.0, 32.25, + 73.5, 68.875, -84.375, -98.75, -30.125, 94.25, 1.625, -86.25, + -56.5, -68.0, 74.25, -51.25, 8.125, 71.375, -53.125, 4.875, + 77.5, -89.875, 4.5, -46.5, -46.375, -92.625, -85.5, -23.0, + -8.875, -12.0, -46.625, -88.625, 66.75, 87.75, 90.25, -45.0, + -78.125, 63.25, 28.75, 28.125, -30.375, 17.75, -16.0, 5.0, + 11.125, 88.625, -47.625, 72.25, 32.0, -7.625, 61.625, -63.125, + -22.75, 83.125, -40.375, -78.25, 49.5, -39.125, -89.625, 47.875, + -61.375, 7.75, 16.875, -96.375, -22.5, 8.5, 74.25, 12.75, + 90.125, 73.875, -71.75, -10.0, 41.25, 1.125, 10.375, -34.625, + 29.75, -27.5, 26.625, 81.0, -8.875, 17.625, 84.375, -23.625, + -53.875, -26.0, -67.375, -90.75, 16.375, 45.625, 99.5, 56.25, + -87.625, -65.5, -79.75, 31.875, 79.75, 6.375, 44.625, -55.25, + -5.5, -68.875, -38.625, 54.125, -3.125, 5.75, 29.25, -39.5, + 26.75, 68.25, -24.625, -53.0, 51.0, 90.625, 65.375, 43.875, + 90.875, -41.625, 99.875, 6.375, -31.25, -94.0}); + ::std::vector kernel_size_vec = {2, 2}; + exec_aten::ArrayRef kernel_size = exec_aten::ArrayRef( + kernel_size_vec.data(), kernel_size_vec.size()); + ::std::vector stride_vec = {1, 1}; + exec_aten::ArrayRef stride = + exec_aten::ArrayRef(stride_vec.data(), stride_vec.size()); + ::std::vector padding_vec = {0, 0}; + exec_aten::ArrayRef padding = + exec_aten::ArrayRef(padding_vec.data(), padding_vec.size()); + ::std::vector dilation_vec = {1, 1}; + exec_aten::ArrayRef dilation = + exec_aten::ArrayRef(dilation_vec.data(), dilation_vec.size()); + bool ceil_mode = false; + exec_aten::Tensor out = tfFloat.zeros({2, 3, 4, 4}); + exec_aten::Tensor indices = tfLong.zeros({2, 3, 4, 4}); + exec_aten::Tensor out_expected = tfFloat.make( + {2, 3, 4, 4}, + {69.625, 97.375, 97.375, 99.5, 69.625, 97.375, 97.375, 99.5, 12.5, + 79.25, 85.5, 85.5, 77.0, 77.0, 85.5, 85.5, 87.0, 73.5, + 73.5, 68.875, 87.0, 94.25, 94.25, 68.875, -30.125, 94.25, 94.25, + 8.125, 71.375, 74.25, 77.5, 77.5, 4.5, -8.875, -12.0, -46.625, + 87.75, 90.25, 90.25, -45.0, 87.75, 90.25, 90.25, 17.75, 63.25, + 28.75, 88.625, 88.625, 83.125, 83.125, 61.625, 61.625, 83.125, 83.125, + 47.875, 49.5, 16.875, 47.875, 47.875, 74.25, 90.125, 90.125, 73.875, + 74.25, 41.25, 81.0, 81.0, 29.75, 84.375, 81.0, 81.0, 17.625, + 84.375, 45.625, 99.5, 99.5, 16.375, 45.625, 99.5, 99.5, 54.125, + 54.125, 5.75, 29.25, 54.125, 68.25, 68.25, 29.25, 90.625, 90.625, + 68.25, 90.875, 99.875, 99.875, 65.375, 90.875}); + exec_aten::Tensor indices_expected = tfLong.make( + {2, 3, 4, 4}, + {6, 7, 7, 9, 6, 7, 7, 9, 16, 12, 18, 18, 21, 21, 18, 18, + 5, 7, 7, 8, 5, 12, 12, 8, 11, 12, 12, 19, 20, 17, 23, 23, + 0, 6, 7, 8, 11, 12, 12, 13, 11, 12, 12, 19, 15, 16, 23, 23, + 6, 6, 3, 3, 6, 6, 12, 9, 15, 12, 12, 19, 21, 21, 22, 19, + 0, 7, 7, 4, 10, 7, 7, 9, 10, 17, 18, 18, 16, 17, 18, 18, + 6, 6, 8, 9, 6, 12, 12, 9, 16, 16, 12, 19, 21, 21, 17, 19}); + op_max_pool2d_with_indices_out( + self, kernel_size, stride, padding, dilation, ceil_mode, out, indices); + EXPECT_TENSOR_CLOSE(out, out_expected); + EXPECT_TENSOR_CLOSE(indices, indices_expected); } TEST_F(OpMaxPool2DWithIndicesOutTest, SanityTest4D_2) {