Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cache rotation inputs and CPU kernel implementation for cache rotation #27088

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from

Conversation

vshampor
Copy link
Contributor

Tickets:
153783

@github-actions github-actions bot added category: Core OpenVINO Core (aka ngraph) category: GPU OpenVINO GPU plugin category: CPU OpenVINO CPU plugin category: transformations OpenVINO Runtime library - Transformations category: CPP API OpenVINO CPP API bindings labels Oct 16, 2024
Comment on lines 22 to 23
get_input_size() == 15,
"PagedAttensionExtension expects 15 inputs, but it has ",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't look as optional inputs. According to the spec they could be omitted. If you replace by get_input_size() == 13 || get_input_size() == 15 it wouldn't be a big code modification but unlock a bit of flexibility in the transition period where various mixes of main ov and genai may happen. As we keep PA op internal and not very particular on op version numbering, then a bit of backward compatibility care would be nice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, but the alibi parameter doesn't seem to follow that approach

Comment on lines 418 to 419
pa_arguments.insert(pa_arguments.begin() + 13, v0::Constant::create(element::f32, Shape{0}, {}));
pa_arguments.insert(pa_arguments.begin() + 14, v0::Constant::create(element::i32, Shape{0}, {}));
Copy link
Contributor

@slyalin slyalin Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you make these inputs really optional, these two lines are not required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

get_input_partial_shape(13).rank().is_dynamic() ||
get_input_partial_shape(13).rank().get_length() == 0 ||
get_input_partial_shape(13).rank().get_length() == 1,
"Input `rotation_coefficients` should either have an empty shape or rank 1, but it has rank ",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Input `rotation_coefficients` should either have an empty shape or rank 1, but it has rank ",
"Input `rotation_coefficients` should either have rank 1 or omitted, but it has rank ",

"Empty" shape means [0] here, which have rank 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(13).rank().is_dynamic() ||
get_input_partial_shape(13).rank().get_length() == 0 ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
get_input_partial_shape(13).rank().get_length() == 0 ||

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 167 to 169
get_input_partial_shape(14).rank().get_length() == 0 ||
get_input_partial_shape(14).rank().get_length() == 1,
"Input `rotated_block_indices` should either have an empty shape or rank 1 but it has rank ",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same comment are applicable here as for input 13 above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -1576,6 +1591,11 @@ struct AttentionExecutor : public PagedAttentionExecutor {
if (alibi_slopes) {
alibi_slopes.assert_dims({H});
}

if (rotated_block_indices) {
// Rotation, and cache eviction, is limited to cases when Q, K and V embedding sizes are equal, e.g. S == Sv
Copy link
Contributor

@slyalin slyalin Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have cases where they are not: minicpm-3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed - realized that we don't need that limitation for cache rotation since we only rotate the K values

@@ -58,6 +59,10 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared
OPENVINO_ASSERT(alibi_const != nullptr);
prim.has_alibi = ov::shape_size(alibi_const->get_output_shape(0)) > 0;

std::shared_ptr<ov::op::v0::Constant> rotation_coefficients_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(op->get_input_node_shared_ptr(rotation_coefficients_idx));
OPENVINO_ASSERT(rotation_coefficients_const != nullptr);
prim.has_rotation_coefficients = ov::shape_size(alibi_const->get_output_shape(0)) > 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alibi_const shouldn't be used here -- bad copy&paste?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks.

@github-actions github-actions bot added the category: build OpenVINO cmake script / infra label Oct 30, 2024
@vshampor vshampor changed the title Add cache rotation inputs Add cache rotation inputs and CPU kernel implementation for cache rotation Nov 12, 2024
@dmitry-gorokhov
Copy link
Contributor

@luo-cheng2021 Please review CPU PA changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: build OpenVINO cmake script / infra category: Core OpenVINO Core (aka ngraph) category: CPP API OpenVINO CPP API bindings category: CPU OpenVINO CPU plugin category: GPU OpenVINO GPU plugin category: transformations OpenVINO Runtime library - Transformations
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants