From 8cdbedcf9e3c26c5e2cbfe7814f0eaf2260b761e Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Fri, 28 Jun 2024 16:52:59 +0400 Subject: [PATCH] [TF FE][MOC] Fix leftovers for Keras LSTM fusion transformation (#25268) **Details:** Fix leftovers for Keras LSTM fusion transformation https://github.com/openvinotoolkit/openvino/pull/25170 **Tickets:** TBD Signed-off-by: Kazantsev, Roman --- .../op_conversions/convert_ti_to_sequences.cpp | 14 ++------------ .../helper_transforms/tensor_list_ops_resolver.cpp | 10 +++++----- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/convert_ti_to_sequences.cpp b/src/common/transformations/src/transformations/op_conversions/convert_ti_to_sequences.cpp index f901399287738b..1888cfd22c2d0c 100644 --- a/src/common/transformations/src/transformations/op_conversions/convert_ti_to_sequences.cpp +++ b/src/common/transformations/src/transformations/op_conversions/convert_ti_to_sequences.cpp @@ -351,20 +351,10 @@ bool check_condition_true_pattern(const std::shared_ptr& cond_re const auto& condition_map = condition_matcher.get_pattern_value_map(); const auto& cond_const = ov::as_type_ptr(condition_map.at(cond_const_label).get_node_shared_ptr()); - if (!cond_const) { + bool cond_value = false; + if (!ov::op::util::get_constant_value(cond_const, cond_value) || !cond_value) { return false; } - if (ov::shape_size(cond_const->get_shape()) != 1) - return false; - const auto& type = cond_const->get_output_element_type(0); - if (type != ov::element::boolean) { - return false; - } - bool cond_value = cond_const->cast_vector()[0]; - if (!cond_value) { - return false; - } - // number of iteration is retrieve from the first input port num_iters_output = loop->input_value(0); diff --git a/src/frontends/tensorflow_common/src/helper_transforms/tensor_list_ops_resolver.cpp b/src/frontends/tensorflow_common/src/helper_transforms/tensor_list_ops_resolver.cpp index 09b7257eb68be7..a185b7abee4372 100644 --- a/src/frontends/tensorflow_common/src/helper_transforms/tensor_list_ops_resolver.cpp +++ b/src/frontends/tensorflow_common/src/helper_transforms/tensor_list_ops_resolver.cpp @@ -172,7 +172,7 @@ uint64_t get_new_param_idx(const std::vector& remove_parameter_idxs, u for (auto remove_idx : remove_parameter_idxs) { FRONT_END_GENERAL_CHECK(old_idx != remove_idx, "[TensorFlow Frontend] internal error: incorrect old_idx for " - "TensorListSliceInputAndConcatOutputReplacer transformation"); + "TensorListInLoopOptimization transformation"); if (remove_idx < old_idx) { ++num_removed; } @@ -181,7 +181,7 @@ uint64_t get_new_param_idx(const std::vector& remove_parameter_idxs, u // compute shifted index FRONT_END_GENERAL_CHECK(num_removed <= old_idx, "[TensorFlow Frontend] internal error: incorrect new parameter index computation " - "TensorListSliceInputAndConcatOutputReplacer transformation"); + "TensorListInLoopOptimization transformation"); return old_idx - num_removed; } } // namespace @@ -478,7 +478,7 @@ ov::frontend::tensorflow::pass::TensorListInLoopOptimization::TensorListInLoopOp std::dynamic_pointer_cast(body_result->get_input_node_shared_ptr(0)); FRONT_END_GENERAL_CHECK(tensor_list_set_item, "[TensorFlow Frontend] internal error: tensor_list_set_item is nullptr in " - "TensorListSliceInputAndConcatOutputReplacer"); + "TensorListInLoopOptimization"); // unsqueeze newly generated data at this iteration // that will be concatenated auto new_data = tensor_list_set_item->input_value(2); @@ -501,13 +501,13 @@ ov::frontend::tensorflow::pass::TensorListInLoopOptimization::TensorListInLoopOp const auto& body_param = body_params[param_idx]; FRONT_END_GENERAL_CHECK(body_param->get_output_target_inputs(0).size() == 1, "[TensorFlow Frontend] internal error: tensor list must have only consumer " - "TensorListGetItem operation in TensorListSliceInputAndConcatOutputReplacer"); + "TensorListGetItem operation in TensorListInLoopOptimization"); auto target_input = *(body_param->get_output_target_inputs(0).begin()); auto tensor_list_get_item = std::dynamic_pointer_cast(target_input.get_node()->shared_from_this()); FRONT_END_GENERAL_CHECK(tensor_list_get_item, "[TensorFlow Frontend] internal error: tensor list must have only consumer " - "TensorListGetItem operation in TensorListSliceInputAndConcatOutputReplacer"); + "TensorListGetItem operation in TensorListInLoopOptimization"); auto new_shape = body_param->get_output_partial_shape(0); if (new_shape.rank().is_static() && new_shape.rank().get_length() > 0) {