diff --git a/src/core/shape_inference/include/search_sorted_shape_inference.hpp b/src/core/shape_inference/include/search_sorted_shape_inference.hpp index da417f54121ee4..7ea0598cffbc87 100644 --- a/src/core/shape_inference/include/search_sorted_shape_inference.hpp +++ b/src/core/shape_inference/include/search_sorted_shape_inference.hpp @@ -16,18 +16,24 @@ std::vector shape_infer(const SearchSorted* op, const std::vectorvalidate(); const auto& sorted_shape = input_shapes[0]; const auto& values_shape = input_shapes[1]; + auto output_shape = values_shape; - TShape::merge_into(output_shape, sorted_shape); - - if (output_shape.rank().is_static()) { - auto last_it = output_shape.end() - 1; - if (values_shape.rank().is_static()) { - *last_it = *(input_shapes[1].end() - 1); - } else { - *last_it = Dimension::dynamic(); - } + + // 1. If we know that the sorted sequence is 1D, than output shape can be anything. + if (sorted_shape.rank().is_static() && sorted_shape.rank().get_length() == 1) { + return {std::move(output_shape)}; + } + + // 2. ND tensor case or rank not known. + auto sorted_shape_last_dynamic = sorted_shape; + if (sorted_shape.rank().is_static()) { + sorted_shape_last_dynamic[sorted_shape.rank().get_length() - 1] = Dimension::dynamic(); } + const bool sorted_values_merge_success = TShape::merge_into(output_shape, sorted_shape_last_dynamic); + + NODE_VALIDATION_CHECK(op, sorted_values_merge_success, "Shapes of sorted sequence and values are not compatible."); + return {std::move(output_shape)}; } } // namespace v15 diff --git a/src/core/src/op/search_sorted.cpp b/src/core/src/op/search_sorted.cpp index df179d925d054a..d3f26a674eef91 100644 --- a/src/core/src/op/search_sorted.cpp +++ b/src/core/src/op/search_sorted.cpp @@ -21,7 +21,7 @@ SearchSorted::SearchSorted(const Output& sorted_sequence, const Output(element::f32, Shape{1, 3, 6}); auto values = make_shared(element::i32, Shape{1, 3, 6});