From 2507d89d9cfcf38bf43238aaf83ff74f066fe452 Mon Sep 17 00:00:00 2001 From: Ivan Tikhonov Date: Mon, 30 Sep 2024 18:33:32 +0400 Subject: [PATCH] Handle dynamic rank in TSUnsqueezeBackward transformation (#26786) ### Details: Handle dynamic rank in TSUnsqueezeBackward transformation ### Tickets: - *CVS-152373* --- .../transpose_sinking/ts_unsqueeze.cpp | 16 ++++++-- .../transpose_sinking/ts_common_test.cpp | 41 +++++++++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp index c46dffec0e5c97..cdeb9226ed236c 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp @@ -190,9 +190,19 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() { return false; } } else { - auto rank = main_node->get_output_partial_shape(0).rank(); - non_negative_axes = - util::try_get_normalized_axis_vector(unsqueeze_axes->get_tensor_view(), rank, *main_node); + const auto& axes = unsqueeze_axes->cast_vector(); + if (std::all_of(axes.begin(), axes.end(), [](int64_t axis) { + return axis >= 0; + })) { + non_negative_axes = std::vector(axes.begin(), axes.end()); + } else { + auto rank = main_node->get_output_partial_shape(0).rank(); + if (rank.is_dynamic()) { + return false; + } + non_negative_axes = + util::try_get_normalized_axis_vector(unsqueeze_axes->get_tensor_view(), rank, *main_node); + } } auto transpose_order_values = transpose_order->cast_vector(); diff --git a/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp b/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp index 67433eead4f639..d71c9006edd38a 100644 --- a/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp +++ b/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp @@ -1636,6 +1636,47 @@ auto test_backward_reshape_unsqueeze = []() { INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward, TSTestFixture, test_backward_reshape_unsqueeze()); + +auto test_backward_unsqueeze_dyn_rank = []() { + TestCase test_case; + + // Initialize common attributes + test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward); + test_case.num_main_ops = {1}; + test_case.inputs_to_main = { + parameter(element::f32, PartialShape::dynamic()), + constant(element::i32, {2}, {-1}), + }; + + auto dyn_transpose = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { + OutputVector result = out_vec; + for (const auto& idx : idxs) { + const auto& out = out_vec[idx]; + + // fill the order const with the stub values {-1, -2} + auto order = make_shared(element::i32, Shape{2}, vector{-1, -2}); + auto transpose = make_shared(out, order); + result[idx] = transpose; + } + return result; + }; + + // Test model description: + test_case.model.main_op = {CREATE_BINARY_FACTORY(Unsqueeze)}; + test_case.model.preprocess_outputs_of_main = {{dyn_transpose}, {{0}}}; + test_case.model.model_template = create_model; + + // Ref model description, the same as the original model, the transformation is not applied + // it's expected. + test_case.model_ref.main_op = {CREATE_BINARY_FACTORY(Unsqueeze)}; + test_case.model_ref.preprocess_outputs_of_main = {{dyn_transpose}, {{0}}}; + test_case.model_ref.model_template = create_model; + return wrapper(test_case); +}; + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackwardDynRank, + TSTestFixture, + test_backward_unsqueeze_dyn_rank()); } // namespace common } // namespace testing } // namespace transpose_sinking