From b2128158c337bd79b789220d18d604a2e9e68025 Mon Sep 17 00:00:00 2001 From: Piotr Kowalczyk Date: Fri, 11 Oct 2024 10:02:23 +0200 Subject: [PATCH] [ref]: Added SearchSorted ref impl (#26958) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Details: - Added ref implementation of search sorted op ### Tickets: - *CVS-154061* Depends on: https://github.com/openvinotoolkit/openvino/pull/26904 --------- Signed-off-by: Kazantsev, Roman Signed-off-by: dependabot[bot] Co-authored-by: Michal Lukaszewski Co-authored-by: Pawel Raasz Co-authored-by: Andrey Babushkin Co-authored-by: Alicja Miloszewska Co-authored-by: Bogdan Pereanu Co-authored-by: Karol Blaszczak Co-authored-by: Tatiana Savina Co-authored-by: Anastasiya(Asya) Pronina Co-authored-by: Dmitry Matveev Co-authored-by: Andrei Beleiu Co-authored-by: Andrew Kwangwoong Park Co-authored-by: Roman Kazantsev Co-authored-by: Pavel Durandin Co-authored-by: Alexey Smirnov Co-authored-by: Hubert Błaszczyk <56601011+hub-bla@users.noreply.github.com> Co-authored-by: Vladimir Paramuzov Co-authored-by: Sergey Shlyapnikov Co-authored-by: Ivan Tikhonov Co-authored-by: Andrzej Kopytko Co-authored-by: Sebastian Golebiewski Co-authored-by: Alina Kladieva Co-authored-by: Ilya Lavrenov Co-authored-by: Maxim Vafin Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../include/openvino/opsets/opset15_tbl.hpp | 1 + .../openvino/reference/search_sorted.hpp | 55 ++++++++ src/core/tests/opset.cpp | 2 +- .../template/backend/ops/ops_evaluates.hpp | 4 + .../template/backend/ops/search_sorted.cpp | 50 +++++++ .../template/backend/opset_int_tbl.hpp | 1 + .../functional/op_reference/search_sorted.cpp | 123 ++++++++++++++++++ .../src/op_impl_check/single_op_graph.cpp | 9 ++ .../tests_data/search_sorted_data.h | 86 ++++++++++++ 9 files changed, 330 insertions(+), 1 deletion(-) create mode 100644 src/core/reference/include/openvino/reference/search_sorted.hpp create mode 100644 src/plugins/template/backend/ops/search_sorted.cpp create mode 100644 src/plugins/template/tests/functional/op_reference/search_sorted.cpp create mode 100644 src/tests/test_utils/unit_test_utils/tests_data/search_sorted_data.h diff --git a/src/core/include/openvino/opsets/opset15_tbl.hpp b/src/core/include/openvino/opsets/opset15_tbl.hpp index 1b46724ea67c31..a18093c4ef3f5c 100644 --- a/src/core/include/openvino/opsets/opset15_tbl.hpp +++ b/src/core/include/openvino/opsets/opset15_tbl.hpp @@ -24,3 +24,4 @@ _OPENVINO_OP_REG(StringTensorPack, ov::op::v15) _OPENVINO_OP_REG(BitwiseLeftShift, ov::op::v15) _OPENVINO_OP_REG(BitwiseRightShift, ov::op::v15) _OPENVINO_OP_REG(SliceScatter, ov::op::v15) +_OPENVINO_OP_REG(SearchSorted, ov::op::v15) diff --git a/src/core/reference/include/openvino/reference/search_sorted.hpp b/src/core/reference/include/openvino/reference/search_sorted.hpp new file mode 100644 index 00000000000000..ca5361c388c621 --- /dev/null +++ b/src/core/reference/include/openvino/reference/search_sorted.hpp @@ -0,0 +1,55 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/shape.hpp" +#include "openvino/reference/utils/coordinate_index.hpp" +#include "openvino/reference/utils/coordinate_transform.hpp" + +namespace ov { +namespace reference { +template +void search_sorted(const T* sorted, + const T* values, + TOut* out, + const Shape& sorted_shape, + const Shape& values_shape, + bool right_mode) { + const CoordinateTransformBasic values_transform{values_shape}; + + std::function compare_func = nullptr; + if (right_mode) { + compare_func = [](const T* begin, const T* end, T value) { + return std::lower_bound(begin, end, value, std::less_equal()); + }; + } else { + compare_func = [](const T* begin, const T* end, T value) { + return std::lower_bound(begin, end, value, std::less()); + }; + } + + for (const Coordinate& values_coord : values_transform) { + const auto values_index = coordinate_index(values_coord, values_shape); + const T value = values[values_index]; + + Coordinate sorted_coord_begin = values_coord; + sorted_coord_begin.back() = 0; + + Coordinate sorted_coord_last = values_coord; + sorted_coord_last.back() = sorted_shape.back(); + + const auto sorted_index_begin = coordinate_index(sorted_coord_begin, sorted_shape); + const auto sorted_index_last = coordinate_index(sorted_coord_last, sorted_shape); + + const T* idx_ptr = compare_func(sorted + sorted_index_begin, sorted + sorted_index_last, value); + + const ptrdiff_t sorted_index = (idx_ptr - sorted) - sorted_index_begin; + + out[values_index] = static_cast(sorted_index); + } +} + +} // namespace reference +} // namespace ov \ No newline at end of file diff --git a/src/core/tests/opset.cpp b/src/core/tests/opset.cpp index dfdc785421b295..2df8bade6a2f2c 100644 --- a/src/core/tests/opset.cpp +++ b/src/core/tests/opset.cpp @@ -75,7 +75,7 @@ INSTANTIATE_TEST_SUITE_P(opset, OpsetTestParams{ov::get_opset12, 178}, OpsetTestParams{ov::get_opset13, 186}, OpsetTestParams{ov::get_opset14, 188}, - OpsetTestParams{ov::get_opset15, 14}), + OpsetTestParams{ov::get_opset15, 15}), OpsetTestNameGenerator{}); class MyOpOld : public ov::op::Op { diff --git a/src/plugins/template/backend/ops/ops_evaluates.hpp b/src/plugins/template/backend/ops/ops_evaluates.hpp index 81a7976a56f63d..54d2e0cb7c8c63 100644 --- a/src/plugins/template/backend/ops/ops_evaluates.hpp +++ b/src/plugins/template/backend/ops/ops_evaluates.hpp @@ -552,3 +552,7 @@ extern template bool evaluate_node(std::shared_ extern template bool evaluate_node(std::shared_ptr node, ov::TensorVector& outputs, const ov::TensorVector& inputs); + +extern template bool evaluate_node(std::shared_ptr node, + ov::TensorVector& outputs, + const ov::TensorVector& inputs); \ No newline at end of file diff --git a/src/plugins/template/backend/ops/search_sorted.cpp b/src/plugins/template/backend/ops/search_sorted.cpp new file mode 100644 index 00000000000000..a9e0169f83ce0b --- /dev/null +++ b/src/plugins/template/backend/ops/search_sorted.cpp @@ -0,0 +1,50 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/reference/search_sorted.hpp" + +#include "evaluate_node.hpp" + +template +bool evaluate(const std::shared_ptr& op, + ov::TensorVector& outputs, + const ov::TensorVector& inputs) { + using T = typename ov::element_type_traits::value_type; + ov::reference::search_sorted(inputs[0].data(), + inputs[1].data(), + outputs[0].data(), + op->get_input_shape(0), + op->get_input_shape(1), + op->get_right_mode()); + return true; +} + +template <> +bool evaluate_node(std::shared_ptr node, + ov::TensorVector& outputs, + const ov::TensorVector& inputs) { + const auto& element_type = node->get_input_element_type(0); + +#define CASE(type) \ + case ov::element::type: \ + return evaluate(ov::as_type_ptr(node), outputs, inputs); + + switch (element_type) { + CASE(bf16); + CASE(f16); + CASE(f32); + CASE(f64); + CASE(i8); + CASE(i16); + CASE(i32); + CASE(i64); + CASE(u8); + CASE(u16); + CASE(u32); + CASE(u64); + default: + OPENVINO_THROW("Unhandled data type ", element_type, " in evaluate_node()"); + } +#undef CASE +} diff --git a/src/plugins/template/backend/opset_int_tbl.hpp b/src/plugins/template/backend/opset_int_tbl.hpp index 6e83f974005c0c..5f4e5737fcb567 100644 --- a/src/plugins/template/backend/opset_int_tbl.hpp +++ b/src/plugins/template/backend/opset_int_tbl.hpp @@ -175,6 +175,7 @@ _OPENVINO_OP_REG(StringTensorPack, ov::op::v15) _OPENVINO_OP_REG(BitwiseLeftShift, ov::op::v15) _OPENVINO_OP_REG(BitwiseRightShift, ov::op::v15) _OPENVINO_OP_REG(SliceScatter, ov::op::v15) +_OPENVINO_OP_REG(SearchSorted, ov::op::v15) _OPENVINO_OP_REG(AUGRUCell, ov::op::internal) _OPENVINO_OP_REG(AUGRUSequence, ov::op::internal) diff --git a/src/plugins/template/tests/functional/op_reference/search_sorted.cpp b/src/plugins/template/tests/functional/op_reference/search_sorted.cpp new file mode 100644 index 00000000000000..59868fa63ff8de --- /dev/null +++ b/src/plugins/template/tests/functional/op_reference/search_sorted.cpp @@ -0,0 +1,123 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/search_sorted.hpp" + +#include + +#include "base_reference_test.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/parameter.hpp" + +using namespace reference_tests; +using namespace ov; + +namespace { + +struct SearchSortedParams { + PartialShape sortedShape; + PartialShape valuesShape; + bool rightMode; + std::string testcaseName; + reference_tests::Tensor sorted; + reference_tests::Tensor values; + reference_tests::Tensor expectedOutput; +}; + +template +SearchSortedParams PrepareTestCaseParams(const PartialShape& sortedShape, + const PartialShape& valuesShape, + bool rightMode, + const std::vector& sortedData, + const std::vector& valuesData, + const std::vector& expectedData, + const std::string& testcaseName) { + SearchSortedParams ret; + const auto elementType = element::from(); + + ret.sortedShape = sortedShape; + ret.valuesShape = valuesShape; + ret.rightMode = rightMode; + ret.testcaseName = testcaseName; + ret.sorted = reference_tests::Tensor(elementType, sortedShape.to_shape(), sortedData); + ret.values = reference_tests::Tensor(elementType, valuesShape.to_shape(), valuesData); + ret.expectedOutput = reference_tests::Tensor(element::Type_t::i64, valuesShape.to_shape(), expectedData); + + return ret; +} + +class ReferenceSearchSortedTest : public testing::TestWithParam, public CommonReferenceTest { +public: + void SetUp() override { + const auto& params = GetParam(); + function = CreateFunction(params); + inputData = {params.sorted.data, params.values.data}; + refOutData = {params.expectedOutput.data}; + } + + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + auto param = obj.param; + std::ostringstream result; + result << "type=" << param.sorted.data.get_element_type(); + result << "_sortedShape=" << param.sortedShape; + result << "_valuesShape=" << param.valuesShape; + result << "_rightMode=" << param.rightMode; + result << "_=" << param.testcaseName; + + return result.str(); + } + +private: + static std::shared_ptr CreateFunction(const SearchSortedParams& params) { + const auto sorted = + std::make_shared(params.sorted.data.get_element_type(), params.sortedShape); + const auto values = + std::make_shared(params.values.data.get_element_type(), params.valuesShape); + + const auto op = std::make_shared(sorted, values, params.rightMode); + + return std::make_shared(NodeVector{op}, ParameterVector{sorted, values}); + } +}; + +TEST_P(ReferenceSearchSortedTest, CompareWithRefs) { + Exec(); +} + +template +std::vector generateParams() { + using T = typename element_type_traits::value_type; + std::vector params; + +#define TEST_DATA(sorted_shape, values_shape, right_mode, sorted_data, values_data, expected_output_data, description) \ + params.push_back(PrepareTestCaseParams(sorted_shape, \ + values_shape, \ + right_mode, \ + sorted_data, \ + values_data, \ + expected_output_data, \ + description)); + +#include "unit_test_utils/tests_data/search_sorted_data.h" +#undef TEST_DATA + + return params; +} + +std::vector generateCombinedParams() { + const std::vector> generatedParams{generateParams(), + generateParams()}; + std::vector combinedParams; + + for (const auto& params : generatedParams) { + combinedParams.insert(combinedParams.end(), params.begin(), params.end()); + } + return combinedParams; +} + +INSTANTIATE_TEST_SUITE_P(smoke_SearchSorted_With_Hardcoded_Refs, + ReferenceSearchSortedTest, + testing::ValuesIn(generateCombinedParams()), + ReferenceSearchSortedTest::getTestCaseName); +} // namespace diff --git a/src/tests/functional/plugin/conformance/test_runner/op_conformance_runner/src/op_impl_check/single_op_graph.cpp b/src/tests/functional/plugin/conformance/test_runner/op_conformance_runner/src/op_impl_check/single_op_graph.cpp index d80057b270c00c..f38427b7b192ed 100644 --- a/src/tests/functional/plugin/conformance/test_runner/op_conformance_runner/src/op_impl_check/single_op_graph.cpp +++ b/src/tests/functional/plugin/conformance/test_runner/op_conformance_runner/src/op_impl_check/single_op_graph.cpp @@ -2066,6 +2066,15 @@ std::shared_ptr generateRNNCellBase(const std::shared_ptr } } +std::shared_ptr generate(const std::shared_ptr& node) { + ov::ParameterVector params{std::make_shared(ov::element::f32, ov::Shape{16})}; + const auto values = + std::make_shared(ov::element::f32, ov::Shape{2, 3}, std::vector(6, 0)); + auto new_node = std::make_shared(params.at(0), values); + ov::ResultVector results{std::make_shared(new_node)}; + return std::make_shared(results, params, "SearchSortedGraph"); +} + std::shared_ptr generateSubGraphOp(const std::shared_ptr &node) { ov::ParameterVector params{std::make_shared(ov::element::f32, ov::Shape{{2, 2}}), std::make_shared(ov::element::f32, ov::Shape{{2, 2}}), diff --git a/src/tests/test_utils/unit_test_utils/tests_data/search_sorted_data.h b/src/tests/test_utils/unit_test_utils/tests_data/search_sorted_data.h new file mode 100644 index 00000000000000..ee355c2daee15e --- /dev/null +++ b/src/tests/test_utils/unit_test_utils/tests_data/search_sorted_data.h @@ -0,0 +1,86 @@ +#pragma once + +#define LIST(...) \ + { __VA_ARGS__ } + +// TEST_DATA(sorted_shape, +// values_shape, +// right_mode, +// sorted_data, +// values_data, +// expected_output_data, +// description) + +// NOTE: expected output were generated using pyTorch.searchsorted implementation. + +TEST_DATA(LIST(5), + LIST(2, 3), + false, + LIST(1, 3, 5, 7, 9), + LIST(3, 6, 9, 3, 6, 9), + LIST(1, 3, 4, 1, 3, 4), + "1d_tensor_1"); + +TEST_DATA(LIST(5), + LIST(4, 3), + false, + LIST(1, 3, 5, 7, 9), + LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20), + LIST(0, 3, 5, 0, 3, 4, 0, 0, 0, 4, 5, 5), + "1d_tensor_2"); + +TEST_DATA(LIST(5), + LIST(4, 3), + true, + LIST(1, 3, 5, 7, 9), + LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20), + LIST(0, 3, 5, 1, 3, 5, 1, 0, 0, 5, 5, 5), + "1d_tensor_2_right_mode"); + +TEST_DATA(LIST(5), + LIST(2, 2, 3), + false, + LIST(1, 3, 5, 7, 9), + LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20), + LIST(0, 3, 5, 0, 3, 4, 0, 0, 0, 4, 5, 5), + "1d_tensor_3"); + +TEST_DATA(LIST(5), + LIST(2, 2, 3), + true, + LIST(1, 3, 5, 7, 9), + LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20), + LIST(0, 3, 5, 1, 3, 5, 1, 0, 0, 5, 5, 5), + "1d_tensor_3_right_mode"); + +TEST_DATA(LIST(2, 5), + LIST(2, 3), + false, + LIST(1, 3, 5, 7, 9, 2, 4, 6, 8, 10), + LIST(3, 6, 9, 3, 6, 9), + LIST(1, 3, 4, 1, 2, 4), + "nd_tensor_1"); + +TEST_DATA(LIST(2, 5), + LIST(2, 3), + true, + LIST(1, 3, 5, 7, 9, 2, 4, 6, 8, 10), + LIST(3, 6, 9, 3, 6, 9), + LIST(2, 3, 5, 1, 3, 4), + "nd_tensor_1_right_mode"); + +TEST_DATA(LIST(2, 2, 5), + LIST(2, 2, 3), + false, + LIST(1, 3, 5, 7, 9, 0, 2, 4, 6, 8, -20, 5, 10, 23, 41, 100, 125, 130, 132, 139), + LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20), + LIST(0, 3, 5, 1, 3, 5, 1, 1, 1, 0, 0, 0), + "nd_tensor_2"); + +TEST_DATA(LIST(2, 2, 5), + LIST(2, 2, 3), + true, + LIST(1, 3, 5, 7, 9, 0, 2, 4, 6, 8, -20, 5, 10, 23, 41, 100, 125, 130, 132, 139), + LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20), + LIST(0, 3, 5, 1, 4, 5, 1, 1, 1, 0, 0, 0), + "nd_tensor_2"); \ No newline at end of file