diff --git a/src/core/reference/include/openvino/reference/search_sorted.hpp b/src/core/reference/include/openvino/reference/search_sorted.hpp index 7ea8ec1078a2a1..629509b28ef78d 100644 --- a/src/core/reference/include/openvino/reference/search_sorted.hpp +++ b/src/core/reference/include/openvino/reference/search_sorted.hpp @@ -32,6 +32,7 @@ void search_sorted(const T* sorted, } const size_t size = shape_size(values_shape); + const size_t sorted_inner_dim = sorted_shape.back(); auto func = [&](size_t i) { auto it = values_transform.begin(); @@ -44,15 +45,12 @@ void search_sorted(const T* sorted, Coordinate sorted_coord_begin = values_coord; sorted_coord_begin.back() = 0; - Coordinate sorted_coord_last = values_coord; - sorted_coord_last.back() = sorted_shape.back(); - const auto sorted_index_begin = coordinate_index(sorted_coord_begin, sorted_shape); - const auto sorted_index_last = coordinate_index(sorted_coord_last, sorted_shape); - - const T* idx_ptr = compare_func(sorted + sorted_index_begin, sorted + sorted_index_last, value); + const T* sorted_begin_ptr = sorted + sorted_index_begin; + const T* sorted_end_ptr = sorted_begin_ptr + sorted_inner_dim; + const T* idx_ptr = compare_func(sorted_begin_ptr, sorted_end_ptr, value); - const ptrdiff_t sorted_index = (idx_ptr - sorted) - sorted_index_begin; + const ptrdiff_t sorted_index = idx_ptr - sorted_begin_ptr; out[values_index] = static_cast(sorted_index); }; diff --git a/src/tests/test_utils/unit_test_utils/tests_data/search_sorted_data.h b/src/tests/test_utils/unit_test_utils/tests_data/search_sorted_data.h index affb0ba7defff2..43e680aa080686 100644 --- a/src/tests/test_utils/unit_test_utils/tests_data/search_sorted_data.h +++ b/src/tests/test_utils/unit_test_utils/tests_data/search_sorted_data.h @@ -69,6 +69,22 @@ TEST_DATA(LIST(5), LIST(0, 3, 5, 1, 3, 5, 1, 0, 0, 5, 5, 5), "1d_tensor_3_right_mode"); +TEST_DATA(LIST(1), + LIST(2, 2, 3), + false, + LIST(2), + LIST(0, 6, 20, 2, 6, 9, 1, 0, 0, 9, 10, 20), + LIST(0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1), + "1d_tensor_4"); + +TEST_DATA(LIST(1), + LIST(2, 2, 3), + true, + LIST(2), + LIST(0, 6, 20, 2, 6, 9, 1, 0, 0, 9, 10, 20), + LIST(0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1), + "1d_tensor_4_right_mode"); + TEST_DATA(LIST(2, 5), LIST(2, 3), false,