From df6a25800d3bdfa6a11b675ae908db89e05a342c Mon Sep 17 00:00:00 2001 From: Roman Lyamin Date: Mon, 3 Jun 2024 14:34:13 +0400 Subject: [PATCH] [GPU] Added RoPE support for ChatGLM and Qwen (#24756) ### Details: - Added support RoPE for ChatGLM and Qwen models - Moved and refactored RoPE functional tests ### Tickets: - *[119150](https://jira.devtools.intel.com/browse/CVS-119150)* --- .../transformations/utils/gen_pattern.hpp | 12 +- .../fuse_rotary_positional_embeddings.cpp | 15 +- .../subgraph_tests/src/rotary_pos_emb.cpp | 621 ------------------ .../skip_tests_config.cpp | 2 + .../subgraph_tests/rotary_pos_emb.cpp | 32 + .../intel_gpu/plugin/primitives_list.hpp | 1 + .../include/intel_gpu/primitives/rope.hpp | 92 +++ .../src/graph/impls/ocl/register.cpp | 1 + .../src/graph/impls/ocl/register.hpp | 2 + .../intel_gpu/src/graph/impls/ocl/rope.cpp | 88 +++ .../intel_gpu/src/graph/include/rope_inst.h | 39 ++ src/plugins/intel_gpu/src/graph/rope.cpp | 76 +++ .../kernel_selector/cl_kernels/rope_ref.cl | 86 +++ .../src/kernel_selector/common_types.h | 1 + .../kernels/rope/rope_kernel_base.cpp | 114 ++++ .../kernels/rope/rope_kernel_base.h | 45 ++ .../kernels/rope/rope_kernel_ref.cpp | 34 + .../kernels/rope/rope_kernel_ref.h | 20 + .../kernels/rope/rope_kernel_selector.cpp | 16 + .../kernels/rope/rope_kernel_selector.h | 23 + src/plugins/intel_gpu/src/plugin/graph.cpp | 1 + src/plugins/intel_gpu/src/plugin/ops/rope.cpp | 37 ++ .../src/plugin/transformations_pipeline.cpp | 12 +- .../skip_tests_config.cpp | 1 + .../subgraph_tests/rotary_pos_emb.cpp | 22 + .../include/subgraph_tests/rotary_pos_emb.hpp | 56 ++ .../subgraph/rotary_pos_emb.hpp | 67 ++ .../src/subgraph/rotary_pos_emb.cpp | 590 +++++++++++++++++ 28 files changed, 1475 insertions(+), 631 deletions(-) delete mode 100644 src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/rotary_pos_emb.cpp create mode 100644 src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp create mode 100644 src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp create mode 100644 src/plugins/intel_gpu/src/graph/include/rope_inst.h create mode 100644 src/plugins/intel_gpu/src/graph/rope.cpp create mode 100644 src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.h create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.cpp create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.h create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_selector.cpp create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_selector.h create mode 100644 src/plugins/intel_gpu/src/plugin/ops/rope.cpp create mode 100644 src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp create mode 100644 src/tests/functional/plugin/shared/include/subgraph_tests/rotary_pos_emb.hpp create mode 100644 src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/rotary_pos_emb.hpp create mode 100644 src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp diff --git a/src/common/transformations/include/transformations/utils/gen_pattern.hpp b/src/common/transformations/include/transformations/utils/gen_pattern.hpp index 711a96f460c86d..4aa2e7484bed4e 100644 --- a/src/common/transformations/include/transformations/utils/gen_pattern.hpp +++ b/src/common/transformations/include/transformations/utils/gen_pattern.hpp @@ -1212,7 +1212,8 @@ class PatternValidator { return false; } - if (ele_type == ov::element::i32 || ele_type == ov::element::f32 || ele_type == ov::element::i64) { + if (ele_type == ov::element::i32 || ele_type == ov::element::i64 || ele_type == ov::element::f16 || + ele_type == ov::element::f32) { auto observed = constop->cast_vector(); for (size_t i = 0; i < symbols.size(); i++) detail::add_symbol_observed(sov, symbols[i], observed[i]); @@ -1259,6 +1260,15 @@ class PatternValidator { } } + if (pconst_node->get_output_element_type(0).is_real() && + vconst_node->get_output_element_type(0).is_real()) { + auto p_values = pconst_node->cast_vector(); + auto v_values = vconst_node->cast_vector(); + if (p_values == v_values) { + continue; + } + } + _VERBOSE_LOG("expecting Constant of type ", pconst_node->get_output_element_type(0), " but got ", diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index 5516c20266a542..2f3ddd5d843ae3 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -9,6 +9,7 @@ #include "itt.hpp" #include "openvino/core/rt_info.hpp" +#include "openvino/op/util/shape_of_base.hpp" #include "openvino/opsets/opset1.hpp" #include "openvino/opsets/opset6.hpp" #include "openvino/opsets/opset8.hpp" @@ -415,9 +416,9 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id) { MATCHER_SCOPE(RoPEFusionChatGLM); - auto qkv_linear = makePattern("f32[?,?,?]"); // f32[seq_length, batch_size, 4608] + auto qkv_linear = makePattern("[?,?,?]"); // [seq_length, batch_size, 4608] auto seq_length = makePattern("i32[1]"); - auto cos_sin_cache = makePattern("f32[?,?,?,?]"); // [max_pos_embeddings, batch_size, 32, 2] + auto cos_sin_cache = makePattern("[?,?,?,?]"); // [max_pos_embeddings, batch_size, 32, 2] auto ndims = ov::gen_pattern::Symbol("ndims"); auto head_cnt = ov::gen_pattern::Symbol("head_cnt"); @@ -538,9 +539,9 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { MATCHER_SCOPE(RoPEFusionQwen); // rotary_emb_cos & rotary_emb_sin are sliced by present kv-length (past-kv-length + cur_len) - auto rotary_emb_cos = makePattern("f32[1,?,1,?]"); // [1,..4096,1,128] - auto rotary_emb_sin = makePattern("f32[1,?,1,?]"); // [1,..4096,1,128] - auto qkv_proj = makePattern("f32[?,?,?]"); // f32[?,?,12288] + auto rotary_emb_cos = makePattern("[1,?,1,?]"); // [1,..4096,1,128] + auto rotary_emb_sin = makePattern("[1,?,1,?]"); // [1,..4096,1,128] + auto qkv_proj = makePattern("[?,?,?]"); // [?,?,12288] auto head_cnt = ov::gen_pattern::Symbol("head_cnt"); auto head_size = ov::gen_pattern::Symbol("head_size"); @@ -559,8 +560,8 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { auto Multiply_567524 = makePattern({ShapeOf_485735, {-1}}, {{"auto_broadcast", "numpy"}}); auto Gather_377635 = makePattern({Multiply_567524, {1}, 0}, {{"batch_dims", 0}}); - auto input_ids = makePattern("i32[?,?]"); // [batch, length] - auto ShapeOf_409241 = makePattern({input_ids}, {}); + auto input_ids = makePattern(); // [batch, length] + auto ShapeOf_409241 = makePattern({input_ids}, {}); auto Gather_311651 = makePattern({ShapeOf_409241, {1}, 0}, {{"batch_dims", 0}}); auto neg_Multiply = makePattern({Gather_311651, {-1}}, {{"auto_broadcast", "numpy"}}); diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/rotary_pos_emb.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/rotary_pos_emb.cpp deleted file mode 100644 index a505a010a20910..00000000000000 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/rotary_pos_emb.cpp +++ /dev/null @@ -1,621 +0,0 @@ -// Copyright (C) 2018-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include -#include -#include -#include -#include - -#include "common_test_utils/common_utils.hpp" -#include "shared_test_classes/base/ov_subgraph.hpp" -#include "utils/cpu_test_utils.hpp" -#include "utils/fusing_test_utils.hpp" -#include "transformations/utils/gen_pattern.hpp" - -using namespace CPUTestUtils; -using namespace ov::gen_pattern; -using namespace ov; - -namespace ov { -namespace test { - -static ov::OutputVector makeCosSinCache(int max_position_embeddings, int rotary_ndims) { - std::vector lut_sin(max_position_embeddings * rotary_ndims, 0.0f); - std::vector lut_cos(max_position_embeddings * rotary_ndims, 0.0f); - - // rotate_half style cos/sin table: - // y1 = cos(m*xita_i) * x1 - sin(m*xita_i) * x2 - // y2 = cos(m*xita_i) * x2 + sin(m*xita_i) * x1 - // - for (int i = 0, k = 0; i < rotary_ndims; i += 2, k++) { - auto xita_i = 1.0 / std::pow(10000.0, static_cast(i) / rotary_ndims); - float* psin = lut_sin.data(); - float* pcos = lut_cos.data(); - for (int m = 0; m < max_position_embeddings; m++, psin += rotary_ndims, pcos += rotary_ndims) { - auto vsin = std::sin(xita_i * m); - auto vcos = std::cos(xita_i * m); - pcos[k] = pcos[k + rotary_ndims / 2] = vcos; - psin[k] = psin[k + rotary_ndims / 2] = vsin; - } - } - auto shape = ov::Shape({1, 1, static_cast(max_position_embeddings), static_cast(rotary_ndims)}); - auto Cos = makeConst(ov::element::f32, shape, lut_cos); - auto Sin = makeConst(ov::element::f32, shape, lut_sin); - return {Cos, Sin}; -} - -static std::shared_ptr buildROPE_Llama2(const int batch, - const int seq_length, - const int max_position_embeddings, - const int num_head, - const int ndims) { - auto input = std::make_shared(ov::element::f32, PartialShape{batch, -1, num_head, ndims}); - auto pos_id_end = std::make_shared(ov::element::i32, ov::Shape{}); - auto pos_ids = std::make_shared(ov::element::i32, PartialShape{1, -1}); - - auto cos_sin_cache = makeCosSinCache(max_position_embeddings, ndims); - auto Constant582 = cos_sin_cache[0]; - auto Constant585 = cos_sin_cache[1]; - - // concat KV length - auto transpose_Transpose = makeOP({input, {0, 2, 1, 3}}); - auto slice_Unsqueeze_426 = makeOP({pos_id_end, 0}); - auto ScatterUpdate_152236 = makeOP({{0, 0, 0}, {2}, slice_Unsqueeze_426, {0}}); - auto slice_Slice = makeOP({Constant582, {0, 0, 0}, ScatterUpdate_152236, {1, 1, 1}}, - {{"begin_mask", {1, 1, 0}}, - {"end_mask", {1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto squeeze_Squeeze = makeOP({slice_Slice, 1}); - auto squeeze_Squeeze_435 = makeOP({squeeze_Squeeze, 0}); - auto index_441_Gather = makeOP({squeeze_Squeeze_435, pos_ids, 0}, {{"batch_dims", 0}}); - auto unsqueeze_Unsqueeze = makeOP({index_441_Gather, 1}); - auto mul_Multiply = - makeOP({transpose_Transpose, unsqueeze_Unsqueeze}, {{"auto_broadcast", "numpy"}}); - auto size_ShapeOf_448 = makeOP({transpose_Transpose}, {{"output_type", "i32"}}); - auto size_Gather_450 = makeOP({size_ShapeOf_448, 3, 0}, {{"batch_dims", 0}}); - auto floor_divide_Divide = - makeOP({size_Gather_450, 2}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); - auto floor_divide_Floor = makeOP({floor_divide_Divide}); - auto slice_Unsqueeze_452 = makeOP({floor_divide_Floor, 0}); - auto ScatterUpdate_152312 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_452, {0}}); - auto slice_Slice_459 = makeOP( - {transpose_Transpose, ScatterUpdate_152312, {0ll, 0ll, 0ll, LLONG_MAX}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Constant_182988 = makeConst(element::f32, - ov::Shape({ - 1, - 1, - 1, - 1, - }), - {-1.000000f}); - auto neg_Multiply = makeOP({slice_Slice_459, Constant_182988}, {{"auto_broadcast", "numpy"}}); - auto ScatterUpdate_152368 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_452, {0}}); - auto slice_Slice2 = - makeOP({transpose_Transpose, {0, 0, 0, 0}, ScatterUpdate_152368, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto cat_Concat = makeOP({neg_Multiply, slice_Slice2}, {{"axis", -1}}); - auto ScatterUpdate_152421 = makeOP({{0, 0, 0}, {2}, slice_Unsqueeze_426, {0}}); - auto slice_Slice_433 = makeOP({Constant585, {0, 0, 0}, ScatterUpdate_152421, {1, 1, 1}}, - {{"begin_mask", {1, 1, 0}}, - {"end_mask", {1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto squeeze_Squeeze_436 = makeOP({slice_Slice_433, 1}); - auto squeeze_Squeeze_437 = makeOP({squeeze_Squeeze_436, 0}); - auto index_446_Gather = makeOP({squeeze_Squeeze_437, pos_ids, 0}, {{"batch_dims", 0}}); - auto unsqueeze_Unsqueeze_447 = makeOP({index_446_Gather, 1}); - auto mul_Multiply_463 = - makeOP({cat_Concat, unsqueeze_Unsqueeze_447}, {{"auto_broadcast", "numpy"}}); - auto add_Add = makeOP({mul_Multiply, mul_Multiply_463}, {{"auto_broadcast", "numpy"}}); - - return std::make_shared(ov::NodeVector{add_Add}, ov::ParameterVector{input, pos_id_end, pos_ids}); -} - -class RoPECPUTestLlama2 : public SubgraphBaseTest { -public: - ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1) { - auto tensor = ov::Tensor(ov::element::i32, shape); - auto* ptr = static_cast(tensor.data()); - for (size_t i = 0; i < tensor.get_size(); i++) { - ptr[i] = start; - start += step; - } - return tensor; - } - - void generate_inputs(const std::vector& targetInputStaticShapes) override { - const auto& funcInputs = function->inputs(); - - const int position_id_start = 15; - auto& input_shape = targetInputStaticShapes[0]; - auto seq_length = input_shape[1]; - - ov::test::utils::InputGenerateData in_data; - in_data.start_from = -1; - in_data.range = 2; - in_data.resolution = 32768; - ov::Tensor t_input = utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, in_data); - ov::Tensor t_position_id_end = create_i32_tensor(ov::Shape({}), position_id_start + seq_length); - ov::Tensor t_position_ids = create_i32_tensor(ov::Shape({1, seq_length}), position_id_start); - - inputs.clear(); - inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input}); - inputs.insert({funcInputs[1].get_node_shared_ptr(), t_position_id_end}); - inputs.insert({funcInputs[2].get_node_shared_ptr(), t_position_ids}); - } - -protected: - void SetUp() override { - targetDevice = ov::test::utils::DEVICE_CPU; - - const int batch = 2; - const int seq_length = 7; - const size_t max_position_embeddings = 2048; - const size_t ndims = 128; - const size_t num_head = 32; - - InputShape inpShape = {{batch, seq_length, num_head, ndims}, {{batch, seq_length, num_head, ndims}}}; - init_input_shapes({inpShape}); - function = buildROPE_Llama2(batch, seq_length, max_position_embeddings, num_head, ndims); - } -}; - -TEST_F(RoPECPUTestLlama2, smoke_CompareWithRefs) { - run(); - CheckNumberOfNodesWithType(compiledModel, "RoPE", 1); -} - -class RoPECPUTestChatGLM : public SubgraphBaseTest { -public: - ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1) { - auto tensor = ov::Tensor(ov::element::i32, shape); - auto* ptr = static_cast(tensor.data()); - for (size_t i = 0; i < tensor.get_size(); i++) { - ptr[i] = start; - start += step; - } - return tensor; - } - - void generate_inputs(const std::vector& targetInputStaticShapes) override { - const auto& funcInputs = function->inputs(); - - auto& input_shape = targetInputStaticShapes[0]; - auto seq_length = input_shape[0]; - // auto batch = input_shape[1]; - - ov::Tensor t_input = - utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768); - ov::Tensor t_cos_sin_cache = - utils::create_and_fill_tensor(funcInputs[1].get_element_type(), {32768, 32, 2}, 2, -1.0f, 32768); - ov::Tensor t_position_ids = create_i32_tensor(ov::Shape({1, seq_length}), 15); - - inputs.clear(); - inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input}); - inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos_sin_cache}); - inputs.insert({funcInputs[2].get_node_shared_ptr(), t_position_ids}); - } - -protected: - std::shared_ptr buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims) { - auto input = - std::make_shared(ov::element::f32, PartialShape{-1, batch, 4096 + 256 + 256}); - auto cos_sin_cache = std::make_shared(ov::element::f32, PartialShape{32768, 32, 2}); - auto position_ids = std::make_shared(ov::element::i32, PartialShape{-1, -1}); - - auto __module_transformer_index_67_Gather = - makeOP({cos_sin_cache, position_ids, 0}, {{"batch_dims", 0}}); - auto __module_transformer_transpose_Transpose = - makeOP({__module_transformer_index_67_Gather, {1, 0, 2, 3}}); - auto size_ShapeOf_110 = - makeOP({__module_transformer_transpose_Transpose}, {{"output_type", "i32"}}); - auto __getitem___Gather = makeOP({size_ShapeOf_110, -2, 0}, {{"batch_dims", 0}}); - auto mul_Multiply = makeOP({__getitem___Gather, 2}, {{"auto_broadcast", "numpy"}}); - auto slice_Unsqueeze_112 = makeOP({mul_Multiply, 0}); - - auto floordiv_Divide = - makeOP({mul_Multiply, 2}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); - auto floordiv_Floor = makeOP({floordiv_Divide}); - auto ListConstruct_126_Reshape_2 = makeOP({floordiv_Floor, {-1}}, {{"special_zero", false}}); - - auto ListUnpack_321 = makeOP({input, -1, {4096, 256, 256}}); - auto view_Reshape = - makeOP({ListUnpack_321->output(0), {0, 0, 32, 128}}, {{"special_zero", true}}); - - auto ScatterUpdate_229053 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_112, {0}}); - auto slice_Slice_357 = - makeOP({view_Reshape, {0, 0, 0, 0}, ScatterUpdate_229053, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto size_ShapeOf_346 = makeOP({view_Reshape}, {{"output_type", "i32"}}); - auto size_Gather_348 = makeOP({size_ShapeOf_346, 0, 0}, {{"batch_dims", 0}}); - auto ListConstruct_372_Reshape = makeOP({size_Gather_348, {-1}}, {{"special_zero", false}}); - auto size_Gather_351 = makeOP({size_ShapeOf_346, {2}, 0}, {{"batch_dims", 0}}); - auto ListConstruct_372_Concat = - makeOP({ListConstruct_372_Reshape, {-1}, size_Gather_351, ListConstruct_126_Reshape_2, {2}}, - {{"axis", 0}}); - auto reshape_Reshape_373 = - makeOP({slice_Slice_357, ListConstruct_372_Concat}, {{"special_zero", false}}); - auto select_Gather_381 = makeOP({reshape_Reshape_373, 0, -1}, {{"batch_dims", 0}}); - auto slice_Unsqueeze_367 = makeOP({size_Gather_348, 0}); - auto slice_Slice_369 = - makeOP({__module_transformer_transpose_Transpose, {0}, slice_Unsqueeze_367, {1}}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto size_ShapeOf_374 = makeOP({reshape_Reshape_373}, {{"output_type", "i32"}}); - auto size_Gather_376 = makeOP({size_ShapeOf_374, {3}, 0}, {{"batch_dims", 0}}); - auto ListConstruct_379_Concat = - makeOP({ListConstruct_372_Reshape, {-1}, {1}, size_Gather_376, {2}}, {{"axis", 0}}); - auto view_Reshape_380 = - makeOP({slice_Slice_369, ListConstruct_379_Concat}, {{"special_zero", false}}); - auto select_Gather_382 = makeOP({view_Reshape_380, 0, -1}, {{"batch_dims", 0}}); - auto mul_Multiply_383 = - makeOP({select_Gather_381, select_Gather_382}, {{"auto_broadcast", "numpy"}}); - auto select_Gather_384 = makeOP({reshape_Reshape_373, 1, -1}, {{"batch_dims", 0}}); - auto select_Gather_385 = makeOP({view_Reshape_380, 1, -1}, {{"batch_dims", 0}}); - auto mul_Multiply_386 = - makeOP({select_Gather_384, select_Gather_385}, {{"auto_broadcast", "numpy"}}); - auto sub_Subtract_389 = - makeOP({mul_Multiply_383, mul_Multiply_386}, {{"auto_broadcast", "numpy"}}); - auto Unsqueeze_62716 = makeOP({sub_Subtract_389, -1}); - auto mul_Multiply_391 = - makeOP({select_Gather_384, select_Gather_382}, {{"auto_broadcast", "numpy"}}); - auto mul_Multiply_393 = - makeOP({select_Gather_381, select_Gather_385}, {{"auto_broadcast", "numpy"}}); - auto add_Add_396 = makeOP({mul_Multiply_391, mul_Multiply_393}, {{"auto_broadcast", "numpy"}}); - auto Unsqueeze_62717 = makeOP({add_Add_396, -1}); - auto stack_401 = makeOP({Unsqueeze_62716, Unsqueeze_62717}, {{"axis", -1}}); - auto flatten_ShapeOf_402 = makeOP({stack_401}, {{"output_type", "i32"}}); - auto flatten_Slice_417 = makeOP({flatten_ShapeOf_402, {0}, {3}, {1}}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto flatten_Concat_420 = makeOP({flatten_Slice_417, {-1}}, {{"axis", 0}}); - auto flatten_Reshape_421 = makeOP({stack_401, flatten_Concat_420}, {{"special_zero", true}}); - auto ScatterUpdate_229067 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_112, {0}}); - auto slice_Slice_363 = - makeOP({view_Reshape, ScatterUpdate_229067, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto cat_Concat_425 = makeOP({flatten_Reshape_421, slice_Slice_363}, {{"axis", -1}}); - return std::make_shared(ov::NodeVector{cat_Concat_425}, - ov::ParameterVector{input, cos_sin_cache, position_ids}); - } - void SetUp() override { - targetDevice = ov::test::utils::DEVICE_CPU; - - const int batch = 2; - const int seq_length = 7; - const int num_head = 32; - const int rotary_dims = 64; - - InputShape inpShape = {{-1, batch, 4096 + 256 + 256}, {{seq_length, batch, 4096 + 256 + 256}}}; - init_input_shapes({inpShape}); - function = buildROPE_ChatGLM(batch, num_head, rotary_dims); - } -}; - -TEST_F(RoPECPUTestChatGLM, smoke_CompareWithRefs) { - run(); - CheckNumberOfNodesWithType(compiledModel, "RoPE", 1); -} - -class RoPECPUTestQwen7b : public SubgraphBaseTest, public testing::WithParamInterface { -public: - static std::string getTestCaseName(const testing::TestParamInfo& obj) { - const bool specialReshape = obj.param; - std::ostringstream result; - result << "specialReshape=" << specialReshape << std::endl; - return result.str(); - } - void generate_inputs(const std::vector& targetInputStaticShapes) override { - const auto& funcInputs = function->inputs(); - - auto& input_shape = targetInputStaticShapes[0]; - - ov::Tensor t_input = - utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768); - ov::Tensor t_cos_cache = - utils::create_and_fill_tensor(funcInputs[1].get_element_type(), {1, 4096, 1, 128}, 2, -1.0f, 32768); - ov::Tensor t_sin_cache = - utils::create_and_fill_tensor(funcInputs[1].get_element_type(), {1, 4096, 1, 128}, 2, -1.0f, 32768); - - inputs.clear(); - inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input}); - inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos_cache}); - inputs.insert({funcInputs[2].get_node_shared_ptr(), t_sin_cache}); - } - -protected: - std::shared_ptr buildROPE_QWen7b(bool specialReshape) { - auto input = - std::make_shared(ov::element::f32, PartialShape{-1, -1, 4096 + 4096 + 4096}); - auto cos_cache = std::make_shared(ov::element::f32, PartialShape{1, -1, 1, 128}); - auto sin_cache = std::make_shared(ov::element::f32, PartialShape{1, -1, 1, 128}); - - auto ListUnpack_389_VariadicSplit = makeOP({input, 2, {4096, 4096, -1}}); - auto view_Reshape = makeOP({ListUnpack_389_VariadicSplit->output(0), {0, 0, 32, 128}}, - {{"special_zero", true}}); - auto size_ShapeOf_414 = makeOP({view_Reshape}, {{"output_type", "i32"}}); - auto size_Gather_416 = makeOP({size_ShapeOf_414, 1, 0}, {{"batch_dims", 0}}); - auto neg_Multiply = makeOP({size_Gather_416, -1}, {{"auto_broadcast", "numpy"}}); - auto slice_Unsqueeze_422 = makeOP({neg_Multiply, 0}); - auto ScatterUpdate_261437 = makeOP({{0, 0}, {1}, slice_Unsqueeze_422, {0}}); - auto slice_Slice_425 = makeOP({cos_cache, ScatterUpdate_261437, {0ll, LLONG_MAX}, {1, 1}}, - {{"begin_mask", {1, 0}}, - {"end_mask", {1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto slice_Slice_431 = - makeOP({slice_Slice_425, {0, 0, 0}, {0ll, 0ll, LLONG_MAX}, {1, 1, 1}}, - {{"begin_mask", {1, 1, 0}}, - {"end_mask", {1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto slice_Slice_437 = - makeOP({slice_Slice_431, {0, 0, 0, 0}, {0ll, 0ll, 0ll, LLONG_MAX}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto size_ShapeOf_462 = makeOP({slice_Slice_437}, {{"output_type", "i32"}}); - auto size_Gather_464 = makeOP({size_ShapeOf_462, {3}, 0}, {{"batch_dims", 0}}); - auto ScatterUpdate_261533 = makeOP({{0, 0, 0, 0}, {3}, size_Gather_464, {0}}); - auto slice_Slice_470 = - makeOP({view_Reshape, {0, 0, 0, 0}, ScatterUpdate_261533, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto mul_Multiply = makeOP({slice_Slice_470, slice_Slice_437}, {{"auto_broadcast", "numpy"}}); - auto size_ShapeOf_478 = makeOP({slice_Slice_470}, {{"output_type", "i32"}}); - auto Gather_239390 = makeOP({size_ShapeOf_478, {0, 1, 2}, 0}, {{"batch_dims", 0}}); - auto size_Gather_489 = makeOP({size_ShapeOf_478, 3, 0}, {{"batch_dims", 0}}); - auto floor_divide_Divide = - makeOP({size_Gather_489, 2}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); - auto floor_divide_Floor = makeOP({floor_divide_Divide}); - auto ListConstruct_493_Reshape_3 = - makeOP({floor_divide_Floor, {-1}}, {{"special_zero", false}}); - auto ListConstruct_493_Concat = - makeOP({Gather_239390, {2}, ListConstruct_493_Reshape_3}, {{"axis", 0}}); - std::shared_ptr reshape_Reshape = nullptr; - if (specialReshape) { - reshape_Reshape = makeOP({slice_Slice_470, {0, 0, 32, 2, 64}}, {{"special_zero", true}}); - } else { - reshape_Reshape = - makeOP({slice_Slice_470, ListConstruct_493_Concat}, {{"special_zero", false}}); - } - auto ListUnpack_496_Split = makeOP({reshape_Reshape, -2}, {{"num_splits", 2}}); - auto ListUnpack_496_Squeeze_0 = makeOP({ListUnpack_496_Split->output(1), -2}); - auto Constant_296840_compressed = makeConst(element::f16, - ov::Shape({ - 1, - 1, - 1, - 1, - }), - {-1}); - auto Constant_296840 = makeOP({Constant_296840_compressed}, {{"destination_type", "f32"}}); - auto neg_Multiply_499 = - makeOP({ListUnpack_496_Squeeze_0, Constant_296840}, {{"auto_broadcast", "numpy"}}); - auto ListUnpack_496_Squeeze = makeOP({ListUnpack_496_Split->output(0), -2}); - auto cat_Concat = makeOP({neg_Multiply_499, ListUnpack_496_Squeeze}, {{"axis", -1}}); - auto slice_Slice_449 = makeOP({sin_cache, ScatterUpdate_261437, {0ll, LLONG_MAX}, {1, 1}}, - {{"begin_mask", {1, 0}}, - {"end_mask", {1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto slice_Slice_455 = - makeOP({slice_Slice_449, {0, 0, 0}, {0ll, 0ll, LLONG_MAX}, {1, 1, 1}}, - {{"begin_mask", {1, 1, 0}}, - {"end_mask", {1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto slice_Slice_461 = - makeOP({slice_Slice_455, {0, 0, 0, 0}, {0ll, 0ll, 0ll, LLONG_MAX}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto mul_Multiply_503 = makeOP({cat_Concat, slice_Slice_461}, {{"auto_broadcast", "numpy"}}); - auto add_Add = makeOP({mul_Multiply, mul_Multiply_503}, {{"auto_broadcast", "numpy"}}); - return std::make_shared(ov::NodeVector{add_Add}, ov::ParameterVector{input, cos_cache, sin_cache}); - } - void SetUp() override { - targetDevice = ov::test::utils::DEVICE_CPU; - const bool specialReshape = this->GetParam(); - const int batch = 2; - const int seq_length = 7; - InputShape inpShape = {{batch, -1, 4096 + 4096 + 4096}, {{batch, seq_length, 4096 + 4096 + 4096}}}; - init_input_shapes({inpShape}); - function = buildROPE_QWen7b(specialReshape); - } -}; - -TEST_P(RoPECPUTestQwen7b, smoke_CompareWithRefs) { - run(); - CheckNumberOfNodesWithType(compiledModel, "RoPE", 1); -} - -INSTANTIATE_TEST_SUITE_P(smoke_RoPECPUTestQwen7b, - RoPECPUTestQwen7b, - ::testing::Values(true, false), - RoPECPUTestQwen7b::getTestCaseName); - -class RoPECPUTestGPTJ : public SubgraphBaseTest, public testing::WithParamInterface { -public: - static std::string getTestCaseName(const testing::TestParamInfo& obj) { - bool hasShapeOf; - hasShapeOf = obj.param; - std::ostringstream result; - result << "hasShapeOf=" << hasShapeOf << std::endl; - return result.str(); - } - void generate_inputs(const std::vector& targetInputStaticShapes) override { - const auto& funcInputs = function->inputs(); - - auto& input_shape = targetInputStaticShapes[0]; - auto& sincos_shape = targetInputStaticShapes[1]; - ov::Tensor t_input = - utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768); - ov::Tensor t_cos_sin_cache = - utils::create_and_fill_tensor(funcInputs[1].get_element_type(), sincos_shape, 2, -1.0f, 32768); - - inputs.clear(); - inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input}); - inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos_sin_cache}); - } - -protected: - std::shared_ptr buildROPE_GPTJ(const int num_head, - const int hidden_dims, - const int rotary_dims, - bool hasShapeOf) { - auto int32_max = std::numeric_limits::max(); - auto input = - std::make_shared(ov::element::f32, PartialShape{-1, -1, num_head, hidden_dims}); - auto sincos = std::make_shared(ov::element::f32, PartialShape{-1, -1, rotary_dims}); - - auto slice_Slice_965 = - makeOP({input, {0, 0, 0, 0}, {0, 0, 0, rotary_dims}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - slice_Slice_965->set_friendly_name("slice_Slice_965"); - - auto varsplit = makeOP({sincos, -1, {rotary_dims / 2, -1}}); - varsplit->set_output_size(2); - varsplit->set_friendly_name("varsplit"); - auto unsqueeze_sin = makeOP({varsplit->output(0), 2}); - auto unsqueeze_cos = makeOP({varsplit->output(1), 2}); - std::vector gather_idx(rotary_dims, 1); - int32_t v = 0; - for (size_t i = 0; i < gather_idx.size(); i += 2, v++) { - gather_idx[i] = v; - gather_idx[i + 1] = v; - } - - auto const_idx = makeConst(ov::element::i32, ov::Shape({static_cast(rotary_dims)}), gather_idx); - auto constant_155588 = makeConst(element::f32, - ov::Shape({ - 1, - 1, - 1, - 1, - }), - {-1.000000f}); - auto repeat_interleave_sin = makeOP({unsqueeze_sin, const_idx, 3}, {{"batch_dims", 0}}); - auto repeat_interleave_cos = makeOP({unsqueeze_cos, const_idx, 3}, {{"batch_dims", 0}}); - repeat_interleave_sin->set_friendly_name("repeat_interleave_sin"); - repeat_interleave_cos->set_friendly_name("repeat_interleave_cos"); - // x interleave (-x[:,:,:, 1::2], x[:,:,:, 0::2]) - auto slice_Slice_1174 = - makeOP({slice_Slice_965, {0, 0, 0, 1}, {0, 0, 0, int32_max}, {1, 1, 1, 2}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto neg_Multiply_1177 = - makeOP({slice_Slice_1174, constant_155588}, {{"auto_broadcast", "numpy"}}); - auto Unsqueeze_65524 = makeOP({neg_Multiply_1177, -1}); - - auto slice_Slice_1168 = - makeOP({slice_Slice_965, {0, 0, 0, 0}, {0, 0, 0, int32_max}, {1, 1, 1, 2}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Unsqueeze_65525 = makeOP({slice_Slice_1168, -1}); - auto stack_1182 = makeOP({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}}); - auto flatten_Reshape_1198 = - makeOP({stack_1182, {0, 0, num_head, rotary_dims}}, {{"special_zero", true}}); - // x*cos [B,L,H,ndims] - auto mul_cos = - makeOP({slice_Slice_965, repeat_interleave_cos}, {{"auto_broadcast", "numpy"}}); - mul_cos->set_friendly_name("mul_cos"); - auto mul_sin = - makeOP({flatten_Reshape_1198, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}}); - // *cos + *sin - auto rotary_emb = makeOP({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}}); - - auto slice_Slice_971 = - makeOP({input, {0, 0, 0, rotary_dims}, {0, 0, 0, int32_max}, {1, 1, 1, 1}}, - {{"begin_mask", {1, 1, 1, 0}}, - {"end_mask", {1, 1, 1, 0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto cat_Concat_1211 = makeOP({rotary_emb, slice_Slice_971}, {{"axis", -1}}); - auto permute_Transpose_1213 = makeOP({cat_Concat_1211, {0, 2, 1, 3}}); - ov::NodeVector model_output = {permute_Transpose_1213}; - if (hasShapeOf) { - auto shapeOf = makeOP({rotary_emb}, {{"output_type", "i32"}}); - auto gather = makeOP({shapeOf, {1}, 0}, {{"batch_dims", 0}}); - model_output.push_back(gather); - } - return std::make_shared(model_output, ov::ParameterVector{input, sincos}); - } - void SetUp() override { - targetDevice = ov::test::utils::DEVICE_CPU; - bool hasShapeOf = this->GetParam(); - const int batch = 2; - const int seq_length = 7; - const int num_head = 16; - const int hidden_dims = 256; - const int rotary_dims = 64; - - InputShape input = {{batch, seq_length, num_head, hidden_dims}, {{batch, seq_length, num_head, hidden_dims}}}; - InputShape sincos = {{batch, seq_length, rotary_dims}, {{batch, seq_length, rotary_dims}}}; - init_input_shapes({input, sincos}); - function = buildROPE_GPTJ(num_head, hidden_dims, rotary_dims, hasShapeOf); - } -}; - -TEST_P(RoPECPUTestGPTJ, smoke_CompareWithRefs) { - run(); - CheckNumberOfNodesWithType(compiledModel, "RoPE", 1); -} - -INSTANTIATE_TEST_SUITE_P(smoke_RoPECPUTestGPTJ, - RoPECPUTestGPTJ, - ::testing::Values(true, false), - RoPECPUTestGPTJ::getTestCaseName); - -} // namespace test -} // namespace ov diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index 0d55c708f94405..f9bbebb878ba81 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -371,6 +371,8 @@ std::vector disabledTestPatterns() { retVector.emplace_back(R"(smoke_VariableState/OVInferRequestVariableStateTest.*)"); // Issue: 141705 retVector.emplace_back(R"(.*smoke_arm_Deconv_2D_Planar_FP16/DeconvolutionLayerCPUTest.*INFERENCE_PRECISION_HINT=f16.*)"); + + retVector.emplace_back(R"(.*smoke_RoPETest.*)"); #endif #if defined(OPENVINO_ARCH_ARM) diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp new file mode 100644 index 00000000000000..0ff1d18ae09ff7 --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp @@ -0,0 +1,32 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "subgraph_tests/rotary_pos_emb.hpp" + +namespace ov { +namespace test { + +INSTANTIATE_TEST_SUITE_P(smoke_RoPETestLlama2, + RoPETestLlama2, + ::testing::Values(ov::test::utils::DEVICE_CPU), + RoPETestLlama2::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM, + RoPETestChatGLM, + ::testing::Values(ov::test::utils::DEVICE_CPU), + RoPETestChatGLM::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwen7b, + RoPETestQwen7b, + ::testing::Combine(::testing::Values(true, false), + ::testing::Values(ov::test::utils::DEVICE_CPU)), + RoPETestQwen7b::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_RoPETestGPTJ, + RoPETestGPTJ, + ::testing::Combine(::testing::Values(true, false), + ::testing::Values(ov::test::utils::DEVICE_CPU)), + RoPETestGPTJ::getTestCaseName); +} // namespace test +} // namespace ov diff --git a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp index a20017540379af..6ce8bb62407aa5 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp @@ -286,3 +286,4 @@ REGISTER_FACTORY(internal, Convolution); REGISTER_FACTORY(internal, Placeholder); REGISTER_FACTORY(internal, SDPA); REGISTER_FACTORY(internal, IndirectSDPA); +REGISTER_FACTORY(internal, RoPE); diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp new file mode 100644 index 00000000000000..b9568c766412de --- /dev/null +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp @@ -0,0 +1,92 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include "primitive.hpp" +#include "ov_ops/rotary_positional_embeddings.hpp" + +namespace cldnn { +using RoPE = ov::op::internal::RoPE; + +/// @brief Rotary Position Embedding primitive +struct rope : public primitive_base { + CLDNN_DECLARE_PRIMITIVE(rope); + + rope() : primitive_base("", {}) {} + + /// @brief Constructs rope primitive + /// @param id This primitive id + /// @param inputs Inputs primitive ids + /// @param config Specific RoPE config + rope(const primitive_id& id, + const std::vector& inputs, + const RoPE::Config& config, + const padding& output_padding = padding()) + : primitive_base(id, inputs, {output_padding}), + config(config) {} + + RoPE::Config config; + + size_t hash() const override { + size_t seed = primitive::hash(); + seed = hash_combine(seed, config.gather_position_arg_id); + seed = hash_combine(seed, config.head_cnt); + seed = hash_combine(seed, config.head_size); + seed = hash_combine(seed, config.input_trans0213); + seed = hash_combine(seed, config.is_chatglm); + seed = hash_combine(seed, config.is_interleaved); + seed = hash_combine(seed, config.is_qwen); + seed = hash_combine(seed, config.rotary_ndims); + seed = hash_combine(seed, config.slice_start); + seed = hash_combine(seed, config.slice_stop); + return seed; + } + + bool operator==(const primitive& rhs) const override { + if (!compare_common_params(rhs)) + return false; + + auto rhs_casted = downcast(rhs); + + return config.gather_position_arg_id == rhs_casted.config.gather_position_arg_id && + config.head_cnt == rhs_casted.config.head_cnt && + config.head_size == rhs_casted.config.head_size && + config.input_trans0213 == rhs_casted.config.input_trans0213 && + config.is_chatglm == rhs_casted.config.is_chatglm && + config.is_interleaved == rhs_casted.config.is_interleaved && + config.is_qwen == rhs_casted.config.is_qwen && + config.rotary_ndims == rhs_casted.config.rotary_ndims && + config.slice_start == rhs_casted.config.slice_start && + config.slice_stop == rhs_casted.config.slice_stop; + } + + void save(BinaryOutputBuffer& ob) const override { + primitive_base::save(ob); + ob << config.gather_position_arg_id; + ob << config.head_cnt; + ob << config.head_size; + ob << config.input_trans0213; + ob << config.is_chatglm; + ob << config.is_interleaved; + ob << config.is_qwen; + ob << config.rotary_ndims; + ob << config.slice_start; + ob << config.slice_stop; + } + + void load(BinaryInputBuffer& ib) override { + primitive_base::load(ib); + ib >> config.gather_position_arg_id; + ib >> config.head_cnt; + ib >> config.head_size; + ib >> config.input_trans0213; + ib >> config.is_chatglm; + ib >> config.is_interleaved; + ib >> config.is_qwen; + ib >> config.rotary_ndims; + ib >> config.slice_start; + ib >> config.slice_stop; + } +}; +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp index 855ae9c421b235..cadab1b29ec711 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp @@ -94,6 +94,7 @@ void register_implementations() { REGISTER_OCL(unique_count); REGISTER_OCL(unique_gather); REGISTER_OCL(scaled_dot_product_attention); + REGISTER_OCL(rope); } } // namespace ocl diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp index f0d2a72e51d848..bacb3c60023c76 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp @@ -75,6 +75,7 @@ #include "intel_gpu/primitives/unique.hpp" #include "intel_gpu/primitives/kv_cache.hpp" #include "intel_gpu/primitives/scaled_dot_product_attention.hpp" +#include "intel_gpu/primitives/rope.hpp" namespace cldnn { namespace ocl { @@ -174,6 +175,7 @@ REGISTER_OCL(eye); REGISTER_OCL(unique_count); REGISTER_OCL(unique_gather); REGISTER_OCL(scaled_dot_product_attention); +REGISTER_OCL(rope); #undef REGISTER_OCL diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp new file mode 100644 index 00000000000000..02ef732daaafad --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp @@ -0,0 +1,88 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "primitive_base.hpp" + +#include "rope_inst.h" +#include "rope/rope_kernel_selector.h" +#include "rope/rope_kernel_ref.h" + +namespace cldnn { +namespace ocl { + +struct rope_impl : typed_primitive_impl_ocl { + using parent = typed_primitive_impl_ocl; + using parent::parent; + using kernel_selector_t = kernel_selector::rope_kernel_selector; + using kernel_params_t = kernel_selector::rope_params; + + DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::rope_impl); + + std::unique_ptr clone() const override { + return make_unique(*this); + } + + void load(BinaryInputBuffer& ib) override { + parent::load(ib); + if (is_dynamic()) { + auto& kernel_selector = kernel_selector_t::Instance(); + auto kernel_impl = kernel_selector.GetImplementation(_kernel_data.kernelName); + kernel_impl->GetUpdateDispatchDataFunc(_kernel_data); + } + } + + static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) { + const auto& primitive = impl_param.typed_desc(); + auto params = get_default_params(impl_param, is_shape_agnostic); + + params.head_cnt = primitive->config.head_cnt; + params.head_size = primitive->config.head_size; + params.rotary_ndims = primitive->config.rotary_ndims; + + params.slice_start = primitive->config.slice_start; + params.slice_stop = primitive->config.slice_stop; + + params.axis = primitive->config.is_qwen || primitive->config.is_chatglm ? 2 : 3; + params.num_of_inputs = primitive->config.is_chatglm || primitive->config.is_interleaved ? 2 : 3; + + params.is_qwen = primitive->config.is_qwen; + params.is_chatglm = primitive->config.is_chatglm; + + for (size_t i = 1; i < impl_param.input_layouts.size(); ++i) { + params.inputs.push_back(convert_data_tensor(impl_param.get_input_layout(i))); + } + return params; + } + + void update_dispatch_data(const kernel_impl_params& impl_param) override { + auto kernel_params = get_kernel_params(impl_param, true); + (_kernel_data.update_dispatch_data_func)(kernel_params, _kernel_data); + } +}; + +namespace detail { + +attach_rope_impl::attach_rope_impl() { + auto types = { + data_types::f32, + data_types::f16 + }; + + auto formats = { + format::bfyx + }; + + implementation_map::add(impl_types::ocl, + shape_types::any, + typed_primitive_impl_ocl::create, + types, + formats); +} + +} // namespace detail +} // namespace ocl +} // namespace cldnn + +BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::rope_impl) +BIND_BINARY_BUFFER_WITH_TYPE(cldnn::rope) diff --git a/src/plugins/intel_gpu/src/graph/include/rope_inst.h b/src/plugins/intel_gpu/src/graph/include/rope_inst.h new file mode 100644 index 00000000000000..669032ac6965f9 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/include/rope_inst.h @@ -0,0 +1,39 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "intel_gpu/primitives/rope.hpp" +#include "primitive_inst.h" + +#include + +namespace cldnn { +template <> +struct typed_program_node : public typed_program_node_base { + using parent = typed_program_node_base; + +public: + using parent::parent; + + program_node& input(size_t idx = 0) const { return get_dependency(idx); } + std::vector get_shape_infer_dependencies() const override { return {}; } +}; + +using rope_node = typed_program_node; + +template <> +class typed_primitive_inst : public typed_primitive_inst_base { + using parent = typed_primitive_inst_base; + using parent::parent; + +public: + template + static std::vector calc_output_layouts(const rope_node& /*node*/, const kernel_impl_params& impl_param); + static layout calc_output_layout(rope_node const& node, kernel_impl_params const& impl_param); + static std::string to_string(rope_node const& node); +}; + +using rope_inst = typed_primitive_inst; +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/rope.cpp b/src/plugins/intel_gpu/src/graph/rope.cpp new file mode 100644 index 00000000000000..8a2307e180bfe7 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/rope.cpp @@ -0,0 +1,76 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "rope_inst.h" + +#include "primitive_type_base.h" +#include "json_object.h" +#include + +namespace cldnn { +GPU_DEFINE_PRIMITIVE_TYPE_ID(rope) + +layout rope_inst::calc_output_layout(rope_node const& node, kernel_impl_params const& impl_param) { + return calc_output_layouts(node, impl_param)[0]; +} + +template +std::vector rope_inst::calc_output_layouts(rope_node const& node, kernel_impl_params const& impl_param) { + auto desc = impl_param.typed_desc(); + + const auto& input0_layout = impl_param.get_input_layout(0); + const auto& input0_shape = input0_layout.get(); + auto output_format = input0_layout.format; + + auto output_type = desc->output_data_types[0].value_or(input0_layout.data_type); + if (impl_param.has_fused_primitives()) { + output_type = impl_param.get_output_element_type(); + } + + ShapeType output_shape = input0_shape; + + if (desc->config.is_qwen || desc->config.is_chatglm) { + output_shape = { input0_shape[0], + input0_shape[1], + ov::Dimension(desc->config.head_cnt), + ov::Dimension(desc->config.head_size) }; + } else { + auto input_slice_size = desc->config.slice_stop - desc->config.slice_start; + if (input_slice_size > 0) { + output_shape[3] = input_slice_size; + } + + if (desc->config.input_trans0213 || desc->config.is_interleaved) { + std::swap(output_shape[2], output_shape[1]); + } + } + return { layout(output_shape, output_type, output_format) }; +} + +template std::vector rope_inst::calc_output_layouts(rope_node const& node, const kernel_impl_params& impl_param); + +std::string rope_inst::to_string(rope_node const& node) { + auto desc = node.get_primitive(); + auto node_info = node.desc_to_json(); + + std::stringstream primitive_description; + + json_composite rope_info; + rope_info.add("gather_position_arg_id", desc->config.gather_position_arg_id); + rope_info.add("head_cnt", desc->config.head_cnt); + rope_info.add("head_size", desc->config.head_size); + rope_info.add("input_trans0213", desc->config.input_trans0213); + rope_info.add("is_chatglm", desc->config.is_chatglm); + rope_info.add("is_interleaved", desc->config.is_interleaved); + rope_info.add("is_qwen", desc->config.is_qwen); + rope_info.add("rotary_ndims", desc->config.rotary_ndims); + rope_info.add("slice_start", desc->config.slice_start); + rope_info.add("slice_stop", desc->config.slice_stop); + node_info->add("rope info", rope_info); + node_info->dump(primitive_description); + + return primitive_description.str(); +} + +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl new file mode 100644 index 00000000000000..1b6b9fe65491a7 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl @@ -0,0 +1,86 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "include/fetch_utils.cl" + +#ifdef CHATGLM +KERNEL(rope_ref)( + OPTIONAL_SHAPE_INFO_ARG + const __global INPUT0_TYPE* input, + const __global INPUT1_TYPE* cos_sin, + __global OUTPUT_TYPE* output) +{ + const uint p = get_global_id(0); + const uint b = get_global_id(1); + const uint h = get_global_id(2) % HEAD_COUNT; + const uint rf = get_global_id(2) / HEAD_COUNT; + uint r = rf < HALF_ROTARY_NDIMS ? rf * 2 : 0; + uint f = rf < HEAD_SIZE - ROTARY_NDIMS ? rf : 0; + +#ifdef ENABLE_SLICE + uint input_idx = GET_DATA_INDEX(SLICED_INPUT0, p, b, h * HEAD_SIZE, 0); + + input_idx += SLICED_FROM_START * (p * INPUT0_FEATURE_NUM + b + 1) + + SLICED_FROM_END * (p * INPUT0_FEATURE_NUM + b); +#else + uint input_idx = INPUT0_GET_INDEX(p, b, h * HEAD_SIZE, 0); +#endif + uint cos_sin_p = p < INPUT1_BATCH_NUM ? p : 0; + uint cos_sin_b = b < INPUT1_FEATURE_NUM ? b : 0; + uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_p, cos_sin_b, 0, 0); + + uint output_idx = OUTPUT_GET_INDEX(p, b, h, 0); + + INPUT1_TYPE cosv = cos_sin[cos_sin_idx + r]; + INPUT1_TYPE sinv = cos_sin[cos_sin_idx + r + 1]; + + INPUT0_TYPE in1 = input[input_idx + r]; + INPUT0_TYPE in2 = input[input_idx + r + 1]; + + output[output_idx + r] = cosv * in1 - sinv * in2; + output[output_idx + r + 1] = sinv * in1 + cosv * in2; + +#ifdef ENABLE_IO_COPY + output[output_idx + ROTARY_NDIMS + f] = input[input_idx + ROTARY_NDIMS + f]; +#endif +} +#endif + +#ifdef QWEN +KERNEL(rope_ref)( + OPTIONAL_SHAPE_INFO_ARG + const __global INPUT0_TYPE* input, + const __global INPUT1_TYPE* cos, + const __global INPUT1_TYPE* sin, + __global OUTPUT_TYPE* output) +{ + const uint b = get_global_id(0); + const uint p = get_global_id(1); + const uint h = get_global_id(2) / HALF_ROTARY_NDIMS; + const uint r = get_global_id(2) % HALF_ROTARY_NDIMS; + +#ifdef ENABLE_SLICE + uint input_idx = GET_DATA_INDEX(SLICED_INPUT0, b, p, h * HEAD_SIZE, 0); + + input_idx += SLICED_FROM_START * (b * INPUT0_FEATURE_NUM + p + 1) + + SLICED_FROM_END * (b * INPUT0_FEATURE_NUM + p); +#else + uint input_idx = INPUT0_GET_INDEX(b, p, h * HEAD_SIZE, 0); +#endif + uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0; + uint cos_sin_p = p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM < INPUT1_FEATURE_NUM ? p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM : 0; + uint cos_sin_h = h < INPUT1_SIZE_Y ? h : 0; + uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_p, cos_sin_h, 0); + + uint output_idx = OUTPUT_GET_INDEX(b, p, h, 0); + + INPUT0_TYPE in1 = input[input_idx + r]; + INPUT0_TYPE in2 = input[input_idx + HALF_ROTARY_NDIMS + r]; + + output[output_idx + r] = cos[cos_sin_idx + r] * in1 - sin[cos_sin_idx + r] * in2; + + output[output_idx + HALF_ROTARY_NDIMS + r] = cos[cos_sin_idx + HALF_ROTARY_NDIMS + r] * in2 + + sin[cos_sin_idx + HALF_ROTARY_NDIMS + r] * in1; +} +#endif diff --git a/src/plugins/intel_gpu/src/kernel_selector/common_types.h b/src/plugins/intel_gpu/src/kernel_selector/common_types.h index 768a0fc3c4f854..e61fdb073d00e8 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/common_types.h +++ b/src/plugins/intel_gpu/src/kernel_selector/common_types.h @@ -97,6 +97,7 @@ enum class KernelType { UNIQUE_GATHER, RMS, SWIGLU, + ROPE }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp new file mode 100644 index 00000000000000..afc5a322a522d1 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp @@ -0,0 +1,114 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "rope_kernel_base.h" +#include "kernel_selector_utils.h" + +namespace kernel_selector { +bool RoPEKernelBase::Validate(const Params& p) const { + return KernelBaseOpenCL::Validate(p); +} + +JitConstants RoPEKernelBase::GetJitConstants(const rope_params& params, RoPEKernelBase::DispatchData) const { + JitConstants jit = MakeBaseParamsJitConstants(params); + + jit.AddConstant(MakeJitConstant("HEAD_SIZE", params.head_size)); + jit.AddConstant(MakeJitConstant("ROTARY_NDIMS", params.rotary_ndims)); + jit.AddConstant(MakeJitConstant("HALF_ROTARY_NDIMS", params.rotary_ndims / 2)); + jit.AddConstant(MakeJitConstant("HEAD_COUNT", params.head_cnt)); + + if (params.head_size > params.rotary_ndims) { + jit.AddConstant(MakeJitConstant("ENABLE_IO_COPY", true)); + } + + if (params.slice_stop - params.slice_start > 0) { + jit.AddConstant(MakeJitConstant("ENABLE_SLICE", true)); + + auto f = toCodeString(params.inputs[0].Feature(), 1); + auto x = toCodeString(params.inputs[0].X(), 2); + auto y = toCodeString(params.inputs[0].Y(), 3); + auto sliced_y = toCodeString(params.slice_stop - params.slice_start); + + jit.AddConstant(MakeJitConstant("SLICED_INPUT0_X_PITCH", 1)); + jit.AddConstant(MakeJitConstant("SLICED_INPUT0_Y_PITCH", x)); + jit.AddConstant(MakeJitConstant("SLICED_INPUT0_FEATURE_PITCH", x + "*" + sliced_y)); + jit.AddConstant(MakeJitConstant("SLICED_INPUT0_BATCH_PITCH", x + "*" + sliced_y + "*" + f)); + jit.AddConstant(MakeJitConstant("SLICED_INPUT0_OFFSET", 0)); + + jit.AddConstant(MakeJitConstant("SLICED_FROM_START", toCodeString(params.slice_start))); + jit.AddConstant(MakeJitConstant("SLICED_FROM_END", "(" + y + "-" + toCodeString(params.slice_stop) + ")")); + } + + if (params.is_qwen) { + jit.AddConstant(MakeJitConstant("QWEN", true)); + } else if (params.is_chatglm) { + jit.AddConstant(MakeJitConstant("CHATGLM", true)); + } + + return jit; +} + +RoPEKernelBase::DispatchData RoPEKernelBase::SetDefault(const rope_params& params) const { + DispatchData dispatchData; + const auto& input = params.inputs[0]; + const auto& output = params.outputs[0]; + + std::vector> dims_by_gws = {{ Tensor::DataChannelName::BATCH }, + { Tensor::DataChannelName::FEATURE }, + { Tensor::DataChannelName::Y, Tensor::DataChannelName::X }}; + dispatchData.gws = {input.Batch().v, + input.Feature().v, + params.head_cnt * std::max(params.rotary_ndims / 2ul, params.head_size - params.rotary_ndims)}; + dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo, input.GetLayout(), output.GetLayout(), dims_by_gws); + + return dispatchData; +} + +void RoPEKernelBase::GetUpdateDispatchDataFunc(KernelData& kd) const { + kd.update_dispatch_data_func = [this](const Params& params, KernelData& kd) { + const auto& prim_params = static_cast(params); + auto dispatchData = SetDefault(prim_params); + OPENVINO_ASSERT(kd.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func"); + kd.kernels[0].params.workGroups.global = dispatchData.gws; + kd.kernels[0].params.workGroups.local = dispatchData.lws; + kd.kernels[0].skip_execution = KernelData::SkipKernelExecution(prim_params); + }; +} + +KernelsData RoPEKernelBase::GetCommonKernelsData(const Params& params) const { + assert(params.GetType() == KernelType::ROPE); + + if (!Validate(params)) + return {}; + + const rope_params& orgParams = static_cast(params); + auto dispatchData = SetDefault(orgParams); + + KernelData kd = KernelData::Default(params); + + auto cldnn_jit = GetJitConstants(orgParams, dispatchData); + auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, params); + auto jit = CreateJit(kernelName, cldnn_jit, entry_point); + + GetUpdateDispatchDataFunc(kd); + + auto& kernel = kd.kernels[0]; + FillCLKernelData(kernel, + dispatchData, + params.engineInfo, + kernelName, + jit, + entry_point, + EXE_MODE_DEFAULT, + false, + false, + static_cast(orgParams.num_of_inputs), + GetFusedPrimitiveInputsCount(params), + 1, + orgParams.outputs[0].is_dynamic()); + + return {kd}; +} + +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.h new file mode 100644 index 00000000000000..f4a92183c5d0c0 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.h @@ -0,0 +1,45 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "kernel_base_opencl.h" + +namespace kernel_selector { +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// rope_params +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +struct rope_params : public base_params { + rope_params() : base_params(KernelType::ROPE) {} + size_t head_cnt; + size_t head_size; + size_t rotary_ndims; + + size_t slice_start; + size_t slice_stop; + size_t axis; + size_t num_of_inputs; + + bool is_qwen; + bool is_chatglm; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// RoPEKernelBase +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class RoPEKernelBase : public KernelBaseOpenCL { +public: + using KernelBaseOpenCL::KernelBaseOpenCL; + virtual ~RoPEKernelBase() {} + + struct DispatchData : public CommonDispatchData {}; + +protected: + bool Validate(const Params&) const override; + virtual JitConstants GetJitConstants(const rope_params& params, DispatchData dispatchData) const; + virtual DispatchData SetDefault(const rope_params& params) const; + KernelsData GetCommonKernelsData(const Params& params) const; + void GetUpdateDispatchDataFunc(KernelData& kd) const override; +}; +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.cpp new file mode 100644 index 00000000000000..27d7efeb525db5 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.cpp @@ -0,0 +1,34 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "rope_kernel_ref.h" +#include "kernel_selector_utils.h" +#include + +namespace kernel_selector { +ParamsKey RoPEKernelRef::GetSupportedKey() const { + ParamsKey k; + k.EnableInputDataType(Datatype::F16); + k.EnableInputDataType(Datatype::F32); + k.EnableOutputDataType(Datatype::F16); + k.EnableOutputDataType(Datatype::F32); + k.EnableInputLayout(DataLayout::bfyx); + k.EnableOutputLayout(DataLayout::bfyx); + + k.EnableTensorOffset(); + k.EnableTensorPitches(); + k.EnableBatching(); + k.EnableDifferentTypes(); + k.EnableDynamicShapesSupport(); + return k; +} + +KernelsData RoPEKernelRef::GetKernelsData(const Params& params) const { + return GetCommonKernelsData(params); +} + +KernelsPriority RoPEKernelRef::GetKernelsPriority(const Params& /*params*/) const { + return FORCE_PRIORITY_9; +} +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.h new file mode 100644 index 00000000000000..ceea2a17720ad1 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.h @@ -0,0 +1,20 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "rope_kernel_base.h" + +namespace kernel_selector { +class RoPEKernelRef : public RoPEKernelBase { +public: + using Parent = RoPEKernelBase; + RoPEKernelRef() : RoPEKernelBase("rope_ref") {} + virtual ~RoPEKernelRef() {} + + KernelsData GetKernelsData(const Params& params) const override; + KernelsPriority GetKernelsPriority(const Params& params) const override; + ParamsKey GetSupportedKey() const override; +}; +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_selector.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_selector.cpp new file mode 100644 index 00000000000000..e5436971b90c09 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_selector.cpp @@ -0,0 +1,16 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "rope_kernel_selector.h" +#include "rope_kernel_ref.h" + +namespace kernel_selector { +rope_kernel_selector::rope_kernel_selector() { + Attach(); +} + +KernelsData rope_kernel_selector::GetBestKernels(const Params& params) const { + return GetNaiveBestKernel(params, KernelType::ROPE); +} +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_selector.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_selector.h new file mode 100644 index 00000000000000..819e1fafbf694f --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_selector.h @@ -0,0 +1,23 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "kernel_selector.h" + +namespace kernel_selector { +class rope_kernel_selector : public kernel_selector_base { +public: + static rope_kernel_selector& Instance() { + static rope_kernel_selector instance_; + return instance_; + } + + rope_kernel_selector(); + + virtual ~rope_kernel_selector() {} + + KernelsData GetBestKernels(const Params& params) const override; +}; +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/plugin/graph.cpp b/src/plugins/intel_gpu/src/plugin/graph.cpp index cc35d024322538..981a4e4058a9f0 100644 --- a/src/plugins/intel_gpu/src/plugin/graph.cpp +++ b/src/plugins/intel_gpu/src/plugin/graph.cpp @@ -198,6 +198,7 @@ std::shared_ptr Graph::get_runtime_model(std::vector& op) { + validate_inputs_count(op, {3, 4}); + auto inputs = p.GetInputInfo(op); + const auto& config = op->get_config(); + + auto rope = cldnn::rope(layer_type_name_ID(op), + inputs, + config); + + p.add_primitive(*op, rope); +} + +REGISTER_FACTORY_IMPL(internal, RoPE); + +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 0c1041b742c0fb..581fb42a109d0d 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -61,14 +61,13 @@ #include "plugin/transformations/kv_cache_fusion.hpp" #include "plugin/transformations/move_fc_reshape_to_weights.hpp" #include "plugin/transformations/bcast_and_pad_zp_buffers.hpp" -#include "transformations/common_optimizations/rms_fusion.hpp" #include "plugin/transformations/swiglu_fusion.hpp" #include "plugin/transformations/transpose_fusion.hpp" #include "plugin/transformations/indirect_kv_cache.hpp" #include "plugin/transformations/convert_convolution.hpp" #include "plugin/transformations/unsqueeze_broadcast_reshape_matmul_fusion.hpp" -#include "transformations/common_optimizations/rms_fusion.hpp" #include "plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.hpp" +#include "transformations/common_optimizations/rms_fusion.hpp" #include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp" #include "transformations/common_optimizations/broadcast_transition.hpp" #include "transformations/common_optimizations/common_optimizations.hpp" @@ -81,6 +80,7 @@ #include "transformations/common_optimizations/transpose_sinking.hpp" #include "transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp" #include "transformations/common_optimizations/wrap_interpolate_into_transposes.hpp" +#include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp" #include "transformations/control_flow/unroll_tensor_iterator.hpp" #include "transformations/convert_pooling_to_reduce.hpp" #include "transformations/convert_precision.hpp" @@ -815,6 +815,14 @@ void TransformationsPipeline::apply(std::shared_ptr func) { const size_t zp_pad_size = device_info.supports_immad ? 16 : 32; manager.register_pass(zp_pad_size); + manager.register_pass(); + pass_config->disable(); + pass_config->disable(); + pass_config->disable(); + pass_config->disable(); + pass_config->disable(); + pass_config->disable(); + // This is supposed to be the last pass to ensure that we don't have name collisions until // GPU plugin stops using friendly names for program creation manager.register_pass(true); diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index 2b2c2c92f71d29..4e265bb41c89a9 100644 --- a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -201,6 +201,7 @@ std::vector disabledTestPatterns() { R"(.*smoke_RDFT_5d_last_axis/RDFTLayerTest.Inference/IS=\(10.4.8.2.5\)_modelType=f32_Axes=\(0.1.2.3.4\)_SignalSize=\(\).*)", // Issue: 136862 R"(.*smoke_ConditionGPUTest_static/StaticConditionLayerGPUTest.CompareWithRefs/IS=\(3.6\)_netPRC=i8_ifCond=PARAM_targetDevice=GPU_.*)", + #if defined(_WIN32) // by calc abs_threshold with expected value R"(.*smoke_RemoteTensor/OVRemoteTensorBatched_Test.NV12toBGR_buffer/(num_batch_4|num_batch_2).*)", diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp new file mode 100644 index 00000000000000..bed957cc35fc0d --- /dev/null +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp @@ -0,0 +1,22 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "subgraph_tests/rotary_pos_emb.hpp" + +namespace ov { +namespace test { + +INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM, + RoPETestChatGLM, + ::testing::Values(ov::test::utils::DEVICE_GPU), + RoPETestChatGLM::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwen7b, + RoPETestQwen7b, + ::testing::Combine(::testing::Values(true, false), + ::testing::Values(ov::test::utils::DEVICE_GPU)), + RoPETestQwen7b::getTestCaseName); + +} // namespace test +} // namespace ov diff --git a/src/tests/functional/plugin/shared/include/subgraph_tests/rotary_pos_emb.hpp b/src/tests/functional/plugin/shared/include/subgraph_tests/rotary_pos_emb.hpp new file mode 100644 index 00000000000000..29b2cc2278613d --- /dev/null +++ b/src/tests/functional/plugin/shared/include/subgraph_tests/rotary_pos_emb.hpp @@ -0,0 +1,56 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "shared_test_classes/subgraph/rotary_pos_emb.hpp" + +namespace ov { +namespace test { + +inline void CheckNumberOfNodesWithType(std::shared_ptr function, + const std::unordered_set& nodeTypes, + size_t expectedCount) { + ASSERT_NE(nullptr, function); + int num_ops = 0; + for (const auto& node : function->get_ordered_ops()) { + const auto& rt_info = node->get_rt_info(); + const auto layer_type = rt_info.find("layerType")->second.as(); + if (nodeTypes.count(layer_type)) { + num_ops++; + } + } + ASSERT_EQ(num_ops, expectedCount); +} + +TEST_P(RoPETestLlama2, CompareWithRefs) { + SKIP_IF_CURRENT_TEST_IS_DISABLED(); + run(); + auto function = compiledModel.get_runtime_model(); + CheckNumberOfNodesWithType(function, {"RoPE"}, 1); +}; + +TEST_P(RoPETestChatGLM, CompareWithRefs) { + SKIP_IF_CURRENT_TEST_IS_DISABLED(); + run(); + auto function = compiledModel.get_runtime_model(); + CheckNumberOfNodesWithType(function, {"RoPE"}, 1); +}; + +TEST_P(RoPETestQwen7b, CompareWithRefs) { + SKIP_IF_CURRENT_TEST_IS_DISABLED(); + run(); + auto function = compiledModel.get_runtime_model(); + CheckNumberOfNodesWithType(function, {"RoPE"}, 1); +}; + +TEST_P(RoPETestGPTJ, CompareWithRefs) { + SKIP_IF_CURRENT_TEST_IS_DISABLED(); + run(); + auto function = compiledModel.get_runtime_model(); + CheckNumberOfNodesWithType(function, {"RoPE"}, 1); +}; + +} // namespace test +} // namespace ov diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/rotary_pos_emb.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/rotary_pos_emb.hpp new file mode 100644 index 00000000000000..c18b57062f0295 --- /dev/null +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/rotary_pos_emb.hpp @@ -0,0 +1,67 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "shared_test_classes/base/ov_subgraph.hpp" + +namespace ov { +namespace test { + +class RoPETestLlama2 : public SubgraphBaseTest, public testing::WithParamInterface { +private: + ov::OutputVector makeCosSinCache(int max_position_embeddings, int rotary_ndims); + std::shared_ptr buildROPE_Llama2(int batch, + int seq_length, + int max_position_embeddings, + int num_head, + int ndims); + ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1); +protected: + void generate_inputs(const std::vector& targetInputStaticShapes) override; + void SetUp() override; + +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj); +}; + +class RoPETestChatGLM : public SubgraphBaseTest, public testing::WithParamInterface { +private: + std::shared_ptr buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims); + ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1); +protected: + void generate_inputs(const std::vector& targetInputStaticShapes) override; + void SetUp() override; + +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj); +}; + +class RoPETestQwen7b : public SubgraphBaseTest, public testing::WithParamInterface> { +private: + std::shared_ptr buildROPE_QWen7b(bool specialReshape); +protected: + void generate_inputs(const std::vector& targetInputStaticShapes) override; + void SetUp() override; + +public: + static std::string getTestCaseName(const testing::TestParamInfo>& obj); +}; + +class RoPETestGPTJ : public SubgraphBaseTest, public testing::WithParamInterface> { +private: + std::shared_ptr buildROPE_GPTJ(int num_head, + int hidden_dims, + int rotary_dims, + bool hasShapeOf); +protected: + void generate_inputs(const std::vector& targetInputStaticShapes) override; + void SetUp() override; + +public: + static std::string getTestCaseName(const testing::TestParamInfo>& obj); +}; + +} // namespace test +} // namespace ov diff --git a/src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp b/src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp new file mode 100644 index 00000000000000..a39a15783874e1 --- /dev/null +++ b/src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp @@ -0,0 +1,590 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_test_utils/ov_tensor_utils.hpp" +#include "shared_test_classes/subgraph/rotary_pos_emb.hpp" +#include "transformations/utils/gen_pattern.hpp" + +using namespace ov::gen_pattern; +using namespace ov; + +namespace ov { +namespace test { + +ov::OutputVector RoPETestLlama2::makeCosSinCache(int max_position_embeddings, int rotary_ndims) { + std::vector lut_sin(max_position_embeddings * rotary_ndims, 0.0f); + std::vector lut_cos(max_position_embeddings * rotary_ndims, 0.0f); + + // rotate_half style cos/sin table: + // y1 = cos(m*xita_i) * x1 - sin(m*xita_i) * x2 + // y2 = cos(m*xita_i) * x2 + sin(m*xita_i) * x1 + // + for (int i = 0, k = 0; i < rotary_ndims; i += 2, k++) { + auto xita_i = 1.0 / std::pow(10000.0, static_cast(i) / rotary_ndims); + float* psin = lut_sin.data(); + float* pcos = lut_cos.data(); + for (int m = 0; m < max_position_embeddings; m++, psin += rotary_ndims, pcos += rotary_ndims) { + auto vsin = std::sin(xita_i * m); + auto vcos = std::cos(xita_i * m); + pcos[k] = pcos[k + rotary_ndims / 2] = vcos; + psin[k] = psin[k + rotary_ndims / 2] = vsin; + } + } + auto shape = ov::Shape({1, 1, static_cast(max_position_embeddings), static_cast(rotary_ndims)}); + auto Cos = makeConst(ov::element::f32, shape, lut_cos); + auto Sin = makeConst(ov::element::f32, shape, lut_sin); + return {Cos, Sin}; +} + +std::shared_ptr RoPETestLlama2::buildROPE_Llama2(int batch, + int seq_length, + int max_position_embeddings, + int num_head, + int ndims) { + auto input = std::make_shared(ov::element::f32, PartialShape{batch, -1, num_head, ndims}); + auto pos_id_end = std::make_shared(ov::element::i32, ov::Shape{}); + auto pos_ids = std::make_shared(ov::element::i32, PartialShape{1, -1}); + + auto cos_sin_cache = makeCosSinCache(max_position_embeddings, ndims); + auto Constant582 = cos_sin_cache[0]; + auto Constant585 = cos_sin_cache[1]; + + // concat KV length + auto transpose_Transpose = makeOP({input, {0, 2, 1, 3}}); + auto slice_Unsqueeze_426 = makeOP({pos_id_end, 0}); + auto ScatterUpdate_152236 = makeOP({{0, 0, 0}, {2}, slice_Unsqueeze_426, {0}}); + auto slice_Slice = makeOP({Constant582, {0, 0, 0}, ScatterUpdate_152236, {1, 1, 1}}, + {{"begin_mask", {1, 1, 0}}, + {"end_mask", {1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto squeeze_Squeeze = makeOP({slice_Slice, 1}); + auto squeeze_Squeeze_435 = makeOP({squeeze_Squeeze, 0}); + auto index_441_Gather = makeOP({squeeze_Squeeze_435, pos_ids, 0}, {{"batch_dims", 0}}); + auto unsqueeze_Unsqueeze = makeOP({index_441_Gather, 1}); + auto mul_Multiply = + makeOP({transpose_Transpose, unsqueeze_Unsqueeze}, {{"auto_broadcast", "numpy"}}); + auto size_ShapeOf_448 = makeOP({transpose_Transpose}, {{"output_type", "i32"}}); + auto size_Gather_450 = makeOP({size_ShapeOf_448, 3, 0}, {{"batch_dims", 0}}); + auto floor_divide_Divide = + makeOP({size_Gather_450, 2}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto floor_divide_Floor = makeOP({floor_divide_Divide}); + auto slice_Unsqueeze_452 = makeOP({floor_divide_Floor, 0}); + auto ScatterUpdate_152312 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_452, {0}}); + auto slice_Slice_459 = makeOP( + {transpose_Transpose, ScatterUpdate_152312, {0ll, 0ll, 0ll, LLONG_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Constant_182988 = makeConst(element::f32, + ov::Shape({ + 1, + 1, + 1, + 1, + }), + {-1.000000f}); + auto neg_Multiply = makeOP({slice_Slice_459, Constant_182988}, {{"auto_broadcast", "numpy"}}); + auto ScatterUpdate_152368 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_452, {0}}); + auto slice_Slice2 = + makeOP({transpose_Transpose, {0, 0, 0, 0}, ScatterUpdate_152368, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat = makeOP({neg_Multiply, slice_Slice2}, {{"axis", -1}}); + auto ScatterUpdate_152421 = makeOP({{0, 0, 0}, {2}, slice_Unsqueeze_426, {0}}); + auto slice_Slice_433 = makeOP({Constant585, {0, 0, 0}, ScatterUpdate_152421, {1, 1, 1}}, + {{"begin_mask", {1, 1, 0}}, + {"end_mask", {1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto squeeze_Squeeze_436 = makeOP({slice_Slice_433, 1}); + auto squeeze_Squeeze_437 = makeOP({squeeze_Squeeze_436, 0}); + auto index_446_Gather = makeOP({squeeze_Squeeze_437, pos_ids, 0}, {{"batch_dims", 0}}); + auto unsqueeze_Unsqueeze_447 = makeOP({index_446_Gather, 1}); + auto mul_Multiply_463 = + makeOP({cat_Concat, unsqueeze_Unsqueeze_447}, {{"auto_broadcast", "numpy"}}); + auto add_Add = makeOP({mul_Multiply, mul_Multiply_463}, {{"auto_broadcast", "numpy"}}); + + return std::make_shared(ov::NodeVector{add_Add}, ov::ParameterVector{input, pos_id_end, pos_ids}); +} + +ov::Tensor RoPETestLlama2::create_i32_tensor(const ov::Shape& shape, int start, int step) { + auto tensor = ov::Tensor(ov::element::i32, shape); + auto* ptr = static_cast(tensor.data()); + for (size_t i = 0; i < tensor.get_size(); i++) { + ptr[i] = start; + start += step; + } + return tensor; +} + +void RoPETestLlama2::generate_inputs(const std::vector& targetInputStaticShapes) { + const auto& funcInputs = function->inputs(); + + const int position_id_start = 15; + auto& input_shape = targetInputStaticShapes[0]; + auto seq_length = input_shape[1]; + + ov::test::utils::InputGenerateData in_data; + in_data.start_from = -1; + in_data.range = 2; + in_data.resolution = 32768; + ov::Tensor t_input = utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, in_data); + ov::Tensor t_position_id_end = create_i32_tensor(ov::Shape({}), position_id_start + seq_length); + ov::Tensor t_position_ids = create_i32_tensor(ov::Shape({1, seq_length}), position_id_start); + + inputs.clear(); + inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input}); + inputs.insert({funcInputs[1].get_node_shared_ptr(), t_position_id_end}); + inputs.insert({funcInputs[2].get_node_shared_ptr(), t_position_ids}); +} + +void RoPETestLlama2::SetUp() { + targetDevice = this->GetParam(); + + const int batch = 2; + const int seq_length = 7; + const size_t max_position_embeddings = 2048; + const size_t ndims = 128; + const size_t num_head = 32; + + InputShape inpShape = {{batch, seq_length, num_head, ndims}, {{batch, seq_length, num_head, ndims}}}; + init_input_shapes({inpShape}); + function = buildROPE_Llama2(batch, seq_length, max_position_embeddings, num_head, ndims); +} + +std::string RoPETestLlama2::getTestCaseName(const testing::TestParamInfo& obj) { + std::string targetDevice = obj.param; + std::ostringstream result; + result << "targetDevice=" << targetDevice; + return result.str(); +} + +std::shared_ptr RoPETestChatGLM::buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims) { + auto input = + std::make_shared(ov::element::f32, PartialShape{-1, batch, 4096 + 256 + 256}); + auto cos_sin_cache = std::make_shared(ov::element::f32, PartialShape{32768, 32, 2}); + auto position_ids = std::make_shared(ov::element::i32, PartialShape{-1, -1}); + + auto __module_transformer_index_67_Gather = + makeOP({cos_sin_cache, position_ids, 0}, {{"batch_dims", 0}}); + auto __module_transformer_transpose_Transpose = + makeOP({__module_transformer_index_67_Gather, {1, 0, 2, 3}}); + auto size_ShapeOf_110 = + makeOP({__module_transformer_transpose_Transpose}, {{"output_type", "i32"}}); + auto __getitem___Gather = makeOP({size_ShapeOf_110, -2, 0}, {{"batch_dims", 0}}); + auto mul_Multiply = makeOP({__getitem___Gather, 2}, {{"auto_broadcast", "numpy"}}); + auto slice_Unsqueeze_112 = makeOP({mul_Multiply, 0}); + + auto floordiv_Divide = + makeOP({mul_Multiply, 2}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto floordiv_Floor = makeOP({floordiv_Divide}); + auto ListConstruct_126_Reshape_2 = makeOP({floordiv_Floor, {-1}}, {{"special_zero", false}}); + + auto ListUnpack_321 = makeOP({input, -1, {4096, 256, 256}}); + auto view_Reshape = + makeOP({ListUnpack_321->output(0), {0, 0, 32, 128}}, {{"special_zero", true}}); + + auto ScatterUpdate_229053 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_112, {0}}); + auto slice_Slice_357 = + makeOP({view_Reshape, {0, 0, 0, 0}, ScatterUpdate_229053, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto size_ShapeOf_346 = makeOP({view_Reshape}, {{"output_type", "i32"}}); + auto size_Gather_348 = makeOP({size_ShapeOf_346, 0, 0}, {{"batch_dims", 0}}); + auto ListConstruct_372_Reshape = makeOP({size_Gather_348, {-1}}, {{"special_zero", false}}); + auto size_Gather_351 = makeOP({size_ShapeOf_346, {2}, 0}, {{"batch_dims", 0}}); + auto ListConstruct_372_Concat = + makeOP({ListConstruct_372_Reshape, {-1}, size_Gather_351, ListConstruct_126_Reshape_2, {2}}, + {{"axis", 0}}); + auto reshape_Reshape_373 = + makeOP({slice_Slice_357, ListConstruct_372_Concat}, {{"special_zero", false}}); + auto select_Gather_381 = makeOP({reshape_Reshape_373, 0, -1}, {{"batch_dims", 0}}); + auto slice_Unsqueeze_367 = makeOP({size_Gather_348, 0}); + auto slice_Slice_369 = + makeOP({__module_transformer_transpose_Transpose, {0}, slice_Unsqueeze_367, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto size_ShapeOf_374 = makeOP({reshape_Reshape_373}, {{"output_type", "i32"}}); + auto size_Gather_376 = makeOP({size_ShapeOf_374, {3}, 0}, {{"batch_dims", 0}}); + auto ListConstruct_379_Concat = + makeOP({ListConstruct_372_Reshape, {-1}, {1}, size_Gather_376, {2}}, {{"axis", 0}}); + auto view_Reshape_380 = + makeOP({slice_Slice_369, ListConstruct_379_Concat}, {{"special_zero", false}}); + auto select_Gather_382 = makeOP({view_Reshape_380, 0, -1}, {{"batch_dims", 0}}); + auto mul_Multiply_383 = + makeOP({select_Gather_381, select_Gather_382}, {{"auto_broadcast", "numpy"}}); + auto select_Gather_384 = makeOP({reshape_Reshape_373, 1, -1}, {{"batch_dims", 0}}); + auto select_Gather_385 = makeOP({view_Reshape_380, 1, -1}, {{"batch_dims", 0}}); + auto mul_Multiply_386 = + makeOP({select_Gather_384, select_Gather_385}, {{"auto_broadcast", "numpy"}}); + auto sub_Subtract_389 = + makeOP({mul_Multiply_383, mul_Multiply_386}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_62716 = makeOP({sub_Subtract_389, -1}); + auto mul_Multiply_391 = + makeOP({select_Gather_384, select_Gather_382}, {{"auto_broadcast", "numpy"}}); + auto mul_Multiply_393 = + makeOP({select_Gather_381, select_Gather_385}, {{"auto_broadcast", "numpy"}}); + auto add_Add_396 = makeOP({mul_Multiply_391, mul_Multiply_393}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_62717 = makeOP({add_Add_396, -1}); + auto stack_401 = makeOP({Unsqueeze_62716, Unsqueeze_62717}, {{"axis", -1}}); + auto flatten_ShapeOf_402 = makeOP({stack_401}, {{"output_type", "i32"}}); + auto flatten_Slice_417 = makeOP({flatten_ShapeOf_402, {0}, {3}, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto flatten_Concat_420 = makeOP({flatten_Slice_417, {-1}}, {{"axis", 0}}); + auto flatten_Reshape_421 = makeOP({stack_401, flatten_Concat_420}, {{"special_zero", true}}); + auto ScatterUpdate_229067 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_112, {0}}); + auto slice_Slice_363 = + makeOP({view_Reshape, ScatterUpdate_229067, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat_425 = makeOP({flatten_Reshape_421, slice_Slice_363}, {{"axis", -1}}); + return std::make_shared(ov::NodeVector{cat_Concat_425}, + ov::ParameterVector{input, cos_sin_cache, position_ids}); +} + +ov::Tensor RoPETestChatGLM::create_i32_tensor(const ov::Shape& shape, int start, int step) { + auto tensor = ov::Tensor(ov::element::i32, shape); + auto* ptr = static_cast(tensor.data()); + for (size_t i = 0; i < tensor.get_size(); i++) { + ptr[i] = start; + start += step; + } + return tensor; +} + +void RoPETestChatGLM::generate_inputs(const std::vector& targetInputStaticShapes) { + const auto& funcInputs = function->inputs(); + + auto& input_shape = targetInputStaticShapes[0]; + auto seq_length = input_shape[0]; + + ov::Tensor t_input = + utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768); + ov::Tensor t_cos_sin_cache = + utils::create_and_fill_tensor(funcInputs[1].get_element_type(), {32768, 32, 2}, 2, -1.0f, 32768); + ov::Tensor t_position_ids = create_i32_tensor(ov::Shape({1, seq_length}), 15); + + inputs.clear(); + inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input}); + inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos_sin_cache}); + inputs.insert({funcInputs[2].get_node_shared_ptr(), t_position_ids}); +} + +void RoPETestChatGLM::SetUp() { + targetDevice = this->GetParam(); + + const int batch = 2; + const int seq_length = 7; + const int num_head = 32; + const int rotary_dims = 64; + + InputShape inpShape = {{-1, batch, 4096 + 256 + 256}, {{seq_length, batch, 4096 + 256 + 256}}}; + init_input_shapes({inpShape}); + function = buildROPE_ChatGLM(batch, num_head, rotary_dims); +} + +std::string RoPETestChatGLM::getTestCaseName(const testing::TestParamInfo& obj) { + std::string targetDevice = obj.param; + std::ostringstream result; + result << "targetDevice=" << targetDevice; + return result.str(); +} + +std::shared_ptr RoPETestQwen7b::buildROPE_QWen7b(bool specialReshape) { + auto input = + std::make_shared(ov::element::f32, PartialShape{-1, -1, 4096 + 4096 + 4096}); + auto cos_cache = std::make_shared(ov::element::f32, PartialShape{1, -1, 1, 128}); + auto sin_cache = std::make_shared(ov::element::f32, PartialShape{1, -1, 1, 128}); + + auto ListUnpack_389_VariadicSplit = makeOP({input, 2, {4096, 4096, -1}}); + auto view_Reshape = makeOP({ListUnpack_389_VariadicSplit->output(0), {0, 0, 32, 128}}, + {{"special_zero", true}}); + auto size_ShapeOf_414 = makeOP({view_Reshape}, {{"output_type", "i32"}}); + auto size_Gather_416 = makeOP({size_ShapeOf_414, 1, 0}, {{"batch_dims", 0}}); + auto neg_Multiply = makeOP({size_Gather_416, -1}, {{"auto_broadcast", "numpy"}}); + auto slice_Unsqueeze_422 = makeOP({neg_Multiply, 0}); + auto ScatterUpdate_261437 = makeOP({{0, 0}, {1}, slice_Unsqueeze_422, {0}}); + auto slice_Slice_425 = makeOP({cos_cache, ScatterUpdate_261437, {0ll, LLONG_MAX}, {1, 1}}, + {{"begin_mask", {1, 0}}, + {"end_mask", {1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto slice_Slice_431 = + makeOP({slice_Slice_425, {0, 0, 0}, {0ll, 0ll, LLONG_MAX}, {1, 1, 1}}, + {{"begin_mask", {1, 1, 0}}, + {"end_mask", {1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto slice_Slice_437 = + makeOP({slice_Slice_431, {0, 0, 0, 0}, {0ll, 0ll, 0ll, LLONG_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto size_ShapeOf_462 = makeOP({slice_Slice_437}, {{"output_type", "i32"}}); + auto size_Gather_464 = makeOP({size_ShapeOf_462, {3}, 0}, {{"batch_dims", 0}}); + auto ScatterUpdate_261533 = makeOP({{0, 0, 0, 0}, {3}, size_Gather_464, {0}}); + auto slice_Slice_470 = + makeOP({view_Reshape, {0, 0, 0, 0}, ScatterUpdate_261533, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto mul_Multiply = makeOP({slice_Slice_470, slice_Slice_437}, {{"auto_broadcast", "numpy"}}); + auto size_ShapeOf_478 = makeOP({slice_Slice_470}, {{"output_type", "i32"}}); + auto Gather_239390 = makeOP({size_ShapeOf_478, {0, 1, 2}, 0}, {{"batch_dims", 0}}); + auto size_Gather_489 = makeOP({size_ShapeOf_478, 3, 0}, {{"batch_dims", 0}}); + auto floor_divide_Divide = + makeOP({size_Gather_489, 2}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto floor_divide_Floor = makeOP({floor_divide_Divide}); + auto ListConstruct_493_Reshape_3 = + makeOP({floor_divide_Floor, {-1}}, {{"special_zero", false}}); + auto ListConstruct_493_Concat = + makeOP({Gather_239390, {2}, ListConstruct_493_Reshape_3}, {{"axis", 0}}); + std::shared_ptr reshape_Reshape = nullptr; + if (specialReshape) { + reshape_Reshape = makeOP({slice_Slice_470, {0, 0, 32, 2, 64}}, {{"special_zero", true}}); + } else { + reshape_Reshape = + makeOP({slice_Slice_470, ListConstruct_493_Concat}, {{"special_zero", false}}); + } + auto ListUnpack_496_Split = makeOP({reshape_Reshape, -2}, {{"num_splits", 2}}); + auto ListUnpack_496_Squeeze_0 = makeOP({ListUnpack_496_Split->output(1), -2}); + auto Constant_296840_compressed = makeConst(element::f16, + ov::Shape({ + 1, + 1, + 1, + 1, + }), + {-1}); + auto Constant_296840 = makeOP({Constant_296840_compressed}, {{"destination_type", "f32"}}); + auto neg_Multiply_499 = + makeOP({ListUnpack_496_Squeeze_0, Constant_296840}, {{"auto_broadcast", "numpy"}}); + auto ListUnpack_496_Squeeze = makeOP({ListUnpack_496_Split->output(0), -2}); + auto cat_Concat = makeOP({neg_Multiply_499, ListUnpack_496_Squeeze}, {{"axis", -1}}); + auto slice_Slice_449 = makeOP({sin_cache, ScatterUpdate_261437, {0ll, LLONG_MAX}, {1, 1}}, + {{"begin_mask", {1, 0}}, + {"end_mask", {1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto slice_Slice_455 = + makeOP({slice_Slice_449, {0, 0, 0}, {0ll, 0ll, LLONG_MAX}, {1, 1, 1}}, + {{"begin_mask", {1, 1, 0}}, + {"end_mask", {1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto slice_Slice_461 = + makeOP({slice_Slice_455, {0, 0, 0, 0}, {0ll, 0ll, 0ll, LLONG_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto mul_Multiply_503 = makeOP({cat_Concat, slice_Slice_461}, {{"auto_broadcast", "numpy"}}); + auto add_Add = makeOP({mul_Multiply, mul_Multiply_503}, {{"auto_broadcast", "numpy"}}); + return std::make_shared(ov::NodeVector{add_Add}, ov::ParameterVector{input, cos_cache, sin_cache}); +} + +void RoPETestQwen7b::generate_inputs(const std::vector& targetInputStaticShapes) { + const auto& funcInputs = function->inputs(); + + auto& input_shape = targetInputStaticShapes[0]; + + ov::Tensor t_input = + utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768); + ov::Tensor t_cos_cache = + utils::create_and_fill_tensor(funcInputs[1].get_element_type(), {1, 4096, 1, 128}, 2, -1.0f, 32768); + ov::Tensor t_sin_cache = + utils::create_and_fill_tensor(funcInputs[1].get_element_type(), {1, 4096, 1, 128}, 2, -1.0f, 32768); + + inputs.clear(); + inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input}); + inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos_cache}); + inputs.insert({funcInputs[2].get_node_shared_ptr(), t_sin_cache}); +} + +void RoPETestQwen7b::SetUp() { + bool specialReshape; + std::tie(specialReshape, targetDevice) = this->GetParam(); + const int batch = 2; + const int seq_length = 7; + InputShape inpShape = {{batch, -1, 4096 + 4096 + 4096}, {{batch, seq_length, 4096 + 4096 + 4096}}}; + init_input_shapes({inpShape}); + function = buildROPE_QWen7b(specialReshape); +} + +std::string RoPETestQwen7b::getTestCaseName(const testing::TestParamInfo>& obj) { + bool specialReshape; + std::string targetDevice; + std::tie(specialReshape, targetDevice) = obj.param; + std::ostringstream result; + result << "specialReshape=" << specialReshape << "_" + << "targetDevice=" << targetDevice; + return result.str(); +} + +std::shared_ptr RoPETestGPTJ::buildROPE_GPTJ(int num_head, + int hidden_dims, + int rotary_dims, + bool hasShapeOf) { + auto int32_max = std::numeric_limits::max(); + auto input = + std::make_shared(ov::element::f32, PartialShape{-1, -1, num_head, hidden_dims}); + auto sincos = std::make_shared(ov::element::f32, PartialShape{-1, -1, rotary_dims}); + + auto slice_Slice_965 = + makeOP({input, {0, 0, 0, 0}, {0, 0, 0, rotary_dims}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + slice_Slice_965->set_friendly_name("slice_Slice_965"); + + auto varsplit = makeOP({sincos, -1, {rotary_dims / 2, -1}}); + varsplit->set_output_size(2); + varsplit->set_friendly_name("varsplit"); + auto unsqueeze_sin = makeOP({varsplit->output(0), 2}); + auto unsqueeze_cos = makeOP({varsplit->output(1), 2}); + std::vector gather_idx(rotary_dims, 1); + int32_t v = 0; + for (size_t i = 0; i < gather_idx.size(); i += 2, v++) { + gather_idx[i] = v; + gather_idx[i + 1] = v; + } + + auto const_idx = makeConst(ov::element::i32, ov::Shape({static_cast(rotary_dims)}), gather_idx); + auto constant_155588 = makeConst(element::f32, + ov::Shape({ + 1, + 1, + 1, + 1, + }), + {-1.000000f}); + auto repeat_interleave_sin = makeOP({unsqueeze_sin, const_idx, 3}, {{"batch_dims", 0}}); + auto repeat_interleave_cos = makeOP({unsqueeze_cos, const_idx, 3}, {{"batch_dims", 0}}); + repeat_interleave_sin->set_friendly_name("repeat_interleave_sin"); + repeat_interleave_cos->set_friendly_name("repeat_interleave_cos"); + // x interleave (-x[:,:,:, 1::2], x[:,:,:, 0::2]) + auto slice_Slice_1174 = + makeOP({slice_Slice_965, {0, 0, 0, 1}, {0, 0, 0, int32_max}, {1, 1, 1, 2}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto neg_Multiply_1177 = + makeOP({slice_Slice_1174, constant_155588}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_65524 = makeOP({neg_Multiply_1177, -1}); + + auto slice_Slice_1168 = + makeOP({slice_Slice_965, {0, 0, 0, 0}, {0, 0, 0, int32_max}, {1, 1, 1, 2}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Unsqueeze_65525 = makeOP({slice_Slice_1168, -1}); + auto stack_1182 = makeOP({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}}); + auto flatten_Reshape_1198 = + makeOP({stack_1182, {0, 0, num_head, rotary_dims}}, {{"special_zero", true}}); + // x*cos [B,L,H,ndims] + auto mul_cos = + makeOP({slice_Slice_965, repeat_interleave_cos}, {{"auto_broadcast", "numpy"}}); + mul_cos->set_friendly_name("mul_cos"); + auto mul_sin = + makeOP({flatten_Reshape_1198, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}}); + // *cos + *sin + auto rotary_emb = makeOP({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}}); + + auto slice_Slice_971 = + makeOP({input, {0, 0, 0, rotary_dims}, {0, 0, 0, int32_max}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat_1211 = makeOP({rotary_emb, slice_Slice_971}, {{"axis", -1}}); + auto permute_Transpose_1213 = makeOP({cat_Concat_1211, {0, 2, 1, 3}}); + ov::NodeVector model_output = {permute_Transpose_1213}; + if (hasShapeOf) { + auto shapeOf = makeOP({rotary_emb}, {{"output_type", "i32"}}); + auto gather = makeOP({shapeOf, {1}, 0}, {{"batch_dims", 0}}); + model_output.push_back(gather); + } + return std::make_shared(model_output, ov::ParameterVector{input, sincos}); +} + +void RoPETestGPTJ::generate_inputs(const std::vector& targetInputStaticShapes) { + const auto& funcInputs = function->inputs(); + + auto& input_shape = targetInputStaticShapes[0]; + auto& sincos_shape = targetInputStaticShapes[1]; + ov::Tensor t_input = + utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768); + ov::Tensor t_cos_sin_cache = + utils::create_and_fill_tensor(funcInputs[1].get_element_type(), sincos_shape, 2, -1.0f, 32768); + + inputs.clear(); + inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input}); + inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos_sin_cache}); +} + +std::string RoPETestGPTJ::getTestCaseName(const testing::TestParamInfo>& obj) { + bool hasShapeOf; + std::string targetDevice; + std::tie(hasShapeOf, targetDevice) = obj.param; + std::ostringstream result; + result << "hasShapeOf=" << hasShapeOf << "_" + << "targetDevice=" << targetDevice; + return result.str(); +} + +void RoPETestGPTJ::SetUp() { + bool hasShapeOf; + std::tie(hasShapeOf, targetDevice) = this->GetParam(); + + const int batch = 2; + const int seq_length = 7; + const int num_head = 16; + const int hidden_dims = 256; + const int rotary_dims = 64; + + InputShape input = {{batch, seq_length, num_head, hidden_dims}, {{batch, seq_length, num_head, hidden_dims}}}; + InputShape sincos = {{batch, seq_length, rotary_dims}, {{batch, seq_length, rotary_dims}}}; + init_input_shapes({input, sincos}); + function = buildROPE_GPTJ(num_head, hidden_dims, rotary_dims, hasShapeOf); +} + +} // namespace test +} // namespace ov