Skip to content

Commit

Permalink
[GPU] Fix incorrect broadcast axis for static SDPA case (openvinotool…
Browse files Browse the repository at this point in the history
…kit#28763)

### Details:
- This change fixes incorrect broadcast axis calculation for static SDPA
operations (dynamic case wasn't affected because the only static
dimensions were head_size and heads_num, which allowed the proper
setting of broadcast axis)

### Tickets:
 - [CVS-160643](https://jira.devtools.intel.com/browse/CVS-160643)
  • Loading branch information
sshlyapn authored Feb 3, 2025
1 parent 0213116 commit 398bde6
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,12 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
const auto value_shape = transpose_pshape(impl_param.get_input_layout(2).get_partial_shape(), desc->input_v_transpose_order);

OPENVINO_ASSERT(key_shape == value_shape, "[GPU] The shapes of key and value inputs are expected to be equal");
for (size_t i = 0; i < query_shape.size(); ++i) {
if (query_shape[i].is_static() && key_shape[i].is_static() && value_shape[i].is_static()) {
if (query_shape[i].get_length() > key_shape[i].get_length()) {
config.broadcast_axis = desc->input_k_transpose_order[i];
config.group_size = query_shape[i].get_length() / key_shape[i].get_length();
}

const auto num_heads_dim = 1;
if (query_shape[num_heads_dim].is_static() && key_shape[num_heads_dim].is_static() && value_shape[num_heads_dim].is_static()) {
if (query_shape[num_heads_dim].get_length() > key_shape[num_heads_dim].get_length()) {
config.broadcast_axis = desc->input_k_transpose_order[num_heads_dim];
config.group_size = query_shape[num_heads_dim].get_length() / key_shape[num_heads_dim].get_length();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,26 @@ const std::vector<std::vector<InputShape>> shapes{
{ov::test::InputShape{ov::PartialShape{-1, 1, -1, -1},
{ov::Shape{1, 1, 7, 7}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 10, 10}}}
},
},
}
};

const std::vector<std::vector<int64_t>> disable_transpose{};
const std::vector<std::vector<int64_t>> transpose_value{{0, 1, 2, 3}, {0, 1, 2, 3}, {0, 2, 1, 3}};
const std::vector<std::vector<int64_t>> transpose_all{{0, 2, 1, 3}, {0, 2, 1, 3}, {0, 2, 1, 3}};

const auto dynamic_shape_params = testing::Combine(testing::Values(ov::element::f16 /*, ov::element::f32 */),
testing::ValuesIn(shapes),
testing::Values(true, false),
testing::Values(true, false),
testing::Values(true, false),
testing::ValuesIn({disable_transpose, transpose_value}));

INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttn_GPU,
ScaledAttnLayerGPUTest,
dynamic_shape_params,
ScaledAttnLayerGPUTest::getTestCaseName);

const std::vector<std::vector<InputShape>> static_shapes{
// static shapes
{
// q shape
Expand All @@ -326,21 +345,32 @@ const std::vector<std::vector<InputShape>> shapes{
{ov::Shape{1, 1, 100, 100}}}
},
},
{
// q shape
{ov::test::InputShape{ov::PartialShape{1, 8, 64, 128},
{ov::Shape{1, 8, 64, 128}}}
},
// kv shape
{ov::test::InputShape{ov::PartialShape{1, 8, 13, 128},
{ov::Shape{1, 8, 13, 128}}}
},
// attn shape: [B, 1, -1, L0+L1]
{ov::test::InputShape{ov::PartialShape{1, 1, 64, 13},
{ov::Shape{1, 1, 64, 13}}}
},
},
};

const std::vector<std::vector<int64_t>> disable_transpose{};
const std::vector<std::vector<int64_t>> enable_transpose{{0, 1, 2, 3}, {0, 1, 2, 3}, {0, 2, 1, 3}};
const auto static_shape_params = testing::Combine(testing::Values(ov::element::f16),
testing::ValuesIn(static_shapes),
testing::Values(true, false),
testing::Values(true, false),
testing::Values(true, false),
testing::ValuesIn({disable_transpose, transpose_all}));

const auto params = testing::Combine(testing::Values(ov::element::f16 /*, ov::element::f32 */),
testing::ValuesIn(shapes),
testing::Values(true, false),
testing::Values(true, false),
testing::Values(true, false),
testing::ValuesIn({disable_transpose, enable_transpose}));

INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttn_GPU,
INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttnStatic_GPU,
ScaledAttnLayerGPUTest,
params,
static_shape_params,
ScaledAttnLayerGPUTest::getTestCaseName);

} // namespace

0 comments on commit 398bde6

Please sign in to comment.