Skip to content

Commit

Permalink
[TRANSFORMATIONS] Make TotalSequenceLengthPattern pattern stricter (#…
Browse files Browse the repository at this point in the history
…25434)

[TRANSFORMATIONS] Make TotalSequenceLengthPattern pattern stricter

Make TotalSequenceLengthPattern pattern stricter to match one of the
cases when 'scale' is calculated from shape.

### Tickets:
 - CVS-138933

Signed-off-by: Andrii Staikov <[email protected]>

---------

Signed-off-by: Andrii Staikov <[email protected]>
  • Loading branch information
CuriousPanCake authored Jul 18, 2024
1 parent 19a5b95 commit 2aea2e0
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@

#pragma once

#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class PositionIDsReplacer;
class TRANSFORMATIONS_API PositionIDsReplacer;

} // namespace pass
} // namespace ov
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@

#pragma once

#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class PrevSequenceLengthPattern;
class TRANSFORMATIONS_API PrevSequenceLengthPattern;

} // namespace pass
} // namespace ov
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace ov {
namespace pass {

class StateManagementPattern;
class TRANSFORMATIONS_API StateManagementPattern;

} // namespace pass
} // namespace ov
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,13 @@

#pragma once

#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TotalSequenceLengthPattern;
class TRANSFORMATIONS_API TotalSequenceLengthPattern;

} // namespace pass
} // namespace ov
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#include "openvino/op/gather.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

using namespace ov::op;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
#include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp"

#include "openvino/cc/pass/itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
Expand All @@ -23,21 +23,74 @@ ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern(
auto kv_current = pattern::any_input();
auto kv_concat = pattern::wrap_type<v0::Concat>({kv_gather, kv_current});
auto kv_shape = pattern::wrap_type<v3::ShapeOf>({kv_concat});
auto seq = pattern::wrap_type<v8::Gather>({kv_shape, pattern::any_input(), pattern::any_input()});
auto gather_idx_label = pattern::wrap_type<v0::Constant>();
auto seq = pattern::wrap_type<v8::Gather>({kv_shape, gather_idx_label, pattern::any_input()});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
// TODO: Check that seq has axis that really takes sequence len but not any other dimension --
// use symbolic infra or look at the constant input
auto gather = m.get_match_root();
auto target_type = gather->get_output_element_type(0);
const auto& pattern_map = m.get_pattern_value_map();

auto concat = std::dynamic_pointer_cast<v0::Concat>(pattern_map.at(kv_concat).get_node_shared_ptr());
auto gather = pattern_map.at(seq).get_node_shared_ptr();
auto gather_idx =
std::dynamic_pointer_cast<v0::Constant>(pattern_map.at(gather_idx_label).get_node_shared_ptr());

if (!concat || !gather || !gather_idx || !gather_idx) {
return false;
}

auto gather_idx_data = gather_idx->cast_vector<int64_t>();

if (gather_idx_data.size() != 1) {
return false;
}

int64_t gather_idx_to_compare = gather_idx_data[0];

if (gather_idx_data[0] < 0) {
if (gather->input(0).get_partial_shape().is_static()) {
const auto& gather_data_shape = gather->input(0).get_shape();
gather_idx_to_compare = ov::util::normalize(gather_idx_data[0], gather_data_shape[0]);
} else {
return false;
}
}

std::shared_ptr<Node> replacement = max_context_len;
if (replacement->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(replacement, target_type);

int64_t concat_axis_to_compare = concat->get_axis();
if (concat_axis_to_compare < 0) {
// If it's dynamic, leave it negative as we cannot take dynamic
// dimension here so the next comparison would fail
if (concat->get_output_partial_shape(0).is_static()) {
const auto& concat_output_shape = concat->output(0).get_partial_shape();
concat_axis_to_compare =
ov::util::normalize(concat_axis_to_compare, concat_output_shape.rank().get_length());
}
}
auto required_shape = gather->get_output_partial_shape(0);
if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1));

if (concat_axis_to_compare == gather_idx_to_compare) {
auto target_type = gather->get_output_element_type(0);

if (replacement->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(replacement, target_type);
}

auto required_shape = gather->get_output_partial_shape(0);

if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1));
}
} else {
// TODO: change in the future when we start supporting dynamic shapes here
replacement = ov::util::get_constant_from_source(gather->output(0));
OPENVINO_ASSERT(replacement,
"TotalSequenceLengthPattern transformation failed to determine the dimension value after ",
"the Gather operation. Most probably, the required dimension is dynamic: ",
concat);
}

replace_node(gather, replacement);
return true;
};
Expand Down
127 changes: 127 additions & 0 deletions src/common/transformations/tests/sdpa_to_paged_attention_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,17 @@

#include "common_test_utils/test_common.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/power.hpp"
#include "openvino/op/read_value.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp"

using namespace ov;

Expand All @@ -24,4 +33,122 @@ TEST(SDPATOPATest, SDPANotPresent) {
ov::pass::Manager manager;
manager.register_pass<pass::SDPAToPagedAttention>();
EXPECT_THROW(manager.run_passes(model), ov::Exception);
}

TEST(SDPATOPATest, GatherIdx_ConcatAxis_EQ) {
// Almost replicating the pattern from the TotalSequenceLengthPattern transformation.
const int CONCAT_AXIS = 1;
const int GATHER_IDX = 1;

const auto input = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
auto variable = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{PartialShape::dynamic(), element::i32, "variable"});
const auto read_value = std::make_shared<op::v6::ReadValue>(input, variable);

