Skip to content

Commit

Permalink
Group broadcast is implemented via commonly used UBR-pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
slyalin committed Nov 12, 2024
1 parent 9269453 commit 7eb07a6
Showing 1 changed file with 42 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include "core/operator_set.hpp"
#include "openvino/frontend/exception.hpp"

// TODO: Filter out unused headers

#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
Expand Down Expand Up @@ -51,7 +53,7 @@ namespace com_microsoft {
namespace detail {
namespace {

// FIXME: Reuse the same function from file attention.cpp
// FIXME: Reuse the same function from file attention.cpp, but it requires a bit of adaptation -- I have redesigned part of the inputs a bit here and in the helper functions below
ov::NodeVector split_to_QKV(const Output<ov::Node>& node,
int64_t num_heads,
const std::vector<int64_t>& qkv_hidden_sizes);
Expand Down Expand Up @@ -86,7 +88,6 @@ ov::Output<ov::Node> rope(const Output<ov::Node>& x, Output<ov::Node> cos, Outpu
cos = make_shared<Slice>(cos, zero, seq_len, step, zero);
sin = make_shared<Slice>(sin, zero, seq_len, step, zero);

Output cos_multiplier = make_shared<Concat>(OutputVector{cos, cos}, 1);
OutputVector x_split = make_shared<Split>(x, Constant::create(element::i32, Shape{}, {-1}), 2)->outputs();

Output res_0 = make_shared<Subtract>(
Expand All @@ -102,6 +103,36 @@ ov::Output<ov::Node> rope(const Output<ov::Node>& x, Output<ov::Node> cos, Outpu
return make_shared<Concat>(OutputVector{res_0, res_1}, -1);
}

ov::Output<ov::Node> broadcast_groups(const Output<ov::Node>& cache, const int num_kv_heads, const int num_heads) {
if(num_kv_heads == 1 || num_kv_heads == num_heads) {
// No broadcast or there is the broadcast that SDPA broadcastability can handle
return cache;
}

OPENVINO_ASSERT(num_heads % num_kv_heads == 0);
const auto broadcast_multiplier = num_heads/num_kv_heads;

auto unsqueeze = std::make_shared<v0::Unsqueeze>(cache, v0::Constant::create(element::i32, Shape{}, {2}));
auto shapeof = std::make_shared<v3::ShapeOf>(cache, element::i32);

auto broadcast_shape = std::make_shared<v0::Concat>(OutputVector{
get_elements(shapeof, {0, 1}),
v0::Constant::create(element::i32, Shape{1}, {broadcast_multiplier}),
get_elements(shapeof, {2, 3})
}, 0);

auto broadcast = std::make_shared<v3::Broadcast>(unsqueeze, broadcast_shape);

auto reshape_shape = std::make_shared<v0::Concat>(OutputVector{
v0::Constant::create(element::i32, Shape{3}, {0, num_heads, -1}),
get_elements(shapeof, {3})
}, 0);

auto reshape = std::make_shared<v1::Reshape>(broadcast, reshape_shape, true);

return reshape;
}


} // namespace
} // namespace detail
Expand Down Expand Up @@ -142,12 +173,20 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) {
K = std::make_shared<v0::Concat>(ov::OutputVector{past_K, K}, 2);
V = std::make_shared<v0::Concat>(ov::OutputVector{past_V, V}, 2);

const auto num_kv_heads = node.get_attribute_value<int64_t>("kv_num_heads");

K = detail::broadcast_groups(K, num_kv_heads, num_heads);
V = detail::broadcast_groups(V, num_kv_heads, num_heads);

// FIXME: Unaligned batch of sequences is not supported.
// That means all input sequence length should be the same and match input.shape[1]
// That means all input sequence lengths should be the same and match input.shape[2]
// We do not check that here because it depends on runtime values.
// If we want to implement not aligned batch of dimensions we have to form not uniform causal mask for attention that
// adds a significant porition of the code.

// FIXME: The same tensor at input/output of past/preset K and V are not supported.
// It requires more complex tensor manipulations that are introduce overhead into pure tensor-value data flow and should be implemented if we really have demand for that.
// Also inplace KV-cache modification logic is not supported efficiently in any plugins (CPU, GPU and NPU).

auto output = std::make_shared<v13::ScaledDotProductAttention>(Q, K, V, true);

Expand Down

0 comments on commit 7eb07a6

Please sign in to comment.