Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TRANSFROMATIONS] Add support for 'inputs_embeds' input in SDPAToPA #27158

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 36 additions & 35 deletions src/core/src/pass/sdpa_to_paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,28 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode

auto sliding_window = v0::Constant::create(element::i32, Shape{}, {0}); // sliding_window

auto has_parameter = [=](const std::shared_ptr<ov::Model>& model,
praasz marked this conversation as resolved.
Show resolved Hide resolved
const std::string& name) -> std::shared_ptr<v0::Parameter> {
for (auto& param : model->inputs()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto& param : model->inputs()) {
for (const auto& param : model->inputs()) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

const auto& names = param.get_names();
if (names.find(name) != names.end()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (names.find(name) != names.end()) {
if (names.count(name)) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

if (auto casted_param = std::dynamic_pointer_cast<v0::Parameter>(param.get_node_shared_ptr())) {
return casted_param;
} else {
OPENVINO_THROW("The model is in the inconsistent state. Found input '",
name,
"', but couldn't cast it to v0::Parameter.");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (auto casted_param = std::dynamic_pointer_cast<v0::Parameter>(param.get_node_shared_ptr())) {
return casted_param;
} else {
OPENVINO_THROW("The model is in the inconsistent state. Found input '",
name,
"', but couldn't cast it to v0::Parameter.");
}
auto casted_param = ov::as_type_ptr<v0::Paramter>(param.get_node_shared_ptr()));
OPENVINO_ASSERT(casted_param , "The model is in the inconsistent state. Input is not paramter: ", casted_param);
return casted_param;

Use ov::as_type_ptr<> instead of std::dynamic_pointer_cast for operators.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

}
}

return nullptr;
};

auto input_ids_name = has_parameter(model, "input_ids") ? "input_ids" : "inputs_embeds";

std::shared_ptr<v0::Parameter> input_ids_node =
std::dynamic_pointer_cast<v0::Parameter>(model->input("input_ids").get_node_shared_ptr());
std::dynamic_pointer_cast<v0::Parameter>(model->input(input_ids_name).get_node_shared_ptr());
input_ids_node->set_partial_shape(PartialShape{-1});
auto unsqueezed_input_ids =
std::make_shared<v0::Unsqueeze>(input_ids_node, v0::Constant::create(element::i32, Shape{}, {1}));
Expand All @@ -66,17 +86,6 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
auto prev_max_seq_len =
std::make_shared<v1::Subtract>(max_context_len, std::make_shared<v0::Convert>(cur_seq_len, element::i32));

auto has_parameter = [=](const std::shared_ptr<ov::Model>& model, const std::string& name) -> bool {
for (auto& t : model->inputs()) {
const auto& names = t.get_names();
if (names.find(name) != names.end()) {
return true;
}
}

return false;
};

ParameterVector kv_parameters;
ParameterVector parameters_to_remove;
ResultVector results_to_remove; // # used, but cannot really track all Results in stateless model
Expand Down Expand Up @@ -136,30 +145,22 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
}

for (auto& param_name : {"beam_idx", "attention_mask"}) {
if (has_parameter(model, param_name)) {
if (const auto& param =
std::dynamic_pointer_cast<v0::Parameter>(model->input(param_name).get_node_shared_ptr())) {
model->remove_parameter(param);

if (param->output(0).get_target_inputs().size() == 0) {
std::stringstream consumers;
consumers << std::endl;
for (auto& input : param->output(0).get_target_inputs()) {
consumers << *input.get_node() << std::endl;
}
OPENVINO_ASSERT(param->output(0).get_target_inputs().size() == 0,
"PagedAttention transformation failed: couldn't remove ",
param->output(0).get_target_inputs().size(),
" inputs of ",
param_name,
" input: ",
consumers.str());
if (auto param = has_parameter(model, param_name)) {
model->remove_parameter(param);

if (param->output(0).get_target_inputs().size() == 0) {
std::stringstream consumers;
consumers << std::endl;
for (auto& input : param->output(0).get_target_inputs()) {
consumers << *input.get_node() << std::endl;
}
} else {
OPENVINO_THROW("The model is in the inconsistent state. Found input '",
param_name,
"', but couldn't cast it to v0::Parameter.");
return false;
OPENVINO_ASSERT(param->output(0).get_target_inputs().size() == 0,
"PagedAttention transformation failed: couldn't remove ",
param->output(0).get_target_inputs().size(),
" inputs of ",
param_name,
" input: ",
consumers.str());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,36 @@
from openvino._offline_transformations import paged_attention_transformation
from openvino._pyopenvino.op import _PagedAttentionExtension, Parameter, Result
from optimum.intel import OVModelForCausalLM
from optimum.intel.openvino import OVModelForVisualCausalLM
from typing import Type, Union

nodes_to_compare = ("ScaledDotProductAttention", "PagedAttentionExtension", "Parameter", "ReadValue", "Assign")

def get_models_list_type(file_name: str, cls: Union[Type[OVModelForCausalLM], Type[OVModelForVisualCausalLM]]):
models = []
for line_items in utils.parse_list_file(file_name):
if len(line_items) == 2:
model_name, model_link = line_items
models.append((model_name, model_link, None, None, cls))
elif len(line_items) == 4:
model_name, model_link, mark, reason = line_items
models.append((model_name, model_link, mark, reason))
elif len(line_items) > 4:
model_name, model_link, mark, reason = line_items[:4]
if not mark:
mark = None
if not reason:
reason = None
other = line_items[4:]
transformations = [item[8:] for item in other if item.startswith('ts_name:')]
layers = [item[6:] for item in other if item.startswith('layer:')]
models.append((model_name, model_link, mark, reason, transformations, layers))
else:
items = ','.join(line_items)
assert False, \
f'Incorrect model info fields {items}. It must contain either 2 or 4 or more than 4 fields.'
return models

def main():
use_cache_eviction = False
if len(sys.argv) >= 2:
Expand All @@ -55,32 +82,37 @@ def main():

if OUTPUT_FILE.exists() and OUTPUT_FILE.is_file():
OUTPUT_FILE.unlink()

with open(OUTPUT_FILE, 'w') as file:
model_list = utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-models-precommit"))
model_list = get_models_list_type(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-models-precommit"), OVModelForCausalLM)
model_list.extend(get_models_list_type(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-vl-models-precommit"), OVModelForVisualCausalLM))
print(OUTPUT_FILE)
print('ref_diff_map_cache_eviction = {' if use_cache_eviction else 'ref_diff_map = {', file=file)

for model_id, _, _, _ in model_list:
for model_id, _, _, _, cls in model_list:
# wrapping in try/catch block to continue printing models even if one has failed
try:
model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True)
model = cls.from_pretrained(model_id, export=True, trust_remote_code=True)
except:
print(f"Couldn't read {model_id}.")
continue

ov_model = model.model if cls is OVModelForCausalLM else model.lm_model

before_map = {}
for op in model.model.get_ordered_ops():
for op in ov_model.get_ordered_ops():
if op.get_type_name() in nodes_to_compare:
before_map[op.get_type_name()] = before_map.get(op.get_type_name(), 0) + 1

# wrapping in try/catch block to continue printing models even if one has failed
try:
paged_attention_transformation(model.model, use_cache_eviction, use_cache_eviction)
paged_attention_transformation(ov_model, use_cache_eviction, use_cache_eviction)
except:
print(f"Couldn't run SDPAToPA transformation on {model_id} and generate diffs.")
continue

after_map = {}
for op in model.model.get_ordered_ops():
for op in ov_model.get_ordered_ops():
if op.get_type_name() in nodes_to_compare:
after_map[op.get_type_name()] = after_map.get(op.get_type_name(), 0) + 1

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
katuni4ka/tiny-random-llava-next,https://huggingface.co/katuni4ka/tiny-random-llava-next
katuni4ka/tiny-random-minicpmv-2_6,https://huggingface.co/katuni4ka/tiny-random-minicpmv-2_6
katuni4ka/tiny-random-llava,https://huggingface.co/katuni4ka/tiny-random-llava
katuni4ka/tiny-random-nanollava,https://huggingface.co/katuni4ka/tiny-random-nanollava
70 changes: 63 additions & 7 deletions tests/model_hub_tests/transformation_tests/sdpa2pa_ref_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,34 @@
"ReadValue" : -12,
"Assign" : -12,
},
"katuni4ka/tiny-random-llava-next" : {
"PagedAttentionExtension" : 2,
"Parameter" : 7,
"ReadValue" : -4,
"ScaledDotProductAttention" : -2,
"Assign" : -4,
},
"katuni4ka/tiny-random-minicpmv-2_6" : {
"PagedAttentionExtension" : 2,
"Parameter" : 7,
"ReadValue" : -4,
"ScaledDotProductAttention" : -2,
"Assign" : -4,
},
"katuni4ka/tiny-random-llava" : {
"Assign" : -4,
"Parameter" : 7,
"ReadValue" : -4,
"ScaledDotProductAttention" : -2,
"PagedAttentionExtension" : 2,
},
"katuni4ka/tiny-random-nanollava" : {
"Assign" : -4,
"Parameter" : 7,
"ReadValue" : -4,
"ScaledDotProductAttention" : -2,
"PagedAttentionExtension" : 2,
},
}

ref_diff_map_cache_eviction = {
Expand Down Expand Up @@ -532,13 +560,13 @@
"Parameter" : 14,
"Assign" : -8,
},
"katuni4ka/tiny-random-minicpm" : {
"ScaledDotProductAttention" : -4,
"Parameter" : 14,
"PagedAttentionExtension" : 4,
"ReadValue" : -8,
"Assign" : -8,
},
"katuni4ka/tiny-random-minicpm" : {
"ScaledDotProductAttention" : -4,
"Parameter" : 14,
"PagedAttentionExtension" : 4,
"ReadValue" : -8,
"Assign" : -8,
},
"katuni4ka/tiny-random-falcon-40b" : {
"ScaledDotProductAttention" : -2,
"ReadValue" : -4,
Expand Down Expand Up @@ -609,4 +637,32 @@
"Parameter" : 20,
"Assign" : -12,
},
"katuni4ka/tiny-random-llava-next" : {
"Parameter" : 8,
"Assign" : -4,
"ReadValue" : -4,
"PagedAttentionExtension" : 2,
"ScaledDotProductAttention" : -2,
},
"katuni4ka/tiny-random-minicpmv-2_6" : {
"Parameter" : 8,
"Assign" : -4,
"ReadValue" : -4,
"PagedAttentionExtension" : 2,
"ScaledDotProductAttention" : -2,
},
"katuni4ka/tiny-random-llava" : {
"ReadValue" : -4,
"Parameter" : 8,
"ScaledDotProductAttention" : -2,
"PagedAttentionExtension" : 2,
"Assign" : -4,
},
"katuni4ka/tiny-random-nanollava" : {
"ReadValue" : -4,
"Parameter" : 8,
"ScaledDotProductAttention" : -2,
"PagedAttentionExtension" : 2,
"Assign" : -4,
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,29 @@
from openvino._offline_transformations import paged_attention_transformation
from openvino._pyopenvino.op import _PagedAttentionExtension
from optimum.intel import OVModelForCausalLM
from optimum.intel.openvino import OVModelForVisualCausalLM
from typing import Type, Union
import openvino as ov
from models_hub_common.utils import retry
import models_hub_common.utils as utils
from sdpa2pa_ref_diff import ref_diff_map, ref_diff_map_cache_eviction, nodes_to_compare
import pytest
import os
import re

@retry(3, exceptions=(OSError,), delay=1)
def run_pa(tmp_path, model_id, model_link, use_block_indices_inputs, use_score_outputs):
model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True)

def compare_diffs(ov_model: ov.Model,
model_id: str,
use_block_indices_inputs: bool,
use_score_outputs: bool):
before_map = {}
for op in model.model.get_ordered_ops():
for op in ov_model.get_ordered_ops():
if op.get_type_name() in nodes_to_compare:
before_map[op.get_type_name()] = before_map.get(op.get_type_name(), 0) + 1

paged_attention_transformation(model.model, use_block_indices_inputs, use_score_outputs)
paged_attention_transformation(ov_model, use_block_indices_inputs, use_score_outputs)

after_map = {}
for op in model.model.get_ordered_ops():
for op in ov_model.get_ordered_ops():
if op.get_type_name() in nodes_to_compare:
after_map[op.get_type_name()] = after_map.get(op.get_type_name(), 0) + 1

Expand All @@ -38,7 +41,7 @@ def run_pa(tmp_path, model_id, model_link, use_block_indices_inputs, use_score_o

assert reference_map == resulting_map

model_inputs = model.model.inputs
model_inputs = ov_model.inputs
for input in model_inputs:
names = list(input.get_names()) # names stored in as set (in this case usually of 1 element)
for name in names:
Expand All @@ -53,7 +56,7 @@ def run_pa(tmp_path, model_id, model_link, use_block_indices_inputs, use_score_o
block_indices_pattern = r'block_indices\.[0-9]+'
block_indices_counter = 0

model_inputs = model.model.inputs
model_inputs = ov_model.inputs
for input in model_inputs:
for name in list(input.get_names()):
if re.search(block_indices_pattern, name):
Expand All @@ -66,7 +69,7 @@ def run_pa(tmp_path, model_id, model_link, use_block_indices_inputs, use_score_o
score_pattern = r'scores\.[0-9]+'
score_outputs_counter = 0

model_outputs = model.model.outputs
model_outputs = ov_model.outputs
for output in model_outputs:
for name in list(output.get_names()):
if re.search(score_pattern, name):
Expand All @@ -75,6 +78,18 @@ def run_pa(tmp_path, model_id, model_link, use_block_indices_inputs, use_score_o
assert block_indices_counter == resulting_map["PagedAttentionExtension"], \
f"The number of scores outputs doesn't correspond to the expected value. Expected {resulting_map['PagedAttentionExtension']}, received {block_indices_counter}"

@retry(3, exceptions=(OSError,), delay=1)
def run_pa(tmp_path,
model_id,
model_link,
cls: Union[Type[OVModelForCausalLM], Type[OVModelForVisualCausalLM]],
use_block_indices_inputs,
use_score_outputs):
model = cls.from_pretrained(model_id, export=True, trust_remote_code=True)
ov_model = model.model if cls is OVModelForCausalLM else model.lm_model

compare_diffs(ov_model, model_id, use_block_indices_inputs, use_score_outputs)

@pytest.mark.precommit
@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-models-precommit")))
def test_pa_precommit(tmp_path, model_name, model_link, mark, reason, ie_device):
Expand All @@ -84,7 +99,7 @@ def test_pa_precommit(tmp_path, model_name, model_link, mark, reason, ie_device)
pytest.skip(reason)
elif mark == 'xfail':
pytest.xfail(reason)
run_pa(tmp_path, model_name, model_link, False, False)
run_pa(tmp_path, model_name, model_link, OVModelForCausalLM, False, False)

@pytest.mark.precommit
@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-models-precommit")))
Expand All @@ -95,4 +110,26 @@ def test_pa_precommit_use_cache_eviction(tmp_path, model_name, model_link, mark,
pytest.skip(reason)
elif mark == 'xfail':
pytest.xfail(reason)
run_pa(tmp_path, model_name, model_link, True, True)
run_pa(tmp_path, model_name, model_link, OVModelForCausalLM, True, True)

@pytest.mark.precommit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you manually verify that the tests are actually running in pre-commit jobs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the newly added models and tests are being run :)

image

@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-vl-models-precommit")))
def test_pa_vlm(tmp_path, model_name, model_link, mark, reason, ie_device):
assert mark is None or mark == 'skip' or mark == 'xfail', \
"Incorrect test case: {}, {}".format(model_name, model_link)
if mark == 'skip':
pytest.skip(reason)
elif mark == 'xfail':
pytest.xfail(reason)
run_pa(tmp_path, model_name, model_link, OVModelForVisualCausalLM, False, False)

@pytest.mark.precommit
@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-vl-models-precommit")))
def test_pa_vlm_use_cache_eviction(tmp_path, model_name, model_link, mark, reason, ie_device):
assert mark is None or mark == 'skip' or mark == 'xfail', \
"Incorrect test case: {}, {}".format(model_name, model_link)
if mark == 'skip':
pytest.skip(reason)
elif mark == 'xfail':
pytest.xfail(reason)
run_pa(tmp_path, model_name, model_link, OVModelForVisualCausalLM, True, True)
Loading
Loading