Skip to content

Commit

Permalink
Fix TransposeSinking for Gather (openvinotoolkit#19202)
Browse files Browse the repository at this point in the history
* Fix TS gather

* enable pytest

* revert auto replaced comment
  • Loading branch information
itikhono authored Aug 15, 2023
1 parent 13f8ff4 commit 8509737
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,14 @@ TSGatherBackward::TSGatherBackward() {
if (success) {
size_t j = 0;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != new_shape[j] && shape[i] == 1) {
axes_val.push_back(i);
continue;
} else if (shape[i] != new_shape[j]) {
success = false;
break;
if (j >= new_shape.size() || shape[i] != new_shape[j]) {
if (shape[i] == 1) {
axes_val.push_back(i);
continue;
} else {
success = false;
break;
}
}
j++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ INSTANTIATE_TEST_SUITE_P(TSCommonGatherForward_3, TSTestFixture, test_forward_ga

struct GatherBackwardArguments {
OutputVector inputs_to_main;
Output<Node> new_Gather_first_input;
AxisVector new_transpose_order;
Output<Node> ref_Gather_axis_input;
AxisVector ref_transpose_order;
AxisVector ref_unsqueeze_axes;
};

auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) {
Expand All @@ -147,14 +148,14 @@ auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] = out_vec[1];
new_out_vec[2] = test_arguments.new_Gather_first_input;
new_out_vec[2] = test_arguments.ref_Gather_axis_input;
return new_out_vec;
};
auto new_transpose = [&test_arguments](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec = out_vec;
auto order = make_shared<Constant>(i32,
Shape{test_arguments.new_transpose_order.size()},
test_arguments.new_transpose_order);
Shape{test_arguments.ref_transpose_order.size()},
test_arguments.ref_transpose_order);
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
return new_out_vec;
};
Expand Down Expand Up @@ -197,13 +198,14 @@ auto test_backward_gather_optimization = [](const GatherBackwardArguments& test_
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] = make_shared<Squeeze>(out_vec[1]);
new_out_vec[2] = test_arguments.new_Gather_first_input;
new_out_vec[2] = test_arguments.ref_Gather_axis_input;
return new_out_vec;
};

auto unsqueeze_for = [&](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
auto axis = constant<int>(i32, {1}, {0});
return {make_shared<Unsqueeze>(out_vec[0], axis)};
const auto& axes_val = test_arguments.ref_unsqueeze_axes;
auto axes = constant<size_t>(i32, {axes_val.size()}, axes_val);
return {make_shared<Unsqueeze>(out_vec[0], axes)};
};

test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, update_gather_inputs}, {{0}, {1, 2}}};
Expand All @@ -215,13 +217,29 @@ auto test_backward_gather_optimization = [](const GatherBackwardArguments& test_
};

vector<GatherBackwardArguments> tests_arguments_bw_optimization{
{{{parameter(f32, {257, 8}), constant<int>(i32, {1, 2}, {0}), constant<int>(i32, {1}, {0})},
constant<int>(i32, {1}, {1}),
AxisVector{}}}};
{{parameter(f32, {257, 8}), constant<int>(i32, {1, 2}, {0}), constant<int>(i32, {1}, {0})},
constant<int>(i32, {1}, {1}),
AxisVector{},
AxisVector{0}},
{{parameter(f32, {4}), constant<int>(i32, {1}, {0}), constant<int>(i32, {1}, {0})},
constant<int>(i32, {1}, {0}),
AxisVector{},
AxisVector{0}},
{{parameter(f32, {4}), constant<int>(i32, {1, 1, 1}, {0}), constant<int>(i32, {1}, {0})},
constant<int>(i32, {1}, {0}),
AxisVector{},
AxisVector{0, 1, 2}},
};

INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_0,
TSTestFixture,
test_backward_gather_optimization(tests_arguments_bw_optimization[0]));
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_1,
TSTestFixture,
test_backward_gather_optimization(tests_arguments_bw_optimization[1]));
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_2,
TSTestFixture,
test_backward_gather_optimization(tests_arguments_bw_optimization[2]));
} // namespace gather
} // namespace testing
} // namespace transpose_sinking
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,7 @@ def create_max_pool_with_argmax_net(self, input_shape, ksize, strides, input_typ
True, False
])
@pytest.mark.parametrize("with_second_output", [
pytest.param(
True,
marks=pytest.mark.skip(reason="117415: TransposeSinking crash")
),
False
True, False
])
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
Expand Down

0 comments on commit 8509737

Please sign in to comment.