Skip to content

Commit

Permalink
[SearchSortedDef]: Bugfixes in shape_infer. (openvinotoolkit#27013)
Browse files Browse the repository at this point in the history
### Details:
 - Fixed bugs with dynamic type and 1d input
 - Added more tests
  • Loading branch information
pkowalc1 authored Oct 14, 2024
1 parent 7250c1e commit b7d1f1d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
24 changes: 15 additions & 9 deletions src/core/shape_inference/include/search_sorted_shape_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,24 @@ std::vector<TRShape> shape_infer(const SearchSorted* op, const std::vector<TShap
op->validate();
const auto& sorted_shape = input_shapes[0];
const auto& values_shape = input_shapes[1];

auto output_shape = values_shape;
TShape::merge_into(output_shape, sorted_shape);

if (output_shape.rank().is_static()) {
auto last_it = output_shape.end() - 1;
if (values_shape.rank().is_static()) {
*last_it = *(input_shapes[1].end() - 1);
} else {
*last_it = Dimension::dynamic();
}

// 1. If we know that the sorted sequence is 1D, than output shape can be anything.
if (sorted_shape.rank().is_static() && sorted_shape.rank().get_length() == 1) {
return {std::move(output_shape)};
}

// 2. ND tensor case or rank not known.
auto sorted_shape_last_dynamic = sorted_shape;
if (sorted_shape.rank().is_static()) {
sorted_shape_last_dynamic[sorted_shape.rank().get_length() - 1] = Dimension::dynamic();
}

const bool sorted_values_merge_success = TShape::merge_into(output_shape, sorted_shape_last_dynamic);

NODE_VALIDATION_CHECK(op, sorted_values_merge_success, "Shapes of sorted sequence and values are not compatible.");

return {std::move(output_shape)};
}
} // namespace v15
Expand Down
2 changes: 1 addition & 1 deletion src/core/src/op/search_sorted.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ SearchSorted::SearchSorted(const Output<Node>& sorted_sequence, const Output<Nod
bool SearchSorted::validate() const {
NODE_VALIDATION_CHECK(this, get_input_size() == 2);
NODE_VALIDATION_CHECK(this,
get_input_element_type(0) == get_input_element_type(1),
get_input_element_type(0).compatible(get_input_element_type(1)),
"Sorted sequence and values must have the same element type.");

const auto& sorted_shape = get_input_partial_shape(0);
Expand Down
22 changes: 21 additions & 1 deletion src/core/tests/type_prop/search_sorted.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@ TEST(type_prop, search_sorted_shape_infer_sorted_dynamic) {
PerformShapeTest(PartialShape::dynamic(), {1, 3, 6}, {1, 3, 6});
}

TEST(type_prop, search_sorted_shape_infer_values_dynamic) {
TEST(type_prop, search_sorted_shape_infer_values_dynamic_1) {
PerformShapeTest({1, 3, 7, 5}, PartialShape::dynamic(), {1, 3, 7, -1});
}

TEST(type_prop, search_sorted_shape_infer_values_dynamic_2) {
PerformShapeTest({1666}, PartialShape::dynamic(), PartialShape::dynamic());
}

TEST(type_prop, search_sorted_shape_infer_different_last_dim) {
PerformShapeTest({1, 3, 7, 100}, {1, 3, 7, 10}, {1, 3, 7, 10});
}
Expand Down Expand Up @@ -73,6 +77,22 @@ TEST(type_prop, search_sorted_shape_infer_both_dynamic_5) {
PerformShapeTest({-1}, {-1, -1, 3}, {-1, -1, 3});
}

TEST(type_prop, search_sorted_shape_infer_both_dynamic_6) {
PerformShapeTest({-1}, PartialShape::dynamic(), PartialShape::dynamic());
}

TEST(type_prop, search_sorted_shape_infer_both_dynamic_7) {
PerformShapeTest({20, 30, 40, -1}, PartialShape::dynamic(), {20, 30, 40, -1});
}

TEST(type_prop, search_sorted_shape_infer_both_dynamic_8) {
PerformShapeTest({10, 20, 30, 40, -1}, {-1, -1, 30, -1, 100}, {10, 20, 30, 40, 100});
}

TEST(type_prop, search_sorted_shape_infer_both_dynamic_9) {
PerformShapeTest({-1, -1}, PartialShape::dynamic(), {-1, -1});
}

TEST(type_prop, search_sorted_shape_infer_different_types) {
auto sorted = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 3, 6});
auto values = make_shared<ov::op::v0::Parameter>(element::i32, Shape{1, 3, 6});
Expand Down

0 comments on commit b7d1f1d

Please sign in to comment.