const auto beam_idx = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
const auto gather_axis = op::v0::Constant::create(element::i64, Shape{}, {0});
const auto gather = std::make_shared<op::v8::Gather>(read_value, beam_idx, gather_axis);

const auto concat_input = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{1, 2, 3});
const auto concat = std::make_shared<op::v0::Concat>(NodeVector{gather, concat_input}, CONCAT_AXIS);

const auto shape_of = std::make_shared<op::v3::ShapeOf>(concat, element::i64);

const auto gather_indices = op::v0::Constant::create(element::i64, Shape{}, {GATHER_IDX});
const auto gather_axis2 = op::v0::Constant::create(element::i64, Shape{}, {0});
const auto gather1 = std::make_shared<op::v8::Gather>(shape_of, gather_indices, gather_axis2);

const auto result = std::make_shared<op::v0::Result>(gather1);
auto model = std::make_shared<Model>(ResultVector{result}, ParameterVector{input, beam_idx, concat_input});

const auto max_context_len = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});

ov::pass::Manager manager;
manager.set_per_pass_validation(false);
manager.register_pass<ov::pass::TotalSequenceLengthPattern>(max_context_len);
bool transformation_run = manager.run_passes(model);

EXPECT_TRUE(transformation_run);
const auto new_convert =
std::dynamic_pointer_cast<op::v0::Convert>(result->input(0).get_source_output().get_node_shared_ptr());
EXPECT_TRUE(new_convert);
const auto new_max_context_len =
std::dynamic_pointer_cast<op::v0::Parameter>(new_convert->input(0).get_source_output().get_node_shared_ptr());
EXPECT_TRUE(new_max_context_len);
EXPECT_TRUE(new_max_context_len == max_context_len);
}

TEST(SDPATOPATest, GatherIdx_ConcatAxis_NOTEQ_STATIC) {
// Almost replicating the pattern from the TotalSequenceLengthPattern transformation.
const int CONCAT_AXIS = 1;
const int GATHER_IDX = 0;

const auto input = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
auto variable = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{PartialShape::dynamic(), element::i32, "variable"});
const auto read_value = std::make_shared<op::v6::ReadValue>(input, variable);

const auto beam_idx = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
const auto gather_axis = op::v0::Constant::create(element::i64, Shape{}, {0});
const auto gather = std::make_shared<op::v8::Gather>(read_value, beam_idx, gather_axis);

const auto concat_input = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{1, 2, 3});
const auto concat = std::make_shared<op::v0::Concat>(NodeVector{gather, concat_input}, CONCAT_AXIS);

const auto shape_of = std::make_shared<op::v3::ShapeOf>(concat, element::i64);

const auto gather_indices = op::v0::Constant::create(element::i64, Shape{}, {GATHER_IDX});
const auto gather_axis2 = op::v0::Constant::create(element::i64, Shape{}, {0});
const auto gather1 = std::make_shared<op::v8::Gather>(shape_of, gather_indices, gather_axis2);

const auto result = std::make_shared<op::v0::Result>(gather1);
auto model = std::make_shared<Model>(ResultVector{result}, ParameterVector{input, beam_idx, concat_input});

const auto max_context_len = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});

ov::pass::Manager manager;
manager.set_per_pass_validation(false);
manager.register_pass<ov::pass::TotalSequenceLengthPattern>(max_context_len);
bool transformation_run = manager.run_passes(model);

EXPECT_TRUE(transformation_run);
const auto new_constant =
std::dynamic_pointer_cast<op::v0::Constant>(result->input(0).get_source_output().get_node_shared_ptr());
EXPECT_TRUE(new_constant);
}

TEST(SDPATOPATest, GatherIdx_ConcatAxis_NOTEQ_DYNAMIC) {
// Almost replicating the pattern from the TotalSequenceLengthPattern transformation.
const int CONCAT_AXIS = 1;
const int GATHER_IDX = 0;

const auto input = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
auto variable = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{PartialShape::dynamic(), element::i32, "variable"});
const auto read_value = std::make_shared<op::v6::ReadValue>(input, variable);

const auto beam_idx = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
const auto gather_axis = op::v0::Constant::create(element::i64, Shape{}, {0});
const auto gather = std::make_shared<op::v8::Gather>(read_value, beam_idx, gather_axis);

const auto concat_input =
std::make_shared<op::v0::Parameter>(element::i32,
PartialShape{Dimension(1, 2), Dimension(1, 3), Dimension(1, 4)});
const auto concat = std::make_shared<op::v0::Concat>(NodeVector{gather, concat_input}, CONCAT_AXIS);

const auto shape_of = std::make_shared<op::v3::ShapeOf>(concat, element::i64);

const auto gather_indices = op::v0::Constant::create(element::i64, Shape{}, {GATHER_IDX});
const auto gather_axis2 = op::v0::Constant::create(element::i64, Shape{}, {0});
const auto gather1 = std::make_shared<op::v8::Gather>(shape_of, gather_indices, gather_axis2);

const auto result = std::make_shared<op::v0::Result>(gather1);
auto model = std::make_shared<Model>(ResultVector{result}, ParameterVector{input, beam_idx, concat_input});

const auto max_context_len = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});

ov::pass::Manager manager;
manager.set_per_pass_validation(false);
manager.register_pass<ov::pass::TotalSequenceLengthPattern>(max_context_len);
EXPECT_THROW(manager.run_passes(model), ov::Exception);
}
Loading

0 comments on commit 2aea2e0

Please sign in to comment.