Skip to content

Commit

Permalink
do negative indices normalization if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
antonvor committed Oct 4, 2023
1 parent becb2b6 commit 407d4b3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,27 @@ pass::ConvertGather8ToGather7::ConvertGather8ToGather7() {
auto axis_dim = data.get_partial_shape()[axis_value].get_length();

auto indices = indices_constant->cast_vector<int64_t>();
// check all the indices are not negative and not out of bound
// Check all the indices are not out of bound and check whether normalization is possible for negative values
bool do_indices_normalization = false;
for (size_t i = 0; i < indices.size(); i++) {
if (indices[i] < 0 || indices[i] >= axis_dim) {
if (indices[i] < -axis_dim || indices[i] >= axis_dim) {
return false;
}
if (indices[i] < 0) {
do_indices_normalization = true;
indices[i] = indices[i] + axis_dim;
}
}

std::shared_ptr<ov::Node> new_indices_constant;
if (do_indices_normalization) {
new_indices_constant = std::make_shared<ov::op::v0::Constant>(indices_constant->get_element_type(), indices_constant->get_shape(), indices);
} else {
new_indices_constant = indices_constant;
}

auto gather_v7_node = make_shared<ov::op::v7::Gather>(gather_v8_node->input_value(0),
gather_v8_node->input_value(1),
new_indices_constant,
gather_v8_node->input_value(2),
gather_v8_node->get_batch_dims());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ TEST_F(TransformationTestsF, ConvertGather8toGather7_negative_indices) {
model = std::make_shared<ov::Model>(NodeVector{gather_v8}, ParameterVector{data});

manager.register_pass<ov::pass::ConvertGather8ToGather7>();
comparator.enable(FunctionsComparator::CONST_VALUES);
}

{
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 3});
auto indices = opset8::Constant::create(element::i32, Shape{2, 2}, {2, 1, 0, 2});
auto axis = opset1::Constant::create(element::i32, Shape{1}, {1});
int64_t batch_dims = 1;

auto gather_v7 = std::make_shared<opset7::Gather>(data, indices, axis, batch_dims);

model_ref = std::make_shared<ov::Model>(NodeVector{gather_v7}, ParameterVector{data});
}
}

Expand Down

0 comments on commit 407d4b3

Please sign in to comment.