-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TRANSFROMATIONS] Add support for 'inputs_embeds' input in SDPAToPA #27158
base: master
Are you sure you want to change the base?
Changes from 10 commits
5e1833d
ae8902c
1949b25
96a7f59
be84734
e41adc6
f1a3f5f
c461407
8f7e81a
d255822
8dc6631
a1eef49
f959f97
6f63934
5b05187
ecb5f47
aabe792
aef6560
32aaabb
89ad50f
c3ca9e4
8831a83
d6ace11
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -53,8 +53,27 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode | |||||||||||||||||||||
|
||||||||||||||||||||||
auto sliding_window = v0::Constant::create(element::i32, Shape{}, {0}); // sliding_window | ||||||||||||||||||||||
|
||||||||||||||||||||||
auto has_parameter = [=](const std::shared_ptr<ov::Model>& model, const std::string& name) -> std::shared_ptr<v0::Parameter> { | ||||||||||||||||||||||
for (auto& param : model->inputs()) { | ||||||||||||||||||||||
const auto& names = param.get_names(); | ||||||||||||||||||||||
if (names.find(name) != names.end()) { | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||||||||||||||||||||||
if (auto casted_param = std::dynamic_pointer_cast<v0::Parameter>(param.get_node_shared_ptr())) { | ||||||||||||||||||||||
return casted_param; | ||||||||||||||||||||||
} else { | ||||||||||||||||||||||
OPENVINO_THROW("The model is in the inconsistent state. Found input '", | ||||||||||||||||||||||
name, | ||||||||||||||||||||||
"', but couldn't cast it to v0::Parameter."); | ||||||||||||||||||||||
} | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||||||||||||||||||||||
} | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
return nullptr; | ||||||||||||||||||||||
}; | ||||||||||||||||||||||
|
||||||||||||||||||||||
auto input_ids_name = has_parameter(model, "input_ids") ? "input_ids" : "inputs_embeds"; | ||||||||||||||||||||||
|
||||||||||||||||||||||
std::shared_ptr<v0::Parameter> input_ids_node = | ||||||||||||||||||||||
std::dynamic_pointer_cast<v0::Parameter>(model->input("input_ids").get_node_shared_ptr()); | ||||||||||||||||||||||
std::dynamic_pointer_cast<v0::Parameter>(model->input(input_ids_name).get_node_shared_ptr()); | ||||||||||||||||||||||
input_ids_node->set_partial_shape(PartialShape{-1}); | ||||||||||||||||||||||
auto unsqueezed_input_ids = | ||||||||||||||||||||||
std::make_shared<v0::Unsqueeze>(input_ids_node, v0::Constant::create(element::i32, Shape{}, {1})); | ||||||||||||||||||||||
|
@@ -66,17 +85,6 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode | |||||||||||||||||||||
auto prev_max_seq_len = | ||||||||||||||||||||||
std::make_shared<v1::Subtract>(max_context_len, std::make_shared<v0::Convert>(cur_seq_len, element::i32)); | ||||||||||||||||||||||
|
||||||||||||||||||||||
auto has_parameter = [=](const std::shared_ptr<ov::Model>& model, const std::string& name) -> bool { | ||||||||||||||||||||||
for (auto& t : model->inputs()) { | ||||||||||||||||||||||
const auto& names = t.get_names(); | ||||||||||||||||||||||
if (names.find(name) != names.end()) { | ||||||||||||||||||||||
return true; | ||||||||||||||||||||||
} | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
return false; | ||||||||||||||||||||||
}; | ||||||||||||||||||||||
|
||||||||||||||||||||||
ParameterVector kv_parameters; | ||||||||||||||||||||||
ParameterVector parameters_to_remove; | ||||||||||||||||||||||
ResultVector results_to_remove; // # used, but cannot really track all Results in stateless model | ||||||||||||||||||||||
|
@@ -136,30 +144,22 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode | |||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
for (auto& param_name : {"beam_idx", "attention_mask"}) { | ||||||||||||||||||||||
if (has_parameter(model, param_name)) { | ||||||||||||||||||||||
if (const auto& param = | ||||||||||||||||||||||
std::dynamic_pointer_cast<v0::Parameter>(model->input(param_name).get_node_shared_ptr())) { | ||||||||||||||||||||||
model->remove_parameter(param); | ||||||||||||||||||||||
|
||||||||||||||||||||||
if (param->output(0).get_target_inputs().size() == 0) { | ||||||||||||||||||||||
std::stringstream consumers; | ||||||||||||||||||||||
consumers << std::endl; | ||||||||||||||||||||||
for (auto& input : param->output(0).get_target_inputs()) { | ||||||||||||||||||||||
consumers << *input.get_node() << std::endl; | ||||||||||||||||||||||
} | ||||||||||||||||||||||
OPENVINO_ASSERT(param->output(0).get_target_inputs().size() == 0, | ||||||||||||||||||||||
"PagedAttention transformation failed: couldn't remove ", | ||||||||||||||||||||||
param->output(0).get_target_inputs().size(), | ||||||||||||||||||||||
" inputs of ", | ||||||||||||||||||||||
param_name, | ||||||||||||||||||||||
" input: ", | ||||||||||||||||||||||
consumers.str()); | ||||||||||||||||||||||
if (auto param = has_parameter(model, param_name)) { | ||||||||||||||||||||||
model->remove_parameter(param); | ||||||||||||||||||||||
|
||||||||||||||||||||||
if (param->output(0).get_target_inputs().size() == 0) { | ||||||||||||||||||||||
std::stringstream consumers; | ||||||||||||||||||||||
consumers << std::endl; | ||||||||||||||||||||||
for (auto& input : param->output(0).get_target_inputs()) { | ||||||||||||||||||||||
consumers << *input.get_node() << std::endl; | ||||||||||||||||||||||
} | ||||||||||||||||||||||
} else { | ||||||||||||||||||||||
OPENVINO_THROW("The model is in the inconsistent state. Found input '", | ||||||||||||||||||||||
param_name, | ||||||||||||||||||||||
"', but couldn't cast it to v0::Parameter."); | ||||||||||||||||||||||
return false; | ||||||||||||||||||||||
OPENVINO_ASSERT(param->output(0).get_target_inputs().size() == 0, | ||||||||||||||||||||||
"PagedAttention transformation failed: couldn't remove ", | ||||||||||||||||||||||
param->output(0).get_target_inputs().size(), | ||||||||||||||||||||||
" inputs of ", | ||||||||||||||||||||||
param_name, | ||||||||||||||||||||||
" input: ", | ||||||||||||||||||||||
consumers.str()); | ||||||||||||||||||||||
} | ||||||||||||||||||||||
} | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
katuni4ka/tiny-random-llava-next,https://huggingface.co/katuni4ka/tiny-random-llava-next | ||
katuni4ka/tiny-random-minicpmv-2_6,https://huggingface.co/katuni4ka/tiny-random-minicpmv-2_6 | ||
katuni4ka/tiny-random-llava,https://huggingface.co/katuni4ka/tiny-random-llava | ||
katuni4ka/tiny-random-nanollava,https://huggingface.co/katuni4ka/tiny-random-nanollava |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,26 +4,29 @@ | |
from openvino._offline_transformations import paged_attention_transformation | ||
from openvino._pyopenvino.op import _PagedAttentionExtension | ||
from optimum.intel import OVModelForCausalLM | ||
from optimum.intel.openvino import OVModelForVisualCausalLM | ||
from typing import Type, Union | ||
import openvino as ov | ||
from models_hub_common.utils import retry | ||
import models_hub_common.utils as utils | ||
from sdpa2pa_ref_diff import ref_diff_map, ref_diff_map_cache_eviction, nodes_to_compare | ||
import pytest | ||
import os | ||
import re | ||
|
||
@retry(3, exceptions=(OSError,), delay=1) | ||
def run_pa(tmp_path, model_id, model_link, use_block_indices_inputs, use_score_outputs): | ||
model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True) | ||
|
||
def compare_diffs(ov_model: ov.Model, | ||
model_id: str, | ||
use_block_indices_inputs: bool, | ||
use_score_outputs: bool): | ||
before_map = {} | ||
for op in model.model.get_ordered_ops(): | ||
for op in ov_model.get_ordered_ops(): | ||
if op.get_type_name() in nodes_to_compare: | ||
before_map[op.get_type_name()] = before_map.get(op.get_type_name(), 0) + 1 | ||
|
||
paged_attention_transformation(model.model, use_block_indices_inputs, use_score_outputs) | ||
paged_attention_transformation(ov_model, use_block_indices_inputs, use_score_outputs) | ||
|
||
after_map = {} | ||
for op in model.model.get_ordered_ops(): | ||
for op in ov_model.get_ordered_ops(): | ||
if op.get_type_name() in nodes_to_compare: | ||
after_map[op.get_type_name()] = after_map.get(op.get_type_name(), 0) + 1 | ||
|
||
|
@@ -38,7 +41,7 @@ def run_pa(tmp_path, model_id, model_link, use_block_indices_inputs, use_score_o | |
|
||
assert reference_map == resulting_map | ||
|
||
model_inputs = model.model.inputs | ||
model_inputs = ov_model.inputs | ||
for input in model_inputs: | ||
names = list(input.get_names()) # names stored in as set (in this case usually of 1 element) | ||
for name in names: | ||
|
@@ -53,7 +56,7 @@ def run_pa(tmp_path, model_id, model_link, use_block_indices_inputs, use_score_o | |
block_indices_pattern = r'block_indices\.[0-9]+' | ||
block_indices_counter = 0 | ||
|
||
model_inputs = model.model.inputs | ||
model_inputs = ov_model.inputs | ||
for input in model_inputs: | ||
for name in list(input.get_names()): | ||
if re.search(block_indices_pattern, name): | ||
|
@@ -66,7 +69,7 @@ def run_pa(tmp_path, model_id, model_link, use_block_indices_inputs, use_score_o | |
score_pattern = r'scores\.[0-9]+' | ||
score_outputs_counter = 0 | ||
|
||
model_outputs = model.model.outputs | ||
model_outputs = ov_model.outputs | ||
for output in model_outputs: | ||
for name in list(output.get_names()): | ||
if re.search(score_pattern, name): | ||
|
@@ -75,6 +78,18 @@ def run_pa(tmp_path, model_id, model_link, use_block_indices_inputs, use_score_o | |
assert block_indices_counter == resulting_map["PagedAttentionExtension"], \ | ||
f"The number of scores outputs doesn't correspond to the expected value. Expected {resulting_map['PagedAttentionExtension']}, received {block_indices_counter}" | ||
|
||
@retry(3, exceptions=(OSError,), delay=1) | ||
def run_pa(tmp_path, | ||
model_id, | ||
model_link, | ||
cls: Union[Type[OVModelForCausalLM], Type[OVModelForVisualCausalLM]], | ||
use_block_indices_inputs, | ||
use_score_outputs): | ||
model = cls.from_pretrained(model_id, export=True, trust_remote_code=True) | ||
ov_model = model.model if cls is OVModelForCausalLM else model.lm_model | ||
|
||
compare_diffs(ov_model, model_id, use_block_indices_inputs, use_score_outputs) | ||
|
||
@pytest.mark.precommit | ||
@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-models-precommit"))) | ||
def test_pa_precommit(tmp_path, model_name, model_link, mark, reason, ie_device): | ||
|
@@ -84,7 +99,7 @@ def test_pa_precommit(tmp_path, model_name, model_link, mark, reason, ie_device) | |
pytest.skip(reason) | ||
elif mark == 'xfail': | ||
pytest.xfail(reason) | ||
run_pa(tmp_path, model_name, model_link, False, False) | ||
run_pa(tmp_path, model_name, model_link, OVModelForCausalLM, False, False) | ||
|
||
@pytest.mark.precommit | ||
@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-models-precommit"))) | ||
|
@@ -95,4 +110,26 @@ def test_pa_precommit_use_cache_eviction(tmp_path, model_name, model_link, mark, | |
pytest.skip(reason) | ||
elif mark == 'xfail': | ||
pytest.xfail(reason) | ||
run_pa(tmp_path, model_name, model_link, True, True) | ||
run_pa(tmp_path, model_name, model_link, OVModelForCausalLM, True, True) | ||
|
||
@pytest.mark.precommit | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you manually verify that the tests are actually running in pre-commit jobs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-vl-models-precommit"))) | ||
def test_pa_vlm(tmp_path, model_name, model_link, mark, reason, ie_device): | ||
assert mark is None or mark == 'skip' or mark == 'xfail', \ | ||
"Incorrect test case: {}, {}".format(model_name, model_link) | ||
if mark == 'skip': | ||
pytest.skip(reason) | ||
elif mark == 'xfail': | ||
pytest.xfail(reason) | ||
run_pa(tmp_path, model_name, model_link, OVModelForVisualCausalLM, False, False) | ||
|
||
@pytest.mark.precommit | ||
@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-vl-models-precommit"))) | ||
def test_pa_vlm_use_cache_eviction(tmp_path, model_name, model_link, mark, reason, ie_device): | ||
assert mark is None or mark == 'skip' or mark == 'xfail', \ | ||
"Incorrect test case: {}, {}".format(model_name, model_link) | ||
if mark == 'skip': | ||
pytest.skip(reason) | ||
elif mark == 'xfail': | ||
pytest.xfail(reason) | ||
run_pa(tmp_path, model_name, model_link, OVModelForVisualCausalLM, True, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed