diff --git a/src/plugins/intel_gpu/include/intel_gpu/graph/kernel_impl_params.hpp b/src/plugins/intel_gpu/include/intel_gpu/graph/kernel_impl_params.hpp index 72623f6d120955..24b13e740be273 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/graph/kernel_impl_params.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/graph/kernel_impl_params.hpp @@ -38,6 +38,7 @@ struct kernel_impl_params final { std::shared_ptr desc; size_t unique_id; bool _can_be_optimized = false; + bool _runtime_skippable = false; std::vector input_layouts; std::vector output_layouts; std::vector input_offsets; @@ -145,6 +146,10 @@ struct kernel_impl_params final { return _can_be_optimized; } + bool runtime_skippable() const { + return _runtime_skippable; + } + template std::shared_ptr typed_desc() const { return std::static_pointer_cast(desc); } diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/mark_runtime_skippable_nodes.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/mark_runtime_skippable_nodes.cpp index d9fab79d76ab2e..6fc98f5023d761 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/mark_runtime_skippable_nodes.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/mark_runtime_skippable_nodes.cpp @@ -14,6 +14,9 @@ #include "non_zero_inst.h" #include "non_max_suppression_inst.h" #include "unique_inst.hpp" +#include "scatter_elements_update_inst.h" +#include "scatter_update_inst.h" +#include "scatter_nd_update_inst.h" #include "program_helpers.h" using namespace cldnn; @@ -201,5 +204,56 @@ void mark_runtime_skippable_nodes::run(program& p) { GPU_DEBUG_TRACE_DETAIL << "[mark_runtime_skippable_nodes] : " << node.id() << " can_be_optimized" << std::endl; } }); + + program_helpers::do_for_types(*node, [](scatter_elements_update_node & node){ + auto impl_params = node.get_kernel_impl_params(); + + if ((node.is_output() && node.get_dependency(0).is_input()) + || node.has_fused_primitives() + || (impl_params->get_input_layout(0).format != impl_params->get_output_layout().format) + || (impl_params->get_input_layout(0).data_type != impl_params->get_output_layout().data_type)) + return; + + if (node.is_dynamic()) { + node.can_be_optimized(true); + // Set runtime skippable only when the node is set as can_be_optimized finally. + node.set_runtime_skippable(true); + GPU_DEBUG_TRACE_DETAIL << "[mark_runtime_skippable_nodes] : " << node.id() << " can_be_optimized" << std::endl; + } + }); + + program_helpers::do_for_types(*node, [](scatter_update_node & node){ + auto impl_params = node.get_kernel_impl_params(); + + if ((node.is_output() && node.get_dependency(0).is_input()) + || node.has_fused_primitives() + || (impl_params->get_input_layout(0).format != impl_params->get_output_layout().format) + || (impl_params->get_input_layout(0).data_type != impl_params->get_output_layout().data_type)) + return; + + if (node.is_dynamic()) { + node.can_be_optimized(true); + // Set runtime skippable only when the node is set as can_be_optimized finally. + node.set_runtime_skippable(true); + GPU_DEBUG_TRACE_DETAIL << "[mark_runtime_skippable_nodes] : " << node.id() << " can_be_optimized" << std::endl; + } + }); + + program_helpers::do_for_types(*node, [](scatter_nd_update_node & node){ + auto impl_params = node.get_kernel_impl_params(); + + if ((node.is_output() && node.get_dependency(0).is_input()) + || node.has_fused_primitives() + || (impl_params->get_input_layout(0).format != impl_params->get_output_layout().format) + || (impl_params->get_input_layout(0).data_type != impl_params->get_output_layout().data_type)) + return; + + if (node.is_dynamic()) { + node.can_be_optimized(true); + // Set runtime skippable only when the node is set as can_be_optimized finally. + node.set_runtime_skippable(true); + GPU_DEBUG_TRACE_DETAIL << "[mark_runtime_skippable_nodes] : " << node.id() << " can_be_optimized" << std::endl; + } + }); } } diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/primitive_base.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl/primitive_base.hpp index 829cd23d0908f5..94d94f29846613 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/primitive_base.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/primitive_base.hpp @@ -22,6 +22,9 @@ #include "permute_inst.h" #include "strided_slice_inst.h" #include "broadcast_inst.h" +#include "scatter_update_inst.h" +#include "scatter_elements_update_inst.h" +#include "scatter_nd_update_inst.h" #include #include @@ -88,12 +91,8 @@ struct typed_primitive_impl_ocl : public typed_primitive_impl { static std::unique_ptr create(const typed_program_node& arg, const kernel_impl_params& impl_param) { // concat buffer fusing for dynamic shape is adaptively applied at runtime. So we need to build dynamic impl at build time. if (impl_param.can_be_optimized() && - !((impl_param.is_type() || - impl_param.is_type() || - impl_param.is_type() || - impl_param.is_type() || - impl_param.is_type() || - impl_param.is_type()) && impl_param.is_dynamic())) { + !((impl_param.runtime_skippable() || impl_param.is_type()) && + impl_param.is_dynamic())) { return make_unique(kernel_selector::kernel_data{}); } auto kernel_params = ImplType::get_kernel_params(ImplType::static_canonicalize_shapes(impl_param)); diff --git a/src/plugins/intel_gpu/src/graph/include/primitive_inst.h b/src/plugins/intel_gpu/src/graph/include/primitive_inst.h index 6899812ab28097..430d2652bb9293 100644 --- a/src/plugins/intel_gpu/src/graph/include/primitive_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/primitive_inst.h @@ -261,6 +261,7 @@ class primitive_inst { void do_runtime_in_place_concat(); void do_runtime_in_place_kv_cache(); void do_runtime_in_place_crop(); + void do_runtime_skip_scatter_update(); void configure_shape_of_dependencies(); memory::ptr fused_memory(size_t dep_id) const { diff --git a/src/plugins/intel_gpu/src/graph/include/program_node.h b/src/plugins/intel_gpu/src/graph/include/program_node.h index 8105a8bc07dec3..d1bbaa8a34cb8f 100644 --- a/src/plugins/intel_gpu/src/graph/include/program_node.h +++ b/src/plugins/intel_gpu/src/graph/include/program_node.h @@ -131,6 +131,7 @@ struct program_node { get_unique_id(), in_layouts, out_layouts, get_fused_primitives())); params->memory_deps = get_const_memory_deps(); params->_can_be_optimized = this->optimized; + params->_runtime_skippable = this->runtime_skippable; params->in_port_to_shape_info_offset = get_input_port_to_shape_info_offset_map(); params->out_port_to_shape_info_offset = get_output_port_to_shape_info_offset_map(); auto deps = get_dependencies(); diff --git a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp index 0d0748311caa08..c92c5854a8199e 100644 --- a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp +++ b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp @@ -25,6 +25,9 @@ #include "shape_of_inst.h" #include "softmax_inst.h" #include "strided_slice_inst.h" +#include "scatter_elements_update_inst.h" +#include "scatter_nd_update_inst.h" +#include "scatter_update_inst.h" #include "gemm_inst.h" #include "assign_inst.h" #include "read_value_inst.h" @@ -737,15 +740,57 @@ event::ptr primitive_inst::realloc_if_needed() { // Clear out memory if was previously reused, but now primitive can't be optimized if (!_node->is_type() && (_node->is_runtime_skippable() || _node->is_type())) { + std::function reset_user_output_memory; + reset_user_output_memory = [&](cldnn::primitive_inst* curr_inst, cldnn::memory::ptr input_mem_ptr) { + auto curr_output_memory_ptr = curr_inst->output_memory_ptr(0); + if (curr_inst->can_be_optimized() + && (curr_output_memory_ptr + && get_network().get_engine().is_the_same_buffer(*curr_output_memory_ptr, *input_mem_ptr))) { + if (curr_inst->mem_allocated()) { + get_network().get_memory_pool().release_memory(curr_inst->_outputs[0].get(), + curr_inst->get_node().get_unique_id(), curr_inst->id(), get_network_id()); + _mem_allocated = false; + } + curr_inst->_outputs[0] = nullptr; + curr_inst->_max_output_layout_count[0] = 0; + for (auto& user_inst : curr_inst->get_user_insts()) { + reset_user_output_memory(user_inst, input_mem_ptr); + } + } + }; if (can_be_optimized()) { _max_output_layout_count = _deps[0].first->_max_output_layout_count; GPU_DEBUG_PROFILED_STAGE_MEMALLOC_INFO("can_be_optimized"); + // If the inst is optimized out but it executed at the previous iteration, + // reset all output memory of users which was optimized out at the previous iteration. + // Ex. + // * iter0: node1(executed) -> node2(skipped) -> node3(skipped) + // * iter1: node1(skipped) -> node2(skipped) -> node3(executed) + if (_outputs[0] && dep_memory_ptr(0) + && !_network.get_engine().is_the_same_buffer(dep_memory(0), output_memory(0))) { + for (auto& user_inst : get_user_insts()) { + reset_user_output_memory(user_inst, dep_memory_ptr(0)); + } + } return ev; } else if (_outputs[0] && dep_memory_ptr(0) && _network.get_engine().is_the_same_buffer(dep_memory(0), output_memory(0))) { - // Clear out memory if was previously reused, but now primitive can't be optimized + if (mem_allocated()) { + get_network().get_memory_pool().release_memory(_outputs[0].get(), + get_node().get_unique_id(), id(), get_network_id()); + _mem_allocated = false; + } _outputs[0] = nullptr; _max_output_layout_count[0] = 0; + // Check users recursively and if the users is can_be_optimized && runtime_skippable + // && output_memory of user is same as current input memory, + // then reset the users output memory too. + // Ex. + // * iter0: node1(skipped) -> node2(skipped) -> node3(skipped) + // * iter1: node1(executed) -> node2(skipped) -> node3(executed) + for (auto& user_inst : get_user_insts()) { + reset_user_output_memory(user_inst, dep_memory_ptr(0)); + } } } @@ -1583,6 +1628,44 @@ void primitive_inst::do_runtime_in_place_concat() { GPU_DEBUG_TRACE_DETAIL << "[In place concat] " << concat_inst->id() << ": can_be_optimized " << std::endl; } +void primitive_inst::do_runtime_skip_scatter_update() { + OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("do_runtime_skip_scatter_update: " + id())); + // Check pattern + if (!(get_node().is_type() + || get_node().is_type() + || get_node().is_type()) + || !get_node().can_be_optimized()) + return; + + GPU_DEBUG_TRACE_DETAIL << "[do_runtime_skip_scatter_update] " << id() << " : check optimizability" << std::endl; + auto input_layout = _impl_params->get_input_layout(0); + auto output_layout = _impl_params->get_output_layout(); + auto idx_layout = _impl_params->get_input_layout(1); + auto update_layout = _impl_params->get_input_layout(2); + + if (idx_layout.count() > 0 && update_layout.count() > 0) { + // set shape_change to realloc memory for same input shapes + if (can_be_optimized()) { + set_shape_change(); + } + set_can_be_optimized(false); + GPU_DEBUG_TRACE_DETAIL << "--- Cannot optimize because idx_layout (" << idx_layout.to_short_string() + << ") and update_layout(" << update_layout.to_short_string() << ") are not zero" << std::endl; + return; + } + + GPU_DEBUG_TRACE_DETAIL << "[do_runtime_skip_scatter_update] " << id() << " : can_be_optimized" << std::endl; + GPU_DEBUG_TRACE_DETAIL << " - Input layout : " << _impl_params->get_input_layout(0).to_short_string() << std::endl; + GPU_DEBUG_TRACE_DETAIL << " - Idx layout : " << _impl_params->get_input_layout(1).to_short_string() << std::endl; + GPU_DEBUG_TRACE_DETAIL << " - Update layout : " << _impl_params->get_input_layout(2).to_short_string() << std::endl; + GPU_DEBUG_TRACE_DETAIL << " - Output layout : " << _impl_params->get_output_layout().to_short_string() << std::endl; + // set shape_change to realloc memory for same input shapes + if (!can_be_optimized()) { + set_shape_change(); + } + set_can_be_optimized(true); +} + void primitive_inst::do_runtime_in_place_crop() { OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("do_runtime_in_place_crop: " + id())); GPU_DEBUG_GET_INSTANCE(debug_config); @@ -1712,6 +1795,7 @@ event::ptr primitive_inst::execute(const std::vector& events) { do_runtime_skip_permute(); do_runtime_skip_strided_slice(); do_runtime_skip_broadcast(); + do_runtime_skip_scatter_update(); do_runtime_in_place_crop(); if (!is_valid_fusion()) { diff --git a/src/plugins/intel_gpu/src/graph/scatter_elements_update.cpp b/src/plugins/intel_gpu/src/graph/scatter_elements_update.cpp index 1e901f991ddf55..ee8850fbd46220 100644 --- a/src/plugins/intel_gpu/src/graph/scatter_elements_update.cpp +++ b/src/plugins/intel_gpu/src/graph/scatter_elements_update.cpp @@ -53,16 +53,18 @@ std::string scatter_elements_update_inst::to_string(scatter_elements_update_node return primitive_description.str(); } -scatter_elements_update_inst::typed_primitive_inst(network& network, scatter_elements_update_node const& node) : parent(network, node) {} -void scatter_elements_update_inst::on_execute() { - auto input1_shape = _impl_params->input_layouts[1].get_partial_shape(); - auto input2_shape = _impl_params->input_layouts[2].get_partial_shape(); +scatter_elements_update_inst::typed_primitive_inst(network& network, scatter_elements_update_node const& node) : parent(network, node) { + update_output_memory(); +} - if ((ov::shape_size(input1_shape.to_shape()) == 0) || (ov::shape_size(input2_shape.to_shape()) == 0)) - update_output_memory(); +void scatter_elements_update_inst::on_execute() { + update_output_memory(); } void scatter_elements_update_inst::update_output_memory() { + if (!can_be_optimized() || _impl_params->is_dynamic()) + return; + if (_outputs.size() > 0 && static_cast(_outputs[0]) && _network.get_engine().is_the_same_buffer(output_memory(), input_memory())) return; diff --git a/src/plugins/intel_gpu/src/graph/scatter_nd_update.cpp b/src/plugins/intel_gpu/src/graph/scatter_nd_update.cpp index 71ef4d0520608d..ba0cea2e32299e 100644 --- a/src/plugins/intel_gpu/src/graph/scatter_nd_update.cpp +++ b/src/plugins/intel_gpu/src/graph/scatter_nd_update.cpp @@ -64,18 +64,18 @@ std::string scatter_nd_update_inst::to_string(scatter_nd_update_node const& node return primitive_description.str(); } -scatter_nd_update_inst::typed_primitive_inst(network& network, scatter_nd_update_node const& node) : parent(network, node) {} +scatter_nd_update_inst::typed_primitive_inst(network& network, scatter_nd_update_node const& node) : parent(network, node) { + update_output_memory(); +} void scatter_nd_update_inst::on_execute() { - auto input1_shape = _impl_params->input_layouts[1].get_partial_shape(); - auto input2_shape = _impl_params->input_layouts[2].get_partial_shape(); - auto same_layouts = _impl_params->input_layouts[0] == _impl_params->output_layouts[0]; - - if (same_layouts && ((ov::shape_size(input1_shape.to_shape()) == 0) || (ov::shape_size(input2_shape.to_shape()) == 0))) - update_output_memory(); + update_output_memory(); } void scatter_nd_update_inst::update_output_memory() { + if (!can_be_optimized() || _impl_params->is_dynamic()) + return; + if (_outputs.size() > 0 && static_cast(_outputs[0]) && _network.get_engine().is_the_same_buffer(output_memory(), input_memory())) return; diff --git a/src/plugins/intel_gpu/src/graph/scatter_update.cpp b/src/plugins/intel_gpu/src/graph/scatter_update.cpp index 8f9099572fa968..8d10f9ad2b4fd7 100644 --- a/src/plugins/intel_gpu/src/graph/scatter_update.cpp +++ b/src/plugins/intel_gpu/src/graph/scatter_update.cpp @@ -44,17 +44,18 @@ std::string scatter_update_inst::to_string(scatter_update_node const& node) { return primitive_description.str(); } -scatter_update_inst::typed_primitive_inst(network& network, scatter_update_node const& node) : parent(network, node) {} +scatter_update_inst::typed_primitive_inst(network& network, scatter_update_node const& node) : parent(network, node) { + update_output_memory(); +} void scatter_update_inst::on_execute() { - auto input1_shape = _impl_params->input_layouts[1].get_partial_shape(); - auto input2_shape = _impl_params->input_layouts[2].get_partial_shape(); - - if ((ov::shape_size(input1_shape.to_shape()) == 0) || (ov::shape_size(input2_shape.to_shape()) == 0)) - update_output_memory(); + update_output_memory(); } void scatter_update_inst::update_output_memory() { + if (!can_be_optimized() || _impl_params->is_dynamic()) + return; + if (_outputs.size() > 0 && static_cast(_outputs[0]) && _network.get_engine().is_the_same_buffer(output_memory(), input_memory())) return; diff --git a/src/plugins/intel_gpu/tests/unit/dynamic_execution/skip_scatter_update_at_runtime.cpp b/src/plugins/intel_gpu/tests/unit/dynamic_execution/skip_scatter_update_at_runtime.cpp new file mode 100644 index 00000000000000..44572bd7980131 --- /dev/null +++ b/src/plugins/intel_gpu/tests/unit/dynamic_execution/skip_scatter_update_at_runtime.cpp @@ -0,0 +1,184 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test_utils.h" + +#include +#include +#include +#include +#include +#include + +#include "permute_inst.h" +#include "program_wrapper.h" + +#include +#include + +using namespace cldnn; +using namespace ::tests; + +namespace skip_scatter_update_tests { + +enum scatter_update_type { + ScatterUpdate = 0, + ScatterNDUpdate = 1, + ScatterElementsUpdate = 2 +}; + +struct skip_scatter_update_params { + scatter_update_type scatter_type; + bool scatter_update_01_skipped; + bool scatter_update_02_skipped; +}; + +class skip_scatter_update_at_runtime_test : public testing::TestWithParam {}; + +TEST_P(skip_scatter_update_at_runtime_test, runtime_skip) { + auto p = GetParam(); + auto& engine = get_test_engine(); + + auto input_layout_static = layout{ov::PartialShape{1,16}, data_types::f16, format::bfyx}; + auto rank = input_layout_static.get_partial_shape().size(); + auto input_layout_dynamic = layout {ov::PartialShape::dynamic(rank), data_types::f16, format::get_default_format(rank)}; + + auto idx1_nonzero_layout = layout{ov::PartialShape{1,16}, data_types::f16, format::bfyx}; + auto idx1_zero_layout = layout{ov::PartialShape{0,16}, data_types::f16, format::bfyx}; + auto update1_nonzero_layout = layout{ov::PartialShape{1,16}, data_types::f16, format::bfyx}; + auto update1_zero_layout = layout{ov::PartialShape{0,16}, data_types::f16, format::bfyx}; + + auto idx2_nonzero_layout = layout{ov::PartialShape{1,16}, data_types::f16, format::bfyx}; + auto idx2_zero_layout = layout{ov::PartialShape{0,16}, data_types::f16, format::bfyx}; + auto update2_nonzero_layout = layout{ov::PartialShape{1,16}, data_types::f16, format::bfyx}; + auto update2_zero_layout = layout{ov::PartialShape{0,16}, data_types::f16, format::bfyx}; + + ExecutionConfig config = get_test_default_config(engine); + config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); + config.set_property(ov::intel_gpu::optimize_data(true)); + + cldnn::network::ptr network = nullptr; + + if (p.scatter_type == scatter_update_type::ScatterElementsUpdate) { + topology topology(input_layout("input", input_layout_dynamic), + input_layout("idx1", input_layout_dynamic), + input_layout("idx2", input_layout_dynamic), + input_layout("update1", input_layout_dynamic), + input_layout("update2", input_layout_dynamic), + scatter_elements_update("scatter1", input_info("input"), input_info("idx1"), input_info("update1"), 0), + scatter_elements_update("scatter2", input_info("scatter1"), input_info("idx2"), input_info("update2"), 0), + reorder("reorder", input_info("scatter2"), format::get_default_format(rank), data_types::f32)); + + network = get_network(engine, topology, config, get_test_stream_ptr(), false); + } else if (p.scatter_type == scatter_update_type::ScatterUpdate) { + topology topology(input_layout("input", input_layout_dynamic), + input_layout("idx1", input_layout_dynamic), + input_layout("idx2", input_layout_dynamic), + input_layout("update1", input_layout_dynamic), + input_layout("update2", input_layout_dynamic), + scatter_update("scatter1", input_info("input"), input_info("idx1"), input_info("update1"), 0), + scatter_update("scatter2", input_info("scatter1"), input_info("idx2"), input_info("update2"), 0), + reorder("reorder", input_info("scatter2"), format::get_default_format(rank), data_types::f32)); + + network = get_network(engine, topology, config, get_test_stream_ptr(), false); + } else if (p.scatter_type == scatter_update_type::ScatterNDUpdate) { + input_layout_static = layout{ov::PartialShape{12}, data_types::f16, format::bfyx}; + rank = input_layout_static.get_partial_shape().size(); + input_layout_dynamic = layout {ov::PartialShape::dynamic(rank), data_types::f16, format::get_default_format(rank)}; + + idx1_nonzero_layout = layout{ov::PartialShape{12,1}, data_types::f16, format::bfyx}; + idx1_zero_layout = layout{ov::PartialShape{0,1}, data_types::f16, format::bfyx}; + update1_nonzero_layout = layout{ov::PartialShape{12}, data_types::f16, format::bfyx}; + update1_zero_layout = layout{ov::PartialShape{0}, data_types::f16, format::bfyx}; + rank = idx1_nonzero_layout.get_partial_shape().size(); + auto idx_layout_dynamic = layout {ov::PartialShape::dynamic(rank), data_types::f16, format::get_default_format(rank)}; + + idx2_nonzero_layout = layout{ov::PartialShape{12,1}, data_types::f16, format::bfyx}; + idx2_zero_layout = layout{ov::PartialShape{0,1}, data_types::f16, format::bfyx}; + update2_nonzero_layout = layout{ov::PartialShape{12}, data_types::f16, format::bfyx}; + update2_zero_layout = layout{ov::PartialShape{0}, data_types::f16, format::bfyx}; + + + + topology topology(input_layout("input", input_layout_dynamic), + input_layout("idx1", idx_layout_dynamic), + input_layout("idx2", idx_layout_dynamic), + input_layout("update1", input_layout_dynamic), + input_layout("update2", input_layout_dynamic), + scatter_nd_update("scatter1", input_info("input"), input_info("idx1"), input_info("update1"), 2), + scatter_nd_update("scatter2", input_info("scatter1"), input_info("idx2"), input_info("update2"), 2), + reorder("reorder", input_info("scatter2"), format::get_default_format(rank), data_types::f32)); + + network = get_network(engine, topology, config, get_test_stream_ptr(), false); + } + + auto input_mem = engine.allocate_memory(input_layout_static); + + auto idx1_layout_static = p.scatter_update_01_skipped? idx1_zero_layout : idx1_nonzero_layout; + auto update1_layout_static = p.scatter_update_01_skipped? update1_zero_layout : update1_nonzero_layout; + + auto idx1_mem = engine.allocate_memory(idx1_nonzero_layout); + auto update1_mem = engine.allocate_memory(update1_nonzero_layout); + + if (p.scatter_update_01_skipped) { + idx1_mem = engine.reinterpret_buffer(*idx1_mem, idx1_zero_layout); + update1_mem = engine.reinterpret_buffer(*update1_mem, update1_zero_layout); + } + + auto idx2_layout_static = p.scatter_update_02_skipped? idx2_zero_layout : idx2_nonzero_layout; + auto update2_layout_static = p.scatter_update_02_skipped? update2_zero_layout : update2_nonzero_layout; + + auto idx2_mem = engine.allocate_memory(idx2_nonzero_layout); + auto update2_mem = engine.allocate_memory(update2_nonzero_layout); + if (p.scatter_update_02_skipped) { + idx2_mem = engine.reinterpret_buffer(*idx2_mem, idx2_zero_layout); + update2_mem = engine.reinterpret_buffer(*update2_mem, update2_zero_layout); + } + network->set_input_data("input", input_mem); + network->set_input_data("idx1", idx1_mem); + network->set_input_data("idx2", idx2_mem); + network->set_input_data("update1", update1_mem); + network->set_input_data("update2", update2_mem); + auto outputs = network->execute(); + outputs.begin()->second.get_memory(); + + auto input_inst = network->get_primitive("input"); + auto scatter1_inst = network->get_primitive("scatter1"); + auto scatter2_inst = network->get_primitive("scatter2"); + + ASSERT_EQ(scatter1_inst->can_be_optimized(), p.scatter_update_01_skipped); + ASSERT_EQ(scatter2_inst->can_be_optimized(), p.scatter_update_02_skipped); + + if (scatter1_inst->can_be_optimized()) { + ASSERT_TRUE(engine.is_the_same_buffer(scatter1_inst->dep_memory(0), scatter1_inst->output_memory(0))); + } else { + ASSERT_FALSE(engine.is_the_same_buffer(scatter1_inst->dep_memory(0), scatter1_inst->output_memory(0))); + } + + if (scatter2_inst->can_be_optimized()) { + ASSERT_TRUE(engine.is_the_same_buffer(scatter2_inst->dep_memory(0), scatter2_inst->output_memory(0))); + } else { + ASSERT_FALSE(engine.is_the_same_buffer(scatter2_inst->dep_memory(0), scatter2_inst->output_memory(0))); + } +} + +INSTANTIATE_TEST_SUITE_P(smoke, skip_scatter_update_at_runtime_test, + testing::ValuesIn(std::vector { + { scatter_update_type::ScatterUpdate, true, true }, + { scatter_update_type::ScatterUpdate, true, false}, + { scatter_update_type::ScatterUpdate, false, true }, + { scatter_update_type::ScatterUpdate, false, false}, + + { scatter_update_type::ScatterNDUpdate, true, true }, + { scatter_update_type::ScatterNDUpdate, true, false}, + { scatter_update_type::ScatterNDUpdate, false, true }, + { scatter_update_type::ScatterNDUpdate, false, false}, + + { scatter_update_type::ScatterElementsUpdate, true, true }, + { scatter_update_type::ScatterElementsUpdate, true, false}, + { scatter_update_type::ScatterElementsUpdate, false, true }, + { scatter_update_type::ScatterElementsUpdate, false, false}, + + })); +} // skip permute tests