Skip to content

Commit

Permalink
[TEMPLATE]: SearchSorted: Fixed a bug when sorted has exactly one ele…
Browse files Browse the repository at this point in the history
…ment. Added more tests.
  • Loading branch information
pkowalc1 committed Nov 14, 2024
1 parent 4dc8abf commit 117f216
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/core/reference/include/openvino/reference/search_sorted.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void search_sorted(const T* sorted,
}

const size_t size = shape_size(values_shape);
const size_t sorted_inner_dim = sorted_shape.back();

auto func = [&](size_t i) {
auto it = values_transform.begin();
Expand All @@ -44,15 +45,12 @@ void search_sorted(const T* sorted,
Coordinate sorted_coord_begin = values_coord;
sorted_coord_begin.back() = 0;

Coordinate sorted_coord_last = values_coord;
sorted_coord_last.back() = sorted_shape.back();

const auto sorted_index_begin = coordinate_index(sorted_coord_begin, sorted_shape);
const auto sorted_index_last = coordinate_index(sorted_coord_last, sorted_shape);

const T* idx_ptr = compare_func(sorted + sorted_index_begin, sorted + sorted_index_last, value);
const T* sorted_begin_ptr = sorted + sorted_index_begin;
const T* sorted_end_ptr = sorted_begin_ptr + sorted_inner_dim;
const T* idx_ptr = compare_func(sorted_begin_ptr, sorted_end_ptr, value);

const ptrdiff_t sorted_index = (idx_ptr - sorted) - sorted_index_begin;
const ptrdiff_t sorted_index = idx_ptr - sorted_begin_ptr;

out[values_index] = static_cast<TOut>(sorted_index);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,22 @@ TEST_DATA(LIST(5),
LIST(0, 3, 5, 1, 3, 5, 1, 0, 0, 5, 5, 5),
"1d_tensor_3_right_mode");

TEST_DATA(LIST(1),
LIST(2, 2, 3),
false,
LIST(2),
LIST(0, 6, 20, 2, 6, 9, 1, 0, 0, 9, 10, 20),
LIST(0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1),
"1d_tensor_4");

TEST_DATA(LIST(1),
LIST(2, 2, 3),
true,
LIST(2),
LIST(0, 6, 20, 2, 6, 9, 1, 0, 0, 9, 10, 20),
LIST(0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1),
"1d_tensor_4_right_mode");

TEST_DATA(LIST(2, 5),
LIST(2, 3),
false,
Expand Down

0 comments on commit 117f216

Please sign in to comment.