From bb5a9d48059abb5223e785b04b0d0042ec80a3dc Mon Sep 17 00:00:00 2001 From: Andrew Kwangwoong Park Date: Wed, 24 Jul 2024 03:09:52 +0900 Subject: [PATCH] [GPU] Fix issue to calculate present layout's padding for KVCache (#25682) ### Details: - Fix issue to calculate present layout's padding for KVCache ### Tickets: - 146876 --- src/plugins/intel_gpu/src/graph/primitive_inst.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp index 522fb03f15c5bd..9fb822955c41a4 100644 --- a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp +++ b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp @@ -1189,6 +1189,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() { } const auto& desc = _node->as().get_primitive(); auto& past_layout = _impl_params->input_layouts[0]; + auto& new_layout = _impl_params->input_layouts[1]; auto& present_layout = _impl_params->output_layouts[0]; const auto& sequence_axis = desc->concat_axis; const auto& gather_axis = desc->gather_axis; @@ -1209,8 +1210,12 @@ void primitive_inst::do_runtime_in_place_kv_cache() { auto max_pad = kv_cache_inst::get_max_pad(past_layout, _deps[0].first->_max_output_layout_count[0], sequence_axis_legacy, "past_layout"); if (max_pad > 0) { - kv_cache_inst::update_pad(present_layout, max_pad - 1, sequence_axis_legacy); - GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id() << " Updated present_layout's pad : " << present_layout.to_string() << std::endl; + const auto new_seq_len = static_cast(new_layout.get_shape()[sequence_axis]); + if (max_pad - new_seq_len >= 0) { + kv_cache_inst::update_pad(present_layout, max_pad - new_seq_len, sequence_axis_legacy); + GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id() << " Updated present_layout's pad : " + << present_layout.to_string() << std::endl; + } auto& variable = get_network().get_variable(desc->variable_info.variable_id); variable.set_layout(present_layout); GPU_DEBUG_TRACE_DETAIL << "[do_runtime_in_place_kv_cache] " << id() << "Updated variable with present_layout"