Skip to content

Commit

Permalink
Correct handling of kv-cache position range (with limitations and ass…
Browse files Browse the repository at this point in the history
…umptions mentioned in the code).
  • Loading branch information
slyalin committed Nov 13, 2024
1 parent 7eb07a6 commit 6452b3d
Showing 1 changed file with 23 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,15 @@ ov::Output<ov::Node> get_dimensions(const ov::Output<ov::Node>& node, const std:
return get_elements(std::make_shared<v3::ShapeOf>(node), dims);
}

ov::Output<ov::Node> rope(const Output<ov::Node>& x, Output<ov::Node> cos, Output<ov::Node> sin, bool interleaved, const Output<ov::Node>& head_size, const Output<ov::Node>& seq_len) {
ov::Output<ov::Node> rope(
const Output<ov::Node>& x,
Output<ov::Node> cos,
Output<ov::Node> sin,
bool interleaved,
const Output<ov::Node>& head_size,
const Output<ov::Node>& pos_id_begin,
const Output<ov::Node>& pos_id_end
) {
OPENVINO_ASSERT(!interleaved, "rotary_interleaved is not supported"); // TODO: Support interleaved mode

using v1::Split;
Expand All @@ -85,8 +93,8 @@ ov::Output<ov::Node> rope(const Output<ov::Node>& x, Output<ov::Node> cos, Outpu
Output step = Constant::create(element::i32, Shape{1}, {1});

// cut for the current sequence length
cos = make_shared<Slice>(cos, zero, seq_len, step, zero);
sin = make_shared<Slice>(sin, zero, seq_len, step, zero);
cos = make_shared<Slice>(cos, pos_id_begin, pos_id_end, step, zero);
sin = make_shared<Slice>(sin, pos_id_begin, pos_id_end, step, zero);

OutputVector x_split = make_shared<Split>(x, Constant::create(element::i32, Shape{}, {-1}), 2)->outputs();

Expand Down Expand Up @@ -161,14 +169,22 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) {

const auto& past_K = nodes[3];
const auto& past_V = nodes[4];

const auto& seqlens_k = nodes[5];
const auto& total_sequence_length = nodes[6];
const auto& cos = nodes[7];
const auto& sin = nodes[8];
const bool rope_interleaved = node.get_attribute_value<int64_t>("rotary_interleaved", 0);

Q = detail::rope(Q, cos, sin, rope_interleaved, head_size, total_sequence_length);
K = detail::rope(K, cos, sin, rope_interleaved, head_size, total_sequence_length);
// FIXME: It works only when KV cache is dynamically growing and doesn't have unused space inside. So it is not compatible with statically-shaped KV cache.
// const auto past_seq_len = detail::get_dimensions(past_K, {0});
// TODO: GQA spec is not compatible with test model. Spec supposes 1D tensor, in the test model we have 2D tensor, flattening to work in both cases.

// FIXME: Unaligned elements in KV cache are not supported.
// We just get one of the seq lens as a common value for all past sequences
const auto& past_seq_len = detail::get_elements(std::make_shared<v1::Reshape>(seqlens_k, v0::Constant::create(element::i32, Shape{1}, {-1}), false), {0});

Q = detail::rope(Q, cos, sin, rope_interleaved, head_size, past_seq_len, total_sequence_length);
K = detail::rope(K, cos, sin, rope_interleaved, head_size, past_seq_len, total_sequence_length);

K = std::make_shared<v0::Concat>(ov::OutputVector{past_K, K}, 2);
V = std::make_shared<v0::Concat>(ov::OutputVector{past_V, V}, 2);
Expand All @@ -178,7 +194,7 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) {
K = detail::broadcast_groups(K, num_kv_heads, num_heads);
V = detail::broadcast_groups(V, num_kv_heads, num_heads);

// FIXME: Unaligned batch of sequences is not supported.
// FIXME: Unaligned batch of sequences is not supported. All past key-value are assumed to have the same length.
// That means all input sequence lengths should be the same and match input.shape[2]
// We do not check that here because it depends on runtime values.
// If we want to implement not aligned batch of dimensions we have to form not uniform causal mask for attention that
Expand Down

0 comments on commit 6452b3d

Please sign in to comment.