Skip to content

Commit

Permalink
[TF FE][MOC] Fix leftovers for Keras LSTM fusion transformation (#25268)
Browse files Browse the repository at this point in the history
**Details:** Fix leftovers for Keras LSTM fusion transformation
#25170

**Tickets:** TBD

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Jun 28, 2024
1 parent 59f1d69 commit 8cdbedc
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -351,20 +351,10 @@ bool check_condition_true_pattern(const std::shared_ptr<op::v0::Result>& cond_re
const auto& condition_map = condition_matcher.get_pattern_value_map();
const auto& cond_const =
ov::as_type_ptr<op::v0::Constant>(condition_map.at(cond_const_label).get_node_shared_ptr());
if (!cond_const) {
bool cond_value = false;
if (!ov::op::util::get_constant_value(cond_const, cond_value) || !cond_value) {
return false;
}
if (ov::shape_size(cond_const->get_shape()) != 1)
return false;
const auto& type = cond_const->get_output_element_type(0);
if (type != ov::element::boolean) {
return false;
}
bool cond_value = cond_const->cast_vector<bool>()[0];
if (!cond_value) {
return false;
}

// number of iteration is retrieve from the first input port
num_iters_output = loop->input_value(0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ uint64_t get_new_param_idx(const std::vector<uint64_t>& remove_parameter_idxs, u
for (auto remove_idx : remove_parameter_idxs) {
FRONT_END_GENERAL_CHECK(old_idx != remove_idx,
"[TensorFlow Frontend] internal error: incorrect old_idx for "
"TensorListSliceInputAndConcatOutputReplacer transformation");
"TensorListInLoopOptimization transformation");
if (remove_idx < old_idx) {
++num_removed;
}
Expand All @@ -181,7 +181,7 @@ uint64_t get_new_param_idx(const std::vector<uint64_t>& remove_parameter_idxs, u
// compute shifted index
FRONT_END_GENERAL_CHECK(num_removed <= old_idx,
"[TensorFlow Frontend] internal error: incorrect new parameter index computation "
"TensorListSliceInputAndConcatOutputReplacer transformation");
"TensorListInLoopOptimization transformation");
return old_idx - num_removed;
}
} // namespace
Expand Down Expand Up @@ -478,7 +478,7 @@ ov::frontend::tensorflow::pass::TensorListInLoopOptimization::TensorListInLoopOp
std::dynamic_pointer_cast<TensorListSetItem>(body_result->get_input_node_shared_ptr(0));
FRONT_END_GENERAL_CHECK(tensor_list_set_item,
"[TensorFlow Frontend] internal error: tensor_list_set_item is nullptr in "
"TensorListSliceInputAndConcatOutputReplacer");
"TensorListInLoopOptimization");
// unsqueeze newly generated data at this iteration
// that will be concatenated
auto new_data = tensor_list_set_item->input_value(2);
Expand All @@ -501,13 +501,13 @@ ov::frontend::tensorflow::pass::TensorListInLoopOptimization::TensorListInLoopOp
const auto& body_param = body_params[param_idx];
FRONT_END_GENERAL_CHECK(body_param->get_output_target_inputs(0).size() == 1,
"[TensorFlow Frontend] internal error: tensor list must have only consumer "
"TensorListGetItem operation in TensorListSliceInputAndConcatOutputReplacer");
"TensorListGetItem operation in TensorListInLoopOptimization");
auto target_input = *(body_param->get_output_target_inputs(0).begin());
auto tensor_list_get_item =
std::dynamic_pointer_cast<TensorListGetItem>(target_input.get_node()->shared_from_this());
FRONT_END_GENERAL_CHECK(tensor_list_get_item,
"[TensorFlow Frontend] internal error: tensor list must have only consumer "
"TensorListGetItem operation in TensorListSliceInputAndConcatOutputReplacer");
"TensorListGetItem operation in TensorListInLoopOptimization");

auto new_shape = body_param->get_output_partial_shape(0);
if (new_shape.rank().is_static() && new_shape.rank().get_length() > 0) {
Expand Down

0 comments on commit 8cdbedc

Please sign in to comment.