Skip to content

Commit

Permalink
[GPU] Added RoPE support for ChatGLM and Qwen (openvinotoolkit#24756)
Browse files Browse the repository at this point in the history
### Details:
- Added support RoPE for ChatGLM and Qwen models
- Moved and refactored RoPE functional tests

### Tickets:
 - *[119150](https://jira.devtools.intel.com/browse/CVS-119150)*
  • Loading branch information
Lyamin-Roman authored Jun 3, 2024
1 parent ba8d6c5 commit df6a258
Show file tree
Hide file tree
Showing 28 changed files with 1,475 additions and 631 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,8 @@ class PatternValidator {
return false;
}

if (ele_type == ov::element::i32 || ele_type == ov::element::f32 || ele_type == ov::element::i64) {
if (ele_type == ov::element::i32 || ele_type == ov::element::i64 || ele_type == ov::element::f16 ||
ele_type == ov::element::f32) {
auto observed = constop->cast_vector<double>();
for (size_t i = 0; i < symbols.size(); i++)
detail::add_symbol_observed(sov, symbols[i], observed[i]);
Expand Down Expand Up @@ -1259,6 +1260,15 @@ class PatternValidator {
}
}

if (pconst_node->get_output_element_type(0).is_real() &&
vconst_node->get_output_element_type(0).is_real()) {
auto p_values = pconst_node->cast_vector<float>();
auto v_values = vconst_node->cast_vector<float>();
if (p_values == v_values) {
continue;
}
}

_VERBOSE_LOG("expecting Constant of type ",
pconst_node->get_output_element_type(0),
" but got ",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/util/shape_of_base.hpp"
#include "openvino/opsets/opset1.hpp"
#include "openvino/opsets/opset6.hpp"
#include "openvino/opsets/opset8.hpp"
Expand Down Expand Up @@ -415,9 +416,9 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id) {
MATCHER_SCOPE(RoPEFusionChatGLM);

auto qkv_linear = makePattern("f32[?,?,?]"); // f32[seq_length, batch_size, 4608]
auto qkv_linear = makePattern("[?,?,?]"); // [seq_length, batch_size, 4608]
auto seq_length = makePattern("i32[1]");
auto cos_sin_cache = makePattern("f32[?,?,?,?]"); // [max_pos_embeddings, batch_size, 32, 2]
auto cos_sin_cache = makePattern("[?,?,?,?]"); // [max_pos_embeddings, batch_size, 32, 2]

auto ndims = ov::gen_pattern::Symbol("ndims");
auto head_cnt = ov::gen_pattern::Symbol("head_cnt");
Expand Down Expand Up @@ -538,9 +539,9 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
MATCHER_SCOPE(RoPEFusionQwen);

// rotary_emb_cos & rotary_emb_sin are sliced by present kv-length (past-kv-length + cur_len)
auto rotary_emb_cos = makePattern("f32[1,?,1,?]"); // [1,..4096,1,128]
auto rotary_emb_sin = makePattern("f32[1,?,1,?]"); // [1,..4096,1,128]
auto qkv_proj = makePattern("f32[?,?,?]"); // f32[?,?,12288]
auto rotary_emb_cos = makePattern("[1,?,1,?]"); // [1,..4096,1,128]
auto rotary_emb_sin = makePattern("[1,?,1,?]"); // [1,..4096,1,128]
auto qkv_proj = makePattern("[?,?,?]"); // [?,?,12288]

auto head_cnt = ov::gen_pattern::Symbol("head_cnt");
auto head_size = ov::gen_pattern::Symbol("head_size");
Expand All @@ -559,8 +560,8 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
auto Multiply_567524 = makePattern<opset1::Multiply>({ShapeOf_485735, {-1}}, {{"auto_broadcast", "numpy"}});
auto Gather_377635 = makePattern<opset8::Gather>({Multiply_567524, {1}, 0}, {{"batch_dims", 0}});

auto input_ids = makePattern("i32[?,?]"); // [batch, length]
auto ShapeOf_409241 = makePattern<opset1::ShapeOf>({input_ids}, {});
auto input_ids = makePattern(); // [batch, length]
auto ShapeOf_409241 = makePattern<ov::op::util::ShapeOfBase>({input_ids}, {});
auto Gather_311651 = makePattern<opset8::Gather>({ShapeOf_409241, {1}, 0}, {{"batch_dims", 0}});
auto neg_Multiply = makePattern<opset1::Multiply>({Gather_311651, {-1}}, {{"auto_broadcast", "numpy"}});

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ std::vector<std::string> disabledTestPatterns() {
retVector.emplace_back(R"(smoke_VariableState/OVInferRequestVariableStateTest.*)");
// Issue: 141705
retVector.emplace_back(R"(.*smoke_arm_Deconv_2D_Planar_FP16/DeconvolutionLayerCPUTest.*INFERENCE_PRECISION_HINT=f16.*)");

retVector.emplace_back(R"(.*smoke_RoPETest.*)");
#endif

#if defined(OPENVINO_ARCH_ARM)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "subgraph_tests/rotary_pos_emb.hpp"

namespace ov {
namespace test {

INSTANTIATE_TEST_SUITE_P(smoke_RoPETestLlama2,
RoPETestLlama2,
::testing::Values(ov::test::utils::DEVICE_CPU),
RoPETestLlama2::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM,
RoPETestChatGLM,
::testing::Values(ov::test::utils::DEVICE_CPU),
RoPETestChatGLM::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwen7b,
RoPETestQwen7b,
::testing::Combine(::testing::Values(true, false),
::testing::Values(ov::test::utils::DEVICE_CPU)),
RoPETestQwen7b::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_RoPETestGPTJ,
RoPETestGPTJ,
::testing::Combine(::testing::Values(true, false),
::testing::Values(ov::test::utils::DEVICE_CPU)),
RoPETestGPTJ::getTestCaseName);
} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,4 @@ REGISTER_FACTORY(internal, Convolution);
REGISTER_FACTORY(internal, Placeholder);
REGISTER_FACTORY(internal, SDPA);
REGISTER_FACTORY(internal, IndirectSDPA);
REGISTER_FACTORY(internal, RoPE);
92 changes: 92 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once
#include "primitive.hpp"
#include "ov_ops/rotary_positional_embeddings.hpp"

namespace cldnn {
using RoPE = ov::op::internal::RoPE;

/// @brief Rotary Position Embedding primitive
struct rope : public primitive_base<rope> {
CLDNN_DECLARE_PRIMITIVE(rope);

rope() : primitive_base("", {}) {}

/// @brief Constructs rope primitive
/// @param id This primitive id
/// @param inputs Inputs primitive ids
/// @param config Specific RoPE config
rope(const primitive_id& id,
const std::vector<input_info>& inputs,
const RoPE::Config& config,
const padding& output_padding = padding())
: primitive_base(id, inputs, {output_padding}),
config(config) {}

RoPE::Config config;

size_t hash() const override {
size_t seed = primitive::hash();
seed = hash_combine(seed, config.gather_position_arg_id);
seed = hash_combine(seed, config.head_cnt);
seed = hash_combine(seed, config.head_size);
seed = hash_combine(seed, config.input_trans0213);
seed = hash_combine(seed, config.is_chatglm);
seed = hash_combine(seed, config.is_interleaved);
seed = hash_combine(seed, config.is_qwen);
seed = hash_combine(seed, config.rotary_ndims);
seed = hash_combine(seed, config.slice_start);
seed = hash_combine(seed, config.slice_stop);
return seed;
}

bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;

auto rhs_casted = downcast<const rope>(rhs);

return config.gather_position_arg_id == rhs_casted.config.gather_position_arg_id &&
config.head_cnt == rhs_casted.config.head_cnt &&
config.head_size == rhs_casted.config.head_size &&
config.input_trans0213 == rhs_casted.config.input_trans0213 &&
config.is_chatglm == rhs_casted.config.is_chatglm &&
config.is_interleaved == rhs_casted.config.is_interleaved &&
config.is_qwen == rhs_casted.config.is_qwen &&
config.rotary_ndims == rhs_casted.config.rotary_ndims &&
config.slice_start == rhs_casted.config.slice_start &&
config.slice_stop == rhs_casted.config.slice_stop;
}

void save(BinaryOutputBuffer& ob) const override {
primitive_base<rope>::save(ob);
ob << config.gather_position_arg_id;
ob << config.head_cnt;
ob << config.head_size;
ob << config.input_trans0213;
ob << config.is_chatglm;
ob << config.is_interleaved;
ob << config.is_qwen;
ob << config.rotary_ndims;
ob << config.slice_start;
ob << config.slice_stop;
}

void load(BinaryInputBuffer& ib) override {
primitive_base<rope>::load(ib);
ib >> config.gather_position_arg_id;
ib >> config.head_cnt;
ib >> config.head_size;
ib >> config.input_trans0213;
ib >> config.is_chatglm;
ib >> config.is_interleaved;
ib >> config.is_qwen;
ib >> config.rotary_ndims;
ib >> config.slice_start;
ib >> config.slice_stop;
}
};
} // namespace cldnn
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ void register_implementations() {
REGISTER_OCL(unique_count);
REGISTER_OCL(unique_gather);
REGISTER_OCL(scaled_dot_product_attention);
REGISTER_OCL(rope);
}

} // namespace ocl
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
#include "intel_gpu/primitives/unique.hpp"
#include "intel_gpu/primitives/kv_cache.hpp"
#include "intel_gpu/primitives/scaled_dot_product_attention.hpp"
#include "intel_gpu/primitives/rope.hpp"

namespace cldnn {
namespace ocl {
Expand Down Expand Up @@ -174,6 +175,7 @@ REGISTER_OCL(eye);
REGISTER_OCL(unique_count);
REGISTER_OCL(unique_gather);
REGISTER_OCL(scaled_dot_product_attention);
REGISTER_OCL(rope);

#undef REGISTER_OCL

Expand Down
88 changes: 88 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "primitive_base.hpp"

#include "rope_inst.h"
#include "rope/rope_kernel_selector.h"
#include "rope/rope_kernel_ref.h"

namespace cldnn {
namespace ocl {

struct rope_impl : typed_primitive_impl_ocl<rope> {
using parent = typed_primitive_impl_ocl<rope>;
using parent::parent;
using kernel_selector_t = kernel_selector::rope_kernel_selector;
using kernel_params_t = kernel_selector::rope_params;

DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::rope_impl);

std::unique_ptr<primitive_impl> clone() const override {
return make_unique<rope_impl>(*this);
}

void load(BinaryInputBuffer& ib) override {
parent::load(ib);
if (is_dynamic()) {
auto& kernel_selector = kernel_selector_t::Instance();
auto kernel_impl = kernel_selector.GetImplementation(_kernel_data.kernelName);
kernel_impl->GetUpdateDispatchDataFunc(_kernel_data);
}
}

static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
const auto& primitive = impl_param.typed_desc<rope>();
auto params = get_default_params<kernel_selector::rope_params>(impl_param, is_shape_agnostic);

params.head_cnt = primitive->config.head_cnt;
params.head_size = primitive->config.head_size;
params.rotary_ndims = primitive->config.rotary_ndims;

params.slice_start = primitive->config.slice_start;
params.slice_stop = primitive->config.slice_stop;

params.axis = primitive->config.is_qwen || primitive->config.is_chatglm ? 2 : 3;
params.num_of_inputs = primitive->config.is_chatglm || primitive->config.is_interleaved ? 2 : 3;

params.is_qwen = primitive->config.is_qwen;
params.is_chatglm = primitive->config.is_chatglm;

for (size_t i = 1; i < impl_param.input_layouts.size(); ++i) {
params.inputs.push_back(convert_data_tensor(impl_param.get_input_layout(i)));
}
return params;
}

void update_dispatch_data(const kernel_impl_params& impl_param) override {
auto kernel_params = get_kernel_params(impl_param, true);
(_kernel_data.update_dispatch_data_func)(kernel_params, _kernel_data);
}
};

namespace detail {

attach_rope_impl::attach_rope_impl() {
auto types = {
data_types::f32,
data_types::f16
};

auto formats = {
format::bfyx
};

implementation_map<rope>::add(impl_types::ocl,
shape_types::any,
typed_primitive_impl_ocl<rope>::create<rope_impl>,
types,
formats);
}

} // namespace detail
} // namespace ocl
} // namespace cldnn

BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::rope_impl)
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::rope)
39 changes: 39 additions & 0 deletions src/plugins/intel_gpu/src/graph/include/rope_inst.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "intel_gpu/primitives/rope.hpp"
#include "primitive_inst.h"

#include <string>

namespace cldnn {
template <>
struct typed_program_node<rope> : public typed_program_node_base<rope> {
using parent = typed_program_node_base<rope>;

public:
using parent::parent;

program_node& input(size_t idx = 0) const { return get_dependency(idx); }
std::vector<size_t> get_shape_infer_dependencies() const override { return {}; }
};

using rope_node = typed_program_node<rope>;

template <>
class typed_primitive_inst<rope> : public typed_primitive_inst_base<rope> {
using parent = typed_primitive_inst_base<rope>;
using parent::parent;

public:
template<typename ShapeType>
static std::vector<layout> calc_output_layouts(const rope_node& /*node*/, const kernel_impl_params& impl_param);
static layout calc_output_layout(rope_node const& node, kernel_impl_params const& impl_param);
static std::string to_string(rope_node const& node);
};

using rope_inst = typed_primitive_inst<rope>;
} // namespace cldnn
Loading

0 comments on commit df6a258

Please sign in to comment.