Skip to content

Commit

Permalink
Implement runtime scatter update skippable (openvinotoolkit#27439)
Browse files Browse the repository at this point in the history
### Details:
- *Fix the memory corruption error found runtime skipped nodes for
scatter element update*
 - *Implement do_runtime_scatter_update_skippable*
 - *Add scatter_update to mark_runtime_skippable_nodes*
 - *Update memory clearing codes to support the following cases:*
   - ***Issued case1***
     - *iter0: node1(executed) -> node2(skipped) -> node3(skipped)*
     - *iter1: node1(skipped)  -> node2(skipped) -> node3(executed)*
   - ***Issued case2***
     - *iter0: node1(skipped)  -> node2(skipped) -> node3(skipped)*
     - *iter1: node1(executed) -> node2(skipped) -> node3(executed)*

### Tickets:
 - *154591*
  • Loading branch information
ahnyoung-paul authored Nov 11, 2024
1 parent 05e12d0 commit 602f701
Show file tree
Hide file tree
Showing 10 changed files with 357 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct kernel_impl_params final {
std::shared_ptr<const primitive> desc;
size_t unique_id;
bool _can_be_optimized = false;
bool _runtime_skippable = false;
std::vector<layout> input_layouts;
std::vector<layout> output_layouts;
std::vector<tensor> input_offsets;
Expand Down Expand Up @@ -145,6 +146,10 @@ struct kernel_impl_params final {
return _can_be_optimized;
}

bool runtime_skippable() const {
return _runtime_skippable;
}

template <class PType>
std::shared_ptr<const PType> typed_desc() const { return std::static_pointer_cast<const PType>(desc); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#include "non_zero_inst.h"
#include "non_max_suppression_inst.h"
#include "unique_inst.hpp"
#include "scatter_elements_update_inst.h"
#include "scatter_update_inst.h"
#include "scatter_nd_update_inst.h"
#include "program_helpers.h"

using namespace cldnn;
Expand Down Expand Up @@ -201,5 +204,56 @@ void mark_runtime_skippable_nodes::run(program& p) {
GPU_DEBUG_TRACE_DETAIL << "[mark_runtime_skippable_nodes] : " << node.id() << " can_be_optimized" << std::endl;
}
});

program_helpers::do_for_types<scatter_elements_update>(*node, [](scatter_elements_update_node & node){
auto impl_params = node.get_kernel_impl_params();

if ((node.is_output() && node.get_dependency(0).is_input())
|| node.has_fused_primitives()
|| (impl_params->get_input_layout(0).format != impl_params->get_output_layout().format)
|| (impl_params->get_input_layout(0).data_type != impl_params->get_output_layout().data_type))
return;

if (node.is_dynamic()) {
node.can_be_optimized(true);
// Set runtime skippable only when the node is set as can_be_optimized finally.
node.set_runtime_skippable(true);
GPU_DEBUG_TRACE_DETAIL << "[mark_runtime_skippable_nodes] : " << node.id() << " can_be_optimized" << std::endl;
}
});

program_helpers::do_for_types<scatter_update>(*node, [](scatter_update_node & node){
auto impl_params = node.get_kernel_impl_params();

if ((node.is_output() && node.get_dependency(0).is_input())
|| node.has_fused_primitives()
|| (impl_params->get_input_layout(0).format != impl_params->get_output_layout().format)
|| (impl_params->get_input_layout(0).data_type != impl_params->get_output_layout().data_type))
return;

if (node.is_dynamic()) {
node.can_be_optimized(true);
// Set runtime skippable only when the node is set as can_be_optimized finally.
node.set_runtime_skippable(true);
GPU_DEBUG_TRACE_DETAIL << "[mark_runtime_skippable_nodes] : " << node.id() << " can_be_optimized" << std::endl;
}
});

program_helpers::do_for_types<scatter_nd_update>(*node, [](scatter_nd_update_node & node){
auto impl_params = node.get_kernel_impl_params();

if ((node.is_output() && node.get_dependency(0).is_input())
|| node.has_fused_primitives()
|| (impl_params->get_input_layout(0).format != impl_params->get_output_layout().format)
|| (impl_params->get_input_layout(0).data_type != impl_params->get_output_layout().data_type))
return;

if (node.is_dynamic()) {
node.can_be_optimized(true);
// Set runtime skippable only when the node is set as can_be_optimized finally.
node.set_runtime_skippable(true);
GPU_DEBUG_TRACE_DETAIL << "[mark_runtime_skippable_nodes] : " << node.id() << " can_be_optimized" << std::endl;
}
});
}
}
11 changes: 5 additions & 6 deletions src/plugins/intel_gpu/src/graph/impls/ocl/primitive_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
#include "permute_inst.h"
#include "strided_slice_inst.h"
#include "broadcast_inst.h"
#include "scatter_update_inst.h"
#include "scatter_elements_update_inst.h"
#include "scatter_nd_update_inst.h"

#include <vector>
#include <list>
Expand Down Expand Up @@ -88,12 +91,8 @@ struct typed_primitive_impl_ocl : public typed_primitive_impl<PType> {
static std::unique_ptr<primitive_impl> create(const typed_program_node<PType>& arg, const kernel_impl_params& impl_param) {
// concat buffer fusing for dynamic shape is adaptively applied at runtime. So we need to build dynamic impl at build time.
if (impl_param.can_be_optimized() &&
!((impl_param.is_type<concatenation>() ||
impl_param.is_type<gather>() ||
impl_param.is_type<permute>() ||
impl_param.is_type<strided_slice>() ||
impl_param.is_type<broadcast>() ||
impl_param.is_type<crop>()) && impl_param.is_dynamic())) {
!((impl_param.runtime_skippable() || impl_param.is_type<crop>()) &&
impl_param.is_dynamic())) {
return make_unique<ImplType>(kernel_selector::kernel_data{});
}
auto kernel_params = ImplType::get_kernel_params(ImplType::static_canonicalize_shapes(impl_param));
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/include/primitive_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ class primitive_inst {
void do_runtime_in_place_concat();
void do_runtime_in_place_kv_cache();
void do_runtime_in_place_crop();
void do_runtime_skip_scatter_update();
void configure_shape_of_dependencies();

memory::ptr fused_memory(size_t dep_id) const {
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/include/program_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ struct program_node {
get_unique_id(), in_layouts, out_layouts, get_fused_primitives()));
params->memory_deps = get_const_memory_deps();
params->_can_be_optimized = this->optimized;
params->_runtime_skippable = this->runtime_skippable;
params->in_port_to_shape_info_offset = get_input_port_to_shape_info_offset_map();
params->out_port_to_shape_info_offset = get_output_port_to_shape_info_offset_map();
auto deps = get_dependencies();
Expand Down
86 changes: 85 additions & 1 deletion src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
#include "shape_of_inst.h"
#include "softmax_inst.h"
#include "strided_slice_inst.h"
#include "scatter_elements_update_inst.h"
#include "scatter_nd_update_inst.h"
#include "scatter_update_inst.h"
#include "gemm_inst.h"
#include "assign_inst.h"
#include "read_value_inst.h"
Expand Down Expand Up @@ -737,15 +740,57 @@ event::ptr primitive_inst::realloc_if_needed() {

// Clear out memory if was previously reused, but now primitive can't be optimized
if (!_node->is_type<concatenation>() && (_node->is_runtime_skippable() || _node->is_type<crop>())) {
std::function<void(cldnn::primitive_inst*, cldnn::memory::ptr)> reset_user_output_memory;
reset_user_output_memory = [&](cldnn::primitive_inst* curr_inst, cldnn::memory::ptr input_mem_ptr) {
auto curr_output_memory_ptr = curr_inst->output_memory_ptr(0);
if (curr_inst->can_be_optimized()
&& (curr_output_memory_ptr
&& get_network().get_engine().is_the_same_buffer(*curr_output_memory_ptr, *input_mem_ptr))) {
if (curr_inst->mem_allocated()) {
get_network().get_memory_pool().release_memory(curr_inst->_outputs[0].get(),
curr_inst->get_node().get_unique_id(), curr_inst->id(), get_network_id());
_mem_allocated = false;
}
curr_inst->_outputs[0] = nullptr;
curr_inst->_max_output_layout_count[0] = 0;
for (auto& user_inst : curr_inst->get_user_insts()) {
reset_user_output_memory(user_inst, input_mem_ptr);
}
}
};
if (can_be_optimized()) {
_max_output_layout_count = _deps[0].first->_max_output_layout_count;
GPU_DEBUG_PROFILED_STAGE_MEMALLOC_INFO("can_be_optimized");
// If the inst is optimized out but it executed at the previous iteration,
// reset all output memory of users which was optimized out at the previous iteration.
// Ex.
// * iter0: node1(executed) -> node2(skipped) -> node3(skipped)
// * iter1: node1(skipped) -> node2(skipped) -> node3(executed)
if (_outputs[0] && dep_memory_ptr(0)
&& !_network.get_engine().is_the_same_buffer(dep_memory(0), output_memory(0))) {
for (auto& user_inst : get_user_insts()) {
reset_user_output_memory(user_inst, dep_memory_ptr(0));
}
}
return ev;
} else if (_outputs[0] && dep_memory_ptr(0) &&
_network.get_engine().is_the_same_buffer(dep_memory(0), output_memory(0))) {
// Clear out memory if was previously reused, but now primitive can't be optimized
if (mem_allocated()) {
get_network().get_memory_pool().release_memory(_outputs[0].get(),
get_node().get_unique_id(), id(), get_network_id());
_mem_allocated = false;
}
_outputs[0] = nullptr;
_max_output_layout_count[0] = 0;
// Check users recursively and if the users is can_be_optimized && runtime_skippable
// && output_memory of user is same as current input memory,
// then reset the users output memory too.
// Ex.
// * iter0: node1(skipped) -> node2(skipped) -> node3(skipped)
// * iter1: node1(executed) -> node2(skipped) -> node3(executed)
for (auto& user_inst : get_user_insts()) {
reset_user_output_memory(user_inst, dep_memory_ptr(0));
}
}
}

Expand Down Expand Up @@ -1583,6 +1628,44 @@ void primitive_inst::do_runtime_in_place_concat() {
GPU_DEBUG_TRACE_DETAIL << "[In place concat] " << concat_inst->id() << ": can_be_optimized " << std::endl;
}

void primitive_inst::do_runtime_skip_scatter_update() {
OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("do_runtime_skip_scatter_update: " + id()));
// Check pattern
if (!(get_node().is_type<scatter_update>()
|| get_node().is_type<scatter_elements_update>()
|| get_node().is_type<scatter_nd_update>())
|| !get_node().can_be_optimized())
return;

GPU_DEBUG_TRACE_DETAIL << "[do_runtime_skip_scatter_update] " << id() << " : check optimizability" << std::endl;
auto input_layout = _impl_params->get_input_layout(0);
auto output_layout = _impl_params->get_output_layout();
auto idx_layout = _impl_params->get_input_layout(1);
auto update_layout = _impl_params->get_input_layout(2);

if (idx_layout.count() > 0 && update_layout.count() > 0) {
// set shape_change to realloc memory for same input shapes
if (can_be_optimized()) {
set_shape_change();
}
set_can_be_optimized(false);
GPU_DEBUG_TRACE_DETAIL << "--- Cannot optimize because idx_layout (" << idx_layout.to_short_string()
<< ") and update_layout(" << update_layout.to_short_string() << ") are not zero" << std::endl;
return;
}

GPU_DEBUG_TRACE_DETAIL << "[do_runtime_skip_scatter_update] " << id() << " : can_be_optimized" << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Input layout : " << _impl_params->get_input_layout(0).to_short_string() << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Idx layout : " << _impl_params->get_input_layout(1).to_short_string() << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Update layout : " << _impl_params->get_input_layout(2).to_short_string() << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Output layout : " << _impl_params->get_output_layout().to_short_string() << std::endl;
// set shape_change to realloc memory for same input shapes
if (!can_be_optimized()) {
set_shape_change();
}
set_can_be_optimized(true);
}

void primitive_inst::do_runtime_in_place_crop() {
OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("do_runtime_in_place_crop: " + id()));
GPU_DEBUG_GET_INSTANCE(debug_config);
Expand Down Expand Up @@ -1712,6 +1795,7 @@ event::ptr primitive_inst::execute(const std::vector<event::ptr>& events) {
do_runtime_skip_permute();
do_runtime_skip_strided_slice();
do_runtime_skip_broadcast();
do_runtime_skip_scatter_update();
do_runtime_in_place_crop();

if (!is_valid_fusion()) {
Expand Down
14 changes: 8 additions & 6 deletions src/plugins/intel_gpu/src/graph/scatter_elements_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,18 @@ std::string scatter_elements_update_inst::to_string(scatter_elements_update_node
return primitive_description.str();
}

scatter_elements_update_inst::typed_primitive_inst(network& network, scatter_elements_update_node const& node) : parent(network, node) {}
void scatter_elements_update_inst::on_execute() {
auto input1_shape = _impl_params->input_layouts[1].get_partial_shape();
auto input2_shape = _impl_params->input_layouts[2].get_partial_shape();
scatter_elements_update_inst::typed_primitive_inst(network& network, scatter_elements_update_node const& node) : parent(network, node) {
update_output_memory();
}

if ((ov::shape_size(input1_shape.to_shape()) == 0) || (ov::shape_size(input2_shape.to_shape()) == 0))
update_output_memory();
void scatter_elements_update_inst::on_execute() {
update_output_memory();
}

void scatter_elements_update_inst::update_output_memory() {
if (!can_be_optimized() || _impl_params->is_dynamic())
return;

if (_outputs.size() > 0 && static_cast<bool>(_outputs[0])
&& _network.get_engine().is_the_same_buffer(output_memory(), input_memory()))
return;
Expand Down
14 changes: 7 additions & 7 deletions src/plugins/intel_gpu/src/graph/scatter_nd_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,18 @@ std::string scatter_nd_update_inst::to_string(scatter_nd_update_node const& node
return primitive_description.str();
}

scatter_nd_update_inst::typed_primitive_inst(network& network, scatter_nd_update_node const& node) : parent(network, node) {}
scatter_nd_update_inst::typed_primitive_inst(network& network, scatter_nd_update_node const& node) : parent(network, node) {
update_output_memory();
}

void scatter_nd_update_inst::on_execute() {
auto input1_shape = _impl_params->input_layouts[1].get_partial_shape();
auto input2_shape = _impl_params->input_layouts[2].get_partial_shape();
auto same_layouts = _impl_params->input_layouts[0] == _impl_params->output_layouts[0];

if (same_layouts && ((ov::shape_size(input1_shape.to_shape()) == 0) || (ov::shape_size(input2_shape.to_shape()) == 0)))
update_output_memory();
update_output_memory();
}

void scatter_nd_update_inst::update_output_memory() {
if (!can_be_optimized() || _impl_params->is_dynamic())
return;

if (_outputs.size() > 0 && static_cast<bool>(_outputs[0])
&& _network.get_engine().is_the_same_buffer(output_memory(), input_memory()))
return;
Expand Down
13 changes: 7 additions & 6 deletions src/plugins/intel_gpu/src/graph/scatter_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,18 @@ std::string scatter_update_inst::to_string(scatter_update_node const& node) {
return primitive_description.str();
}

scatter_update_inst::typed_primitive_inst(network& network, scatter_update_node const& node) : parent(network, node) {}
scatter_update_inst::typed_primitive_inst(network& network, scatter_update_node const& node) : parent(network, node) {
update_output_memory();
}

void scatter_update_inst::on_execute() {
auto input1_shape = _impl_params->input_layouts[1].get_partial_shape();
auto input2_shape = _impl_params->input_layouts[2].get_partial_shape();

if ((ov::shape_size(input1_shape.to_shape()) == 0) || (ov::shape_size(input2_shape.to_shape()) == 0))
update_output_memory();
update_output_memory();
}

void scatter_update_inst::update_output_memory() {
if (!can_be_optimized() || _impl_params->is_dynamic())
return;

if (_outputs.size() > 0 && static_cast<bool>(_outputs[0])
&& _network.get_engine().is_the_same_buffer(output_memory(), input_memory()))
return;
Expand Down
Loading

0 comments on commit 602f701

Please sign in to comment.