forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[GPU] Added RoPE support for ChatGLM and Qwen (openvinotoolkit#24756)
### Details: - Added support RoPE for ChatGLM and Qwen models - Moved and refactored RoPE functional tests ### Tickets: - *[119150](https://jira.devtools.intel.com/browse/CVS-119150)*
- Loading branch information
1 parent
ba8d6c5
commit df6a258
Showing
28 changed files
with
1,475 additions
and
631 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
621 changes: 0 additions & 621 deletions
621
src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/rotary_pos_emb.cpp
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 32 additions & 0 deletions
32
...ugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "subgraph_tests/rotary_pos_emb.hpp" | ||
|
||
namespace ov { | ||
namespace test { | ||
|
||
INSTANTIATE_TEST_SUITE_P(smoke_RoPETestLlama2, | ||
RoPETestLlama2, | ||
::testing::Values(ov::test::utils::DEVICE_CPU), | ||
RoPETestLlama2::getTestCaseName); | ||
|
||
INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM, | ||
RoPETestChatGLM, | ||
::testing::Values(ov::test::utils::DEVICE_CPU), | ||
RoPETestChatGLM::getTestCaseName); | ||
|
||
INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwen7b, | ||
RoPETestQwen7b, | ||
::testing::Combine(::testing::Values(true, false), | ||
::testing::Values(ov::test::utils::DEVICE_CPU)), | ||
RoPETestQwen7b::getTestCaseName); | ||
|
||
INSTANTIATE_TEST_SUITE_P(smoke_RoPETestGPTJ, | ||
RoPETestGPTJ, | ||
::testing::Combine(::testing::Values(true, false), | ||
::testing::Values(ov::test::utils::DEVICE_CPU)), | ||
RoPETestGPTJ::getTestCaseName); | ||
} // namespace test | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 92 additions & 0 deletions
92
src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
// Copyright (C) 2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
#include "primitive.hpp" | ||
#include "ov_ops/rotary_positional_embeddings.hpp" | ||
|
||
namespace cldnn { | ||
using RoPE = ov::op::internal::RoPE; | ||
|
||
/// @brief Rotary Position Embedding primitive | ||
struct rope : public primitive_base<rope> { | ||
CLDNN_DECLARE_PRIMITIVE(rope); | ||
|
||
rope() : primitive_base("", {}) {} | ||
|
||
/// @brief Constructs rope primitive | ||
/// @param id This primitive id | ||
/// @param inputs Inputs primitive ids | ||
/// @param config Specific RoPE config | ||
rope(const primitive_id& id, | ||
const std::vector<input_info>& inputs, | ||
const RoPE::Config& config, | ||
const padding& output_padding = padding()) | ||
: primitive_base(id, inputs, {output_padding}), | ||
config(config) {} | ||
|
||
RoPE::Config config; | ||
|
||
size_t hash() const override { | ||
size_t seed = primitive::hash(); | ||
seed = hash_combine(seed, config.gather_position_arg_id); | ||
seed = hash_combine(seed, config.head_cnt); | ||
seed = hash_combine(seed, config.head_size); | ||
seed = hash_combine(seed, config.input_trans0213); | ||
seed = hash_combine(seed, config.is_chatglm); | ||
seed = hash_combine(seed, config.is_interleaved); | ||
seed = hash_combine(seed, config.is_qwen); | ||
seed = hash_combine(seed, config.rotary_ndims); | ||
seed = hash_combine(seed, config.slice_start); | ||
seed = hash_combine(seed, config.slice_stop); | ||
return seed; | ||
} | ||
|
||
bool operator==(const primitive& rhs) const override { | ||
if (!compare_common_params(rhs)) | ||
return false; | ||
|
||
auto rhs_casted = downcast<const rope>(rhs); | ||
|
||
return config.gather_position_arg_id == rhs_casted.config.gather_position_arg_id && | ||
config.head_cnt == rhs_casted.config.head_cnt && | ||
config.head_size == rhs_casted.config.head_size && | ||
config.input_trans0213 == rhs_casted.config.input_trans0213 && | ||
config.is_chatglm == rhs_casted.config.is_chatglm && | ||
config.is_interleaved == rhs_casted.config.is_interleaved && | ||
config.is_qwen == rhs_casted.config.is_qwen && | ||
config.rotary_ndims == rhs_casted.config.rotary_ndims && | ||
config.slice_start == rhs_casted.config.slice_start && | ||
config.slice_stop == rhs_casted.config.slice_stop; | ||
} | ||
|
||
void save(BinaryOutputBuffer& ob) const override { | ||
primitive_base<rope>::save(ob); | ||
ob << config.gather_position_arg_id; | ||
ob << config.head_cnt; | ||
ob << config.head_size; | ||
ob << config.input_trans0213; | ||
ob << config.is_chatglm; | ||
ob << config.is_interleaved; | ||
ob << config.is_qwen; | ||
ob << config.rotary_ndims; | ||
ob << config.slice_start; | ||
ob << config.slice_stop; | ||
} | ||
|
||
void load(BinaryInputBuffer& ib) override { | ||
primitive_base<rope>::load(ib); | ||
ib >> config.gather_position_arg_id; | ||
ib >> config.head_cnt; | ||
ib >> config.head_size; | ||
ib >> config.input_trans0213; | ||
ib >> config.is_chatglm; | ||
ib >> config.is_interleaved; | ||
ib >> config.is_qwen; | ||
ib >> config.rotary_ndims; | ||
ib >> config.slice_start; | ||
ib >> config.slice_stop; | ||
} | ||
}; | ||
} // namespace cldnn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "primitive_base.hpp" | ||
|
||
#include "rope_inst.h" | ||
#include "rope/rope_kernel_selector.h" | ||
#include "rope/rope_kernel_ref.h" | ||
|
||
namespace cldnn { | ||
namespace ocl { | ||
|
||
struct rope_impl : typed_primitive_impl_ocl<rope> { | ||
using parent = typed_primitive_impl_ocl<rope>; | ||
using parent::parent; | ||
using kernel_selector_t = kernel_selector::rope_kernel_selector; | ||
using kernel_params_t = kernel_selector::rope_params; | ||
|
||
DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::rope_impl); | ||
|
||
std::unique_ptr<primitive_impl> clone() const override { | ||
return make_unique<rope_impl>(*this); | ||
} | ||
|
||
void load(BinaryInputBuffer& ib) override { | ||
parent::load(ib); | ||
if (is_dynamic()) { | ||
auto& kernel_selector = kernel_selector_t::Instance(); | ||
auto kernel_impl = kernel_selector.GetImplementation(_kernel_data.kernelName); | ||
kernel_impl->GetUpdateDispatchDataFunc(_kernel_data); | ||
} | ||
} | ||
|
||
static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) { | ||
const auto& primitive = impl_param.typed_desc<rope>(); | ||
auto params = get_default_params<kernel_selector::rope_params>(impl_param, is_shape_agnostic); | ||
|
||
params.head_cnt = primitive->config.head_cnt; | ||
params.head_size = primitive->config.head_size; | ||
params.rotary_ndims = primitive->config.rotary_ndims; | ||
|
||
params.slice_start = primitive->config.slice_start; | ||
params.slice_stop = primitive->config.slice_stop; | ||
|
||
params.axis = primitive->config.is_qwen || primitive->config.is_chatglm ? 2 : 3; | ||
params.num_of_inputs = primitive->config.is_chatglm || primitive->config.is_interleaved ? 2 : 3; | ||
|
||
params.is_qwen = primitive->config.is_qwen; | ||
params.is_chatglm = primitive->config.is_chatglm; | ||
|
||
for (size_t i = 1; i < impl_param.input_layouts.size(); ++i) { | ||
params.inputs.push_back(convert_data_tensor(impl_param.get_input_layout(i))); | ||
} | ||
return params; | ||
} | ||
|
||
void update_dispatch_data(const kernel_impl_params& impl_param) override { | ||
auto kernel_params = get_kernel_params(impl_param, true); | ||
(_kernel_data.update_dispatch_data_func)(kernel_params, _kernel_data); | ||
} | ||
}; | ||
|
||
namespace detail { | ||
|
||
attach_rope_impl::attach_rope_impl() { | ||
auto types = { | ||
data_types::f32, | ||
data_types::f16 | ||
}; | ||
|
||
auto formats = { | ||
format::bfyx | ||
}; | ||
|
||
implementation_map<rope>::add(impl_types::ocl, | ||
shape_types::any, | ||
typed_primitive_impl_ocl<rope>::create<rope_impl>, | ||
types, | ||
formats); | ||
} | ||
|
||
} // namespace detail | ||
} // namespace ocl | ||
} // namespace cldnn | ||
|
||
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::rope_impl) | ||
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::rope) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// Copyright (C) 2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "intel_gpu/primitives/rope.hpp" | ||
#include "primitive_inst.h" | ||
|
||
#include <string> | ||
|
||
namespace cldnn { | ||
template <> | ||
struct typed_program_node<rope> : public typed_program_node_base<rope> { | ||
using parent = typed_program_node_base<rope>; | ||
|
||
public: | ||
using parent::parent; | ||
|
||
program_node& input(size_t idx = 0) const { return get_dependency(idx); } | ||
std::vector<size_t> get_shape_infer_dependencies() const override { return {}; } | ||
}; | ||
|
||
using rope_node = typed_program_node<rope>; | ||
|
||
template <> | ||
class typed_primitive_inst<rope> : public typed_primitive_inst_base<rope> { | ||
using parent = typed_primitive_inst_base<rope>; | ||
using parent::parent; | ||
|
||
public: | ||
template<typename ShapeType> | ||
static std::vector<layout> calc_output_layouts(const rope_node& /*node*/, const kernel_impl_params& impl_param); | ||
static layout calc_output_layout(rope_node const& node, kernel_impl_params const& impl_param); | ||
static std::string to_string(rope_node const& node); | ||
}; | ||
|
||
using rope_inst = typed_primitive_inst<rope>; | ||
} // namespace cldnn |
Oops, something went wrong.