diff --git a/src/plugins/intel_cpu/src/cpu_types.cpp b/src/plugins/intel_cpu/src/cpu_types.cpp index aa75c202c112a1..a4e8b140415e24 100644 --- a/src/plugins/intel_cpu/src/cpu_types.cpp +++ b/src/plugins/intel_cpu/src/cpu_types.cpp @@ -214,6 +214,7 @@ static const TypeToNameMap& get_type_to_name_tbl() { { "Unique", Type::Unique}, { "Ngram", Type::Ngram}, { "ScaledDotProductAttention", Type::ScaledDotProductAttention}, + { "RoPE", Type::RoPE}, }; return type_to_name_tbl; } @@ -328,6 +329,7 @@ std::string NameFromType(const Type type) { CASE(Unique); CASE(Ngram); CASE(ScaledDotProductAttention); + CASE(RoPE); CASE(Unknown); } #undef CASE diff --git a/src/plugins/intel_cpu/src/cpu_types.h b/src/plugins/intel_cpu/src/cpu_types.h index f7f40d2c1fca0e..cf214542b1b604 100644 --- a/src/plugins/intel_cpu/src/cpu_types.h +++ b/src/plugins/intel_cpu/src/cpu_types.h @@ -114,6 +114,7 @@ enum class Type { Unique, Ngram, ScaledDotProductAttention, + RoPE, }; enum class Algorithm { diff --git a/src/plugins/intel_cpu/src/nodes/rope.cpp b/src/plugins/intel_cpu/src/nodes/rope.cpp new file mode 100644 index 00000000000000..5ec1aaa2183104 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/rope.cpp @@ -0,0 +1,201 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "rope.h" + +#include +#include +#include +#include +#include +#include + +#include "common/bfloat16.hpp" +#include "common/cpu_memcpy.h" +#include "utils/plain_tensor.hpp" + +using namespace InferenceEngine; + +namespace ov { +namespace intel_cpu { +namespace node { + +RoPE::RoPE(const std::shared_ptr& op, const GraphContext::CPtr context) + : Node(op, context, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { + std::string errorMessage; + if (!isSupportedOperation(op, errorMessage)) { + OPENVINO_THROW("CPU: " + errorMessage); + } + + const auto node = std::dynamic_pointer_cast(op); + m_config = node->get_config(); +} + +template +struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor { + void execute(dnnl::stream strm, + const RoPENode::Config& config, + const std::vector& inputs, + const std::vector& outputs) override { + ov::intel_cpu::PlainTensor t_src(inputs[0]); + ov::intel_cpu::PlainTensor t_cos(inputs[1]); + ov::intel_cpu::PlainTensor t_sin(inputs[2]); + ov::intel_cpu::PlainTensor t_dst(outputs[0]); + ov::intel_cpu::PlainTensor gather; + + if (config.slice_stop - config.slice_start > 0) { + t_src = t_src.slice(3, config.slice_start, config.slice_stop); + } + if (config.input_trans0213) { + t_src = t_src.permute({0, 2, 1, 3}); + } + if (config.gather_position_arg_id > 0) { + gather.reset(inputs[config.gather_position_arg_id]); + } + + auto batch_size = t_src.size(0); + auto head_cnt = t_src.size(1); + auto seq_len = t_src.size(2); + auto feature_size = t_src.size(3); + + auto rotary_dims = config.rotary_ndims; + auto half_rotary_dims = rotary_dims / 2; + + parallel_for3d(batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) { + auto cos_pos = p; + if (gather) { + if (gather.m_rank == 4) + cos_pos = gather.at({b, h, p, 0}, true); + else + cos_pos = gather.at({b, p}, true); + } + auto* src = &t_src.at({b, h, p, 0}); + auto* cos = &t_cos.at({b, h, cos_pos, 0}, true); + auto* sin = &t_sin.at({b, h, cos_pos, 0}, true); + auto* dst = &t_dst.at({b, h, p, 0}); + + size_t i = 0; + for (; i < half_rotary_dims; i++) { + dst[i] = cos[i] * src[i] + sin[i] * (-src[i + half_rotary_dims]); + } + for (; i < rotary_dims; i++) { + dst[i] = cos[i] * src[i] + sin[i] * (src[i - half_rotary_dims]); + } + for (; i < feature_size; i++) { + dst[i] = src[i]; + } + }); + } +}; + +template +struct RoPE::RoPEExecutorInterleaved : public RoPE::Executor { + void execute(dnnl::stream strm, + const RoPENode::Config& config, + const std::vector& inputs, + const std::vector& outputs) override { + ov::intel_cpu::PlainTensor t_src(inputs[0]); + ov::intel_cpu::PlainTensor t_sin_cos(inputs[1]); + ov::intel_cpu::PlainTensor t_dst(outputs[0]); + + auto batch_size = t_src.size(0); + auto seq_len = t_src.size(1); + auto head_cnt = t_src.size(2); + auto head_dims = t_src.size(3); + + auto rotary_dims = config.rotary_ndims; + auto half_rotary_dims = rotary_dims / 2; + parallel_for3d(batch_size, seq_len, head_cnt, [&](size_t b, size_t p, size_t h) { + auto* x = &t_src.at({b, p, h, 0}); + float* sin = &t_sin_cos.at({b, p, 0}, true); + float* cos = &t_sin_cos.at({b, p, half_rotary_dims}, true); + auto* dst = &t_dst.at({b, h, p, 0}); + + size_t i = 0; + for (size_t j = 0; i < rotary_dims; i += 2, j++) { + dst[i] = cos[j] * x[i] - sin[j] * x[i + 1]; + dst[i + 1] = cos[j] * x[i + 1] + sin[j] * x[i]; + } + for (; i < head_dims; i++) { + dst[i] = x[i]; + } + }); + } +}; + +void RoPE::initSupportedPrimitiveDescriptors() { + if (!supportedPrimitiveDescriptors.empty()) + return; + auto srcPrecision = getOriginalInputPrecisionAtPort(0); + + auto rtPrecision = srcPrecision; + auto CosSinPrecision = ov::element::f32; + + if (m_config.is_interleaved) { + OPENVINO_ASSERT(m_config.input_trans0213 == false); + OPENVINO_ASSERT(m_config.slice_start == 0); + OPENVINO_ASSERT(m_config.slice_stop == 0); + OPENVINO_ASSERT(m_config.gather_position_arg_id == 0); + if (rtPrecision == ov::element::bf16) { + m_executor = std::make_shared>(); + } else { + m_executor = std::make_shared>(); + rtPrecision = ov::element::f32; + } + } else { + if (rtPrecision == ov::element::bf16) { + m_executor = std::make_shared>(); + } else { + m_executor = std::make_shared>(); + rtPrecision = ov::element::f32; + } + } + + // initialize input ports + std::vector inPortConfigs; + inPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getInputShapeAtPort(0), false, -1); + inPortConfigs.emplace_back(LayoutType::ncsp, CosSinPrecision, getInputShapeAtPort(1), false, -1); + inPortConfigs.emplace_back(LayoutType::ncsp, CosSinPrecision, getInputShapeAtPort(2), false, -1); + if (m_config.gather_position_arg_id > 0) { + inPortConfigs.emplace_back(LayoutType::ncsp, + ov::element::i32, + getInputShapeAtPort(m_config.gather_position_arg_id), + false, + -1); + } + + // initialize output port + std::vector outPortConfigs; + outPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getOutputShapeAtPort(0), false, -1); + + addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any); +} + +void RoPE::execute(dnnl::stream strm) { + std::vector inputs(getParentEdges().size()), outputs(getChildEdges().size()); + for (size_t i = 0; i < inputs.size(); i++) { + inputs[i] = getParentEdgeAt(i)->getMemoryPtr(); + } + for (size_t i = 0; i < outputs.size(); i++) { + outputs[i] = getChildEdgeAt(i)->getMemoryPtr(); + } + m_executor->execute(strm, m_config, inputs, outputs); +} + +bool RoPE::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { + try { + const auto node = std::dynamic_pointer_cast(op); + if (!node) { + errorMessage = "Only RoPENode operation is supported"; + return false; + } + } catch (...) { + return false; + } + return true; +} + +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/rope.h b/src/plugins/intel_cpu/src/nodes/rope.h new file mode 100644 index 00000000000000..c1b2bbda3b3c0f --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/rope.h @@ -0,0 +1,54 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include +#include + +#include +#include +#include + +#include "transformations/cpu_opset/common/op/rope.hpp" + +namespace ov { +namespace intel_cpu { +namespace node { + +class RoPE : public Node { +public: + RoPE(const std::shared_ptr& op, const GraphContext::CPtr context); + + void getSupportedDescriptors() override {} + bool created() const override { + return getType() == Type::RoPE; + } + bool needPrepareParams() const override { + return false; + }; + void executeDynamicImpl(dnnl::stream strm) override { + execute(strm); + } + void initSupportedPrimitiveDescriptors() override; + void execute(dnnl::stream strm) override; + static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; + +private: + struct Executor { + virtual void execute(dnnl::stream strm, + const RoPENode::Config& config, + const std::vector& inputs, + const std::vector& outputs) = 0; + }; + template + struct RoPEExecutorRotateHalf; + template + struct RoPEExecutorInterleaved; + RoPENode::Config m_config; + std::shared_ptr m_executor; +}; + +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes_factory.cpp b/src/plugins/intel_cpu/src/nodes_factory.cpp index 963218314faaa3..bead297d033e79 100644 --- a/src/plugins/intel_cpu/src/nodes_factory.cpp +++ b/src/plugins/intel_cpu/src/nodes_factory.cpp @@ -94,6 +94,7 @@ #include "nodes/unique.hpp" #include "nodes/ngram.h" #include "nodes/scaled_attn.h" +#include "nodes/rope.h" namespace ov { namespace intel_cpu { @@ -181,6 +182,7 @@ Node::NodesFactory::NodesFactory() INTEL_CPU_NODE(Eye, Type::Eye); INTEL_CPU_NODE(Unique, Type::Unique); INTEL_CPU_NODE(Ngram, Type::Ngram); + INTEL_CPU_NODE(RoPE, Type::RoPE); INTEL_CPU_NODE(Interpolate, Type::Interpolate); INTEL_CPU_NODE(RandomUniform, Type::RandomUniform); INTEL_CPU_NODE(Reduce, Type::Reduce); diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/rope.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/rope.cpp new file mode 100644 index 00000000000000..8b4461b479ee7b --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/rope.cpp @@ -0,0 +1,50 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include "rope.hpp" + +#include + +#include "transformations/itt.hpp" + +ov::intel_cpu::RoPENode::RoPENode(const OutputVector& args, const Config& cfg) : Op(args), m_config(cfg) { + constructor_validate_and_infer_types(); +} + +std::shared_ptr ov::intel_cpu::RoPENode::clone_with_new_inputs( + const ngraph::OutputVector& new_args) const { + INTERNAL_OP_SCOPE(RoPENode_with_new_inputs); + check_new_args_count(this, new_args); + return std::make_shared(new_args, m_config); +} + +void ov::intel_cpu::RoPENode::validate_and_infer_types() { + INTERNAL_OP_SCOPE(RoPENode_validate_and_infer_types); + auto input_pshape = get_input_partial_shape(0); + auto input_slice_size = m_config.slice_stop - m_config.slice_start; + if (input_slice_size > 0) { + input_pshape[3] = input_slice_size; + } + if (m_config.input_trans0213) { + // transpose 0213 ([B,L,H,S]=>[B,H,L,S]) happens before RoPE + std::swap(input_pshape[2], input_pshape[1]); + } else if (m_config.is_interleaved) { + // transpose 0213 ([B,L,H,S]=>[B,H,L,S]) happens after RoPE + std::swap(input_pshape[2], input_pshape[1]); + } + + set_output_type(0, get_input_element_type(0), input_pshape); +} + +bool ov::intel_cpu::RoPENode::visit_attributes(ngraph::AttributeVisitor& visitor) { + INTERNAL_OP_SCOPE(RoPENode_visit_attributes); + visitor.start_structure("config"); + visitor.on_attribute("slice_start", m_config.slice_start); + visitor.on_attribute("slice_stop", m_config.slice_stop); + visitor.on_attribute("input_trans0213", m_config.input_trans0213); + visitor.on_attribute("is_interleaved", m_config.is_interleaved); + visitor.on_attribute("rotary_ndims", m_config.rotary_ndims); + visitor.on_attribute("gather_position_arg_id", m_config.gather_position_arg_id); + visitor.finish_structure(); + return true; +} diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/rope.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/rope.hpp new file mode 100644 index 00000000000000..cc6df7ec2b107f --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/rope.hpp @@ -0,0 +1,98 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace ov { +namespace intel_cpu { + +/** + * The operation performs rotary positional embedding operation described in: + * ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING by Jianlin Su + * + * the core computation is application of 2x2 rotation matrix on basis + * of pair of input states x[i0] & x[i1] to get the rotary embedded pair of output + * states y[i0] and y[i1]: + * + * suppose dimension of hidden states (of each attention head) is N and d of which + * are to be embedded (d <= N), non-embedded parts are copied into output. + * + * for i in 0...(d/2) + * if (is_interleaved) { + * // interleaving style of indexing + * i0 = i*2 + * i1 = i*2 + 1 + * } else { + * // rotate-half style of indexing + * i0 = i + * i1 = i + (d/2) + * } + * y[i0] = x[i0]*cos(m * xita[i]) - x[i1]*sin(m * xita[i]) + * y[i1] = x[i1]*cos(m * xita[i]) + x[i0]*sin(m * xita[i]) + * Note: m is token position of current input + * + * based on configuration, additional preprocessing steps maybe performed as well: + * - slicing last dimension of input tensor + * (when q/k/v are merged and only q or k part is to be extracted & embedded) + * - transpose input tensor + * (when q/k comes from fullyconnect has layout [batch, seq_len, head_cnt, head_dim] + * but output of RoPE is required to be of layout [batch, head_cnt, seq_length, head_dims]) + * - gather sin/cos from input tensor 2&3 using position index tensor passed through input 4 + * + * Inputs: + * 1. Input hidden states tensor of type T1 - shape: + * [batch, seq_length, head_cnt, head_dims] when input_trans0213 == false OR + * [batch, head_cnt, seq_length, head_dims] when input_trans0213 == true + * 2. pre-calculated cos(m*xita[n]) tensor of type T2 - shape [1, 1, max_position_embeddings, d]. + * 3. pre-calculated sin(m*xita[n]) tensor of type T2 - shape [1, 1, max_position_embeddings, d]. + * input 3 is combined with 2 when is_interleaved is true. + * 4. postion index tensor of type T3 - shape [batch, 1, seq_length, 1 or d] OR [batch, seq_length] optional + * Outputs: + * 1. New embedding tensor of type T1 and of shape [batch, head_cnt, seq_length, head_dims] + * Types: + * T1 - FP32 or BF16 + * T2 - FP32 + * T3 - I32 + */ +class RoPENode : public ngraph::op::Op { +public: + OPENVINO_OP("RoPE", "cpu_plugin_opset"); + + RoPENode() = default; + + struct Config { + size_t slice_start = 0; // slice inner-most dimensions of input + size_t slice_stop = 0; + bool input_trans0213 = false; // transpose input dim 1&2 + bool is_interleaved = false; // interleaved mode, implies trans0213 happens after RoPE + size_t rotary_ndims = 0; // dimensions to be embedded (d in the description) + int gather_position_arg_id = + 0; // arg id of position tensor, ==3 when gather from sin/cos inputs according to position is required + }; + + RoPENode(const OutputVector& args, const Config& cfg); + + bool visit_attributes(ngraph::AttributeVisitor& visitor) override; + + void validate_and_infer_types() override; + + std::shared_ptr clone_with_new_inputs(const ngraph::OutputVector& new_args) const override; + + const Config& get_config() const { + return m_config; + } + + Config& get_config() { + return m_config; + } + +private: + Config m_config; +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.cpp new file mode 100644 index 00000000000000..fca105c07a475b --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.cpp @@ -0,0 +1,444 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "rope_fusion.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "itt.hpp" +#include "ov_ops/type_relaxed.hpp" +#include "transformations/cpu_opset/common/op/rope.hpp" +#include "utils/gen_pattern.hpp" + +using namespace ov::gen_pattern; + +ov::intel_cpu::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX() { + MATCHER_SCOPE(RoPEFusionGPTNEOX); + + // rope pattern matching triggers a little design flaw: + // y1 = mul(x, cos) + // y2 = mul(x, sin) + // y = add(y1, y2) + // there is a chance that in 'y1' branch, pattern x is mapped to actual value of cos (mul is commutable) + // this leads to the matching failure of 'y2' branch, because cos didn't appear in that + // branch. + // so here we use a WA, only match the path of rotate_hal(x)*sin and check the x*cos path + // in the callback + auto x = makePattern(ov::Rank(4)); + auto x_or_cos1 = makePattern(ov::Rank(4)); + auto x_or_cos2 = makePattern(ov::Rank(4)); + auto t_sin = makePattern(ov::Rank(4)); + + x->set_friendly_name("x"); + + auto half_ndims = Symbol("half_ndims"); + auto int32_max = std::numeric_limits::max(); + + // rotate half : [-x2, x1] + auto x2 = GenSlice(x, half_ndims, int32_max, 1, 3); + auto x2neg = makePattern({x2, -1.0f}, {{"auto_broadcast", "numpy"}}); + auto x1 = GenSlice(x, 0, half_ndims, 1, 3); + auto x_rotate_half = makePattern({x2neg, x1}, {{"axis", -1}}); + + auto mul_cos = makePattern({x_or_cos1, x_or_cos2}, {{"auto_broadcast", "numpy"}}); + auto mul_sin = makePattern({x_rotate_half, t_sin}, {{"auto_broadcast", "numpy"}}); + + // [x1, x2]*cos + [-x2, x1]*sin + auto result = makePattern({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}}); + + matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + PatternValidator validator(m); + if (!validator) { + return false; + } + + const auto& pattern_map = m.get_pattern_value_map(); + auto root = m.get_match_root(); + + // check mul(x, cos) exists + Output v_cos; + if (pattern_map.at(x_or_cos1) == pattern_map.at(x)) { + v_cos = pattern_map.at(x_or_cos2); + } else if (pattern_map.at(x_or_cos2) == pattern_map.at(x)) { + v_cos = pattern_map.at(x_or_cos1); + } else { + // not a RoPE + return false; + } + + RoPENode::Config config; + OutputVector new_args; + config.rotary_ndims = 2 * validator["half_ndims"]; + + new_args.push_back(pattern_map.at(x)); + new_args.push_back(v_cos); + new_args.push_back(pattern_map.at(t_sin)); + + auto old_node = root; + auto new_node = std::make_shared(new_args, config); + new_node->set_friendly_name(old_node->get_friendly_name()); + ov::replace_node(old_node, new_node); + + // this new node may match following additional matchers + register_new_node(new_node); + + return true; + }; + + auto m = std::make_shared(result, matcher_name); + this->register_matcher(m, callback); +} + +ov::intel_cpu::RoPEFusionCosSinPreprocess::RoPEFusionCosSinPreprocess() { + MATCHER_SCOPE(RoPEFusionCosSinPreprocess); + + auto cos_const = makePattern({}); // "f32[1,1,2048,24]" + auto sin_const = makePattern({}); // "f32[1,1,2048,24]" + + auto node_batch_size = makePattern("i32[1]"); + auto tile_batch = makePattern("i32[1]"); + auto gather_positions = makePattern("i32[?,?,?,?]"); + + auto prepare_cos_sin_gptneox = [&](std::shared_ptr const_tab) { + auto slice1 = makePattern({const_tab, {0}, node_batch_size, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + return makePattern({slice1, gather_positions}, {{"axis", 2}}); + }; + + auto seq_len = makePattern("i32[1]"); + auto gather_positions_2d = makePattern("i32[?,?]"); + + auto head_dims = Symbol("head_dims"); + auto prepare_cos_sin_llama = [&](std::shared_ptr const_tab) { + auto ScatterUpdate = makePattern({{0, 0, 0}, 2, seq_len, 0}); + auto slice_Slice = makePattern({const_tab, {0, 0, 0}, ScatterUpdate, {1, 1, 1}}, + {{"begin_mask", {1, 1, 0}}, + {"end_mask", {1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto squeeze = makePattern({slice_Slice, {-1, head_dims}}); + auto index_Gather = makePattern({squeeze, gather_positions_2d, 0}, {{"batch_dims", 0}}); + auto unsqueeze = makePattern({index_Gather, {1, 1, -1, head_dims}}); + return unsqueeze; + }; + + auto cos_tab = prepare_cos_sin_gptneox(cos_const) | prepare_cos_sin_llama(cos_const); + auto sin_tab = prepare_cos_sin_gptneox(sin_const) | prepare_cos_sin_llama(sin_const); + + auto x = makePattern(ov::Rank(4)); + auto rope = makePattern({x, cos_tab, sin_tab}); + + matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + PatternValidator validator(m); + if (!validator) { + return false; + } + const auto& pattern_map = m.get_pattern_value_map(); + auto root = m.get_match_root(); + auto rope_node = as_type_ptr(pattern_map.at(rope).get_node_shared_ptr()); + if (!rope_node) + return false; + + if (pattern_map.count(cos_const)) { + rope_node->set_argument(1, pattern_map.at(cos_const)); + } + if (pattern_map.count(sin_const)) { + rope_node->set_argument(2, pattern_map.at(sin_const)); + } + + auto& config = rope_node->get_config(); + if (pattern_map.count(gather_positions)) { + auto arg_id = rope_node->get_input_size(); + rope_node->set_argument(arg_id, pattern_map.at(gather_positions)); + config.gather_position_arg_id = arg_id; + } else if (pattern_map.count(gather_positions_2d)) { + auto arg_id = rope_node->get_input_size(); + rope_node->set_argument(arg_id, pattern_map.at(gather_positions_2d)); + config.gather_position_arg_id = arg_id; + } + rope_node->validate_and_infer_types(); + register_new_node(rope_node); + return true; + }; + auto m = std::make_shared(rope, matcher_name); + this->register_matcher(m, callback); +} + +// only a fraction of head_size is rotary-embedded +ov::intel_cpu::RoPEFusionIOSlicing::RoPEFusionIOSlicing() { + MATCHER_SCOPE(RoPEFusionIOSlicing); + auto int32_max = std::numeric_limits::max(); + auto data = makePattern(ov::Rank(4)); + + auto ndims = Symbol("ndims"); + auto x = GenSlice(data, 0, ndims, 1, 3); + auto y = GenSlice(data, ndims, int32_max, 1, 3); + auto x_emb = makePattern({x, {}, {}}) | makePattern({x, {}, {}, {}}); + auto result = makePattern({x_emb, y}, {{"axis", -1}}); + + matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + auto root = m.get_match_root(); + + auto rope_node = as_type_ptr(root->input_value(0).get_node_shared_ptr()); + if (!rope_node) + return false; + + PatternValidator validator(m); + if (!validator) { + return false; + } + auto ndims = validator["ndims"]; + + auto& config = rope_node->get_config(); + if (config.rotary_ndims != ndims) + return false; + + // remove slice & concat + rope_node->set_argument(0, pattern_map.at(data)); + rope_node->set_friendly_name(root->get_friendly_name()); + ov::replace_node(root, rope_node); + + rope_node->validate_and_infer_types(); + register_new_node(rope_node); + return true; + }; + auto m = std::make_shared(result, matcher_name); + this->register_matcher(m, callback); +} + +ov::intel_cpu::RoPEFusionPreprocess::RoPEFusionPreprocess() { + MATCHER_SCOPE(RoPEFusionPreprocess); + + // gptneox-preprocess of input data + auto input_to_slice = makePattern(ov::Rank(4)); + auto input_to_trans = makePattern(ov::Rank(4)); // no need to slice from 3S + + // in some model qkv prejection is combined and + // needs to be sliced before RoPE + auto slice_start = Symbol("slice_start"); + auto slice_stop = Symbol("slice_stop"); + auto input_slice = GenSlice(input_to_slice, slice_start, slice_stop, 1, 3); + + // some model will transpose from [B,L,H,S] to [B,H,L,S] before RoPE + auto x = makePattern({input_slice | input_to_trans, {0, 2, 1, 3}}); + auto result = makePattern({x, {}, {}}) | makePattern({x, {}, {}, {}}); + + matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + PatternValidator validator(m); + if (!validator) { + return false; + } + + const auto& pattern_map = m.get_pattern_value_map(); + auto root = m.get_match_root(); + auto rope_node = as_type_ptr(root); + if (!rope_node) + return false; + + auto& config = rope_node->get_config(); + + if (pattern_map.count(input_to_slice)) { + config.slice_start = validator["slice_start"]; + config.slice_stop = validator["slice_stop"]; + config.input_trans0213 = true; + rope_node->set_argument(0, pattern_map.at(input_to_slice)); + } else if (pattern_map.count(input_to_trans)) { + config.input_trans0213 = true; + rope_node->set_argument(0, pattern_map.at(input_to_trans)); + } else { + return false; + } + rope_node->validate_and_infer_types(); + register_new_node(rope_node); + return true; + }; + auto m = std::make_shared(result, matcher_name); + this->register_matcher(m, callback); +} + +// remove stridedslice from 0 to int32_max with stride 1 +ov::intel_cpu::EliminateStridedSlice::EliminateStridedSlice() { + MATCHER_SCOPE(EliminateStridedSlice); + auto data = ov::pass::pattern::any_input(ngraph::pattern::has_static_rank()); + auto begin = ov::pass::pattern::wrap_type(ngraph::pattern::type_matches(ov::element::i32)); + auto end = ov::pass::pattern::wrap_type(ngraph::pattern::type_matches(ov::element::i32)); + auto stride = ov::pass::pattern::wrap_type(ngraph::pattern::type_matches(ov::element::i32)); + + auto strided_slice = + ov::pass::pattern::wrap_type({data, begin, end, stride}, [](const Output& value) { + auto s1 = as_type_ptr(value.get_node_shared_ptr()); + if (!s1->get_new_axis_mask().empty() || !s1->get_shrink_axis_mask().empty() || + !s1->get_ellipsis_mask().empty()) { + return false; + } + + auto inputs = s1->input_values(); + + auto begin = as_type_ptr(inputs[1].get_node_shared_ptr()); + auto end = as_type_ptr(inputs[2].get_node_shared_ptr()); + auto stride = as_type_ptr(inputs[3].get_node_shared_ptr()); + + if (!begin) + return false; + if (!end) + return false; + if (!stride) + return false; + + // stride is all 1 + auto v_stride = stride->cast_vector(); + for (auto& v : v_stride) { + if (v != 1) + return false; + } + + auto v_begin = begin->cast_vector(); + auto v_end = end->cast_vector(); + if (v_begin.size() != v_end.size()) { + return false; + } + + auto& begin_mask = s1->get_begin_mask(); + auto& end_mask = s1->get_end_mask(); + auto mask_size = begin_mask.size(); + if (begin_mask.size() != end_mask.size()) { + return false; + } + + auto int32_max = std::numeric_limits::max(); + size_t i = 0; + for (; i < mask_size; i++) { + if (begin_mask[i] != end_mask[i]) + return false; + // all valid [begin, end] are [0, int32_max] + if (begin_mask[i] == 0 && end_mask[i] == 0) { + if (v_begin[i] != 0 || v_end[i] != int32_max) + return false; + } + } + // the non-masked part + for (; i < v_begin.size(); i++) { + if (v_begin[i] != 0 || v_end[i] != int32_max) + return false; + } + return true; + }); + + matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + auto root = m.get_match_root(); + return replace_output_update_name(root->output(0), root->input_value(0)); + }; + + auto m = std::make_shared(strided_slice, matcher_name); + this->register_matcher(m, callback); +} + +ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() { + MATCHER_SCOPE(RoPEFusionGPTJ); + + auto int32_max = std::numeric_limits::max(); + auto ndims = Symbol("ndims"); + + auto view_Reshape = makePattern(ov::Rank(4)); + + // view_Reshape : B,L,H,S + auto slice_Slice_965 = GenSlice(view_Reshape, 0, ndims, 1, 3); + + auto gather_sin_cos = makePattern("f32"); + + auto varsplit = makePattern({gather_sin_cos, -1, {ndims / 2, -1}}); + varsplit->set_output_size(2); + auto unsqueeze_sin = makePattern({varsplit->output(0), {1, -1, 1, 32}}); + auto unsqueeze_cos = makePattern({varsplit->output(1), {1, -1, 1, 32}}); + // repeate cos/sin table + auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) { + const auto& vec = node.get_vector(); + int32_t v = 0; + for (size_t i = 0; i < vec.size(); i += 2, v++) { + if (vec[i] != v || vec[i + 1] != v) + return false; + } + return true; + }); + auto repeat_interleave_sin = makePattern({unsqueeze_sin, const_idx, 3}, {{"batch_dims", 0}}); + auto repeat_interleave_cos = makePattern({unsqueeze_cos, const_idx, 3}, {{"batch_dims", 0}}); + + auto t_cos = makePattern(ov::Rank(4)); + auto t_sin = makePattern(ov::Rank(4)); + + // x interleave (-x[:,:,:, 1::2], x[:,:,:, 0::2]) + auto slice_Slice_1174 = GenSlice(slice_Slice_965, 1, int32_max, 2, 3); + + auto neg_Multiply_1177 = makePattern({slice_Slice_1174, -1.0f}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_65524 = makePattern({neg_Multiply_1177, -1}); + + auto slice_Slice_1168 = GenSlice(slice_Slice_965, 0, int32_max, 2, 3); + auto Unsqueeze_65525 = makePattern({slice_Slice_1168, -1}); + auto stack_1182 = makePattern({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}}); + + auto ShapeOf_169068 = makePattern({stack_1182}); + auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0); + auto flatten_Concat_1197 = makePattern({flatten_Slice_1194, {-1}}, {{"axis", 0}}); + auto flatten_Reshape_1198 = makePattern({stack_1182, flatten_Concat_1197}); + + // x*cos [B,L,H,ndims] + auto mul_cos = + makePattern({slice_Slice_965, repeat_interleave_cos}, {{"auto_broadcast", "numpy"}}); + auto mul_sin = + makePattern({flatten_Reshape_1198, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}}); + + // *cos + *sin + auto rotary_emb = makePattern({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}}); + + auto slice_Slice_971 = GenSlice(view_Reshape, ndims, int32_max, 1, 3); + auto cat_Concat_1211 = makePattern({rotary_emb, slice_Slice_971}, {{"axis", -1}}); + auto permute_Transpose_1213 = makePattern({cat_Concat_1211, {0, 2, 1, 3}}); + + auto result = permute_Transpose_1213; + + matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + auto root = m.get_match_root(); + PatternValidator validator(m); + if (!validator) { + return false; + } + + RoPENode::Config config; + OutputVector new_args; + config.rotary_ndims = validator["ndims"]; + + config.is_interleaved = true; + + // input is [B,L,H,S] + new_args.push_back(pattern_map.at(view_Reshape)); + // sin_cos table (gathered with positions) [1, L, 64] + new_args.push_back(pattern_map.at(gather_sin_cos)); + new_args.push_back(pattern_map.at(gather_sin_cos)); + + auto old_node = root; + + auto new_node = std::make_shared(new_args, config); + new_node->set_friendly_name(old_node->get_friendly_name()); + ov::replace_node(old_node, new_node); + return true; + }; + + auto m = std::make_shared(result, matcher_name); + this->register_matcher(m, callback); +} \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.hpp new file mode 100644 index 00000000000000..58bab527504096 --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.hpp @@ -0,0 +1,63 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace ov { +namespace intel_cpu { + +class RoPEFusionGPTNEOX : public ngraph::pass::MatcherPass { +public: + OPENVINO_RTTI("RoPEFusionGPTNEOX", "0"); + RoPEFusionGPTNEOX(); +}; + +class RoPEFusionGPTJ : public ngraph::pass::MatcherPass { +public: + OPENVINO_RTTI("RoPEFusionGPTJ", "0"); + RoPEFusionGPTJ(); +}; + +class RoPEFusionIOSlicing : public ngraph::pass::MatcherPass { +public: + OPENVINO_RTTI("RoPEFusionIOSlicing", "0"); + RoPEFusionIOSlicing(); +}; + +class RoPEFusionPreprocess : public ngraph::pass::MatcherPass { +public: + OPENVINO_RTTI("RoPEFusionPreprocess", "0"); + RoPEFusionPreprocess(); +}; + +class RoPEFusionCosSinPreprocess : public ngraph::pass::MatcherPass { +public: + OPENVINO_RTTI("RoPEFusionCosSinPreprocess", "0"); + RoPEFusionCosSinPreprocess(); +}; + +class EliminateStridedSlice : public ngraph::pass::MatcherPass { +public: + OPENVINO_RTTI("EliminateStridedSlice", "0"); + EliminateStridedSlice(); +}; + +class RoPEFusion : public ngraph::pass::GraphRewrite { +public: + OPENVINO_RTTI("RoPEFusion", "0"); + RoPEFusion() { + add_matcher(); + add_matcher(); + // optional heads & tails are fused in separate matcher pass, + // after RoPENode has been created. + add_matcher(); + add_matcher(); + add_matcher(); + } +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 2cd69ffdee2644..a034acb257266b 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -83,6 +83,7 @@ #include "transformations/smart_reshape/matmul_sr.hpp" #include "transformations/init_node_info.hpp" #include "utils/ngraph_transformation.hpp" +#include "utils/print_model.hpp" // LPT transformations #include "low_precision/add.hpp" @@ -110,6 +111,7 @@ #include "transformations/cpu_opset/common/pass/insert_convert_after_extension.hpp" #include "transformations/cpu_opset/common/pass/move_eltwise_up_data_movement.hpp" #include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp" +#include "transformations/cpu_opset/common/pass/rope_fusion.hpp" // Snippets #include "snippets/pass/tokenization.hpp" @@ -654,6 +656,10 @@ void Transformations::PostLpt() { // Execute before snippets. Otherwise FQ will be converted to Subgraph CPU_REGISTER_PASS_X64(postLPTPassManager, ConvertFqRnnToQuantizedRnn); + + CPU_REGISTER_PASS_X64(postLPTPassManager, EliminateStridedSlice); + CPU_REGISTER_PASS_X64(postLPTPassManager, RoPEFusion); + postLPTPassManager.run_passes(model); } diff --git a/src/plugins/intel_cpu/src/utils/gen_pattern.hpp b/src/plugins/intel_cpu/src/utils/gen_pattern.hpp new file mode 100644 index 00000000000000..c562190494ed4e --- /dev/null +++ b/src/plugins/intel_cpu/src/utils/gen_pattern.hpp @@ -0,0 +1,1304 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "openvino/opsets/opset1.hpp" +#include "openvino/opsets/opset2.hpp" +#include "openvino/opsets/opset3.hpp" +#include "openvino/opsets/opset4.hpp" +#include "openvino/opsets/opset5.hpp" +#include "openvino/opsets/opset6.hpp" +#include "openvino/opsets/opset7.hpp" +#include "openvino/opsets/opset8.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/label.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" + +namespace ov { +namespace gen_pattern { + +static bool force_matcher_verbose = false; + +#ifdef CPU_DEBUG_CAPS + +template +static inline void _verbose_log(Args&&... args) { + std::stringstream ss; + int dummy[] = {(ss << std::forward(args) << " ", 0)...}; + (void)(dummy); + ss << std::endl; + std::cout << ss.str(); +} + +static int matcher_verbose_enabled() { + static const int enabled = std::getenv("GENP_VERBOSE") ? (atoi(std::getenv("GENP_VERBOSE"))) : 0; + return enabled; +} + +# define _VERBOSE_LOG(...) \ + if (matcher_verbose_enabled() || force_matcher_verbose) \ + _verbose_log(__VA_ARGS__) +#else +# define _VERBOSE_LOG(...) +#endif + +namespace detail { +inline std::vector split_string(const std::string& s, const std::string& delimiter) { + std::vector ret; + size_t pos = 0, pos_next; + std::string token; + while ((pos_next = s.find(delimiter, pos)) != std::string::npos) { + token = s.substr(pos, pos_next - pos); + ret.push_back(token); + pos = pos_next + 1; + } + // return whole string if no delimiter if found + token = s.substr(pos, pos_next); + ret.push_back(token); + return ret; +} + +template +std::string vec2str(const std::vector& vec, int cnt_limit = 9) { + std::stringstream ss; + ss << "{"; + const char* sep = ""; + for (auto& v : vec) { + cnt_limit--; + if (cnt_limit == 0) { + ss << sep << "..."; + break; + } + ss << sep << v; + sep = ","; + } + ss << "}"; + return ss.str(); +} +} // namespace detail + +struct values_info { + values_info(const char* pattern_list = nullptr) { + if (pattern_list == nullptr || pattern_list[0] == 0) { + all_type_pshape.clear(); + return; + } + auto pattern_vector = detail::split_string(pattern_list, " "); + for (auto& pattern : pattern_vector) { + if (pattern[0] == '[') { + all_type_pshape.emplace_back(ov::element::dynamic, ov::PartialShape(pattern)); + } else { + auto sep = pattern.find("["); + if (sep != std::string::npos) { + // ele_type[p_shape] + all_type_pshape.emplace_back(ov::element::Type(pattern.substr(0, sep)), + ov::PartialShape(pattern.substr(sep))); + } else { + // ele_type + all_type_pshape.emplace_back(ov::element::Type(pattern), ov::PartialShape::dynamic()); + } + } + } + } + + size_t size() { + return all_type_pshape.size(); + } + const std::pair& operator[](int index) { + return all_type_pshape[index]; + } + + //------------------------------------------------------------- + bool predicate(const ov::Output& value) const { + if (all_type_pshape.empty()) + return true; + auto index = value.get_index(); + auto& item = all_type_pshape[index]; + if (!item.first.compatible(value.get_element_type()) || !item.second.compatible(value.get_partial_shape())) { + _VERBOSE_LOG("* mismatched vtype between value & pattern : ", + value.get_element_type(), + value.get_partial_shape(), + "vs", + item.first, + item.second); + return false; + } + return true; + } + + std::string to_string() { + std::stringstream ss; + const char* sep = ""; + for (auto& t : all_type_pshape) { + ss << sep << t.first << t.second; + sep = ";"; + } + return ss.str(); + } + + std::vector> all_type_pshape; +}; + +// Symbol : a constant that unknown at the pattern's building time +// but collected and validated after pattern was matched +// with some sub-graph values. +class Symbol { +private: + struct Entity { + const char* name = "?"; + char op; + double literal_const_value; + std::shared_ptr lhs; + std::shared_ptr rhs; + // _,+,-,*,/ + // l : literal const + // n : named symbol + double eval(const std::map& value_map) const { + switch (op) { + case 'l': + return literal_const_value; + case 'n': + return value_map.at(this); + case '+': + return lhs->eval(value_map) + rhs->eval(value_map); + case '-': + return lhs->eval(value_map) - rhs->eval(value_map); + case '*': + return lhs->eval(value_map) * rhs->eval(value_map); + case '/': + return lhs->eval(value_map) / rhs->eval(value_map); + case '_': + return -lhs->eval(value_map); + case 'r': + return std::sqrt(lhs->eval(value_map)); + default: + assert(false); + return std::numeric_limits::quiet_NaN(); + } + } + }; + std::shared_ptr entity; + +public: + Symbol() { + entity = std::make_shared(); + entity->op = 'n'; + } + Symbol(const char* name) { + entity = std::make_shared(); + entity->op = 'n'; + entity->name = name; + } + Symbol(const int value) { + entity = std::make_shared(); + entity->op = 'l'; + entity->literal_const_value = value; + } + Symbol(char op, const Symbol& lhs, const Symbol& rhs) { + entity = std::make_shared(); + entity->op = op; + entity->lhs = lhs.entity; + entity->rhs = rhs.entity; + } + double eval(const std::map& value_map) const { + return entity->eval(value_map); + } + bool is_independent_var() const { + return entity->op == 'n'; + } + int is_literal_const() const { + return entity->op == 'l'; + } + char get_op() const { + return entity->op; + } + void* get_id() const { + return entity.get(); + } + const char* get_name() const { + return entity->name; + } + bool operator<(const Symbol& rhs) const { + return get_id() < rhs.get_id(); + } +}; + +inline Symbol operator-(const Symbol& lhs) { + return Symbol('_', lhs, lhs); +} +inline Symbol operator+(const Symbol& lhs, const Symbol& rhs) { + return Symbol('+', lhs, rhs); +} +inline Symbol operator-(const Symbol& lhs, const Symbol& rhs) { + return Symbol('-', lhs, rhs); +} +inline Symbol operator*(const Symbol& lhs, const Symbol& rhs) { + return Symbol('*', lhs, rhs); +} +inline Symbol operator/(const Symbol& lhs, const Symbol& rhs) { + return Symbol('/', lhs, rhs); +} +inline Symbol sqrt(Symbol lhs) { + return Symbol('r', lhs, lhs); +} + +namespace detail { + +// AttrAny is simple wrapper of Any to provide some constructor +// to take advantage of C++ implicit conversion to allow: +// - attribute expressed using initializer_list. +// - symbol to be used as attributes +struct AttrAny { + ov::Any any; + + // empty attribute, means empty vector, and error for scalar + AttrAny() {} + + AttrAny(const Symbol& v) : any(v) {} + AttrAny(const ov::element::Type& v) : any(v) {} + AttrAny(const ov::PartialShape& v) : any(v) {} + AttrAny(const ov::Dimension& v) : any(v) {} + AttrAny(bool v) : any(v) {} + AttrAny(int v) : any(v) {} + AttrAny(float v) : any(v) {} + AttrAny(double v) : any(v) {} + AttrAny(long v) : any(static_cast(v)) {} + AttrAny(long long v) : any(static_cast(v)) {} + AttrAny(const char* v) : any(v) {} + AttrAny(const std::string& v) : any(v) {} + + // template ::value>::type = true> + // AttrAny(const T& v) : any(v) {} + + // template ::value>::type = true> + // AttrAny(const std::vector& v) : any(v) {} + + AttrAny(const std::vector& v) : any(v) {} + + // template ::value>::type = true> + // AttrAny(std::initializer_list values) : any(std::vector(values)) {} + AttrAny(std::initializer_list values) : any(std::vector(values)) {} + AttrAny(std::initializer_list values) : any(std::vector(values.begin(), values.end())) {} + AttrAny(std::initializer_list values) : any(std::vector(values)) {} + AttrAny(std::initializer_list values) : any(std::vector(values)) {} + AttrAny(std::initializer_list values) : any(std::vector(values.begin(), values.end())) {} + + AttrAny(std::initializer_list values) : any(std::vector(values)) {} + AttrAny(std::initializer_list values) : any(std::vector(values)) {} + + std::string as_string() { + if (any.is()) + return any.as(); + return any.as(); + } + bool as_bool() { + if (any.is()) + return any.as(); + return any.as(); + } + double as_double() { + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + return any.as(); + } + int64_t as_int64_t() { + if (any.is()) + return any.as(); + return any.as(); + } + + template + std::vector as_vector() { + if (any.empty()) + return {}; + if (!std::is_same::value) { + if (any.is>()) { + auto ivec = any.as>(); + return std::vector(ivec.begin(), ivec.end()); + } + if (any.is>()) { + auto vec = any.as>(); + return std::vector(vec.begin(), vec.end()); + } + } + if (!std::is_same::value) { + if (any.is>()) { + auto ivec = any.as>(); + return std::vector(ivec.begin(), ivec.end()); + } + if (any.is>()) { + auto vec = any.as>(); + return std::vector(vec.begin(), vec.end()); + } + } + if (any.is>()) { + auto ivec = any.as>(); + return std::vector(ivec.begin(), ivec.end()); + } + return any.as>(); + } + + template + std::vector as_T_vector() { + if (any.empty()) + return {}; + if (any.is()) { + auto to_vec = [](std::initializer_list v) { + return std::vector(v); + }; + return to_vec({any.as()}); + } + if (any.is>()) { + auto ivec = any.as>(); + return std::vector(ivec.begin(), ivec.end()); + } + return any.as>(); + } + + std::vector as_str_vector() { + if (any.empty()) + return {}; + if (any.is>()) { + auto vec = any.as>(); + return std::vector(vec.begin(), vec.end()); + } + return any.as>(); + } + + template + T cast_to() { + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + return any.as(); + } + + template + bool equal_to(const std::vector& rhs) { + if (any.empty() && rhs.empty()) + return true; + auto& vec = any.as>(); + return std::equal(vec.begin(), vec.end(), rhs.begin()); + } + + template + bool equal_to(const std::vector& rhs) { + if (any.empty() && rhs.empty()) + return true; + + if (any.is>()) { + auto& vec = any.as>(); + return vec.size() == rhs.size() && std::equal(vec.begin(), vec.end(), rhs.begin()); + } + return equal_to(rhs); + } + + template + typename std::enable_if::value, bool>::type equal_to(const T& rhs) { + return rhs == any.as(); + } + + template + typename std::enable_if::value, bool>::type equal_to(const T& rhs) { + if (any.is()) { + auto& value = any.as(); + return rhs == static_cast(value); + } + return equal_to(rhs); + } +}; + +using AttrMap = std::map; + +class AttrSetter : public ov::AttributeVisitor { +public: + AttrMap& m_attr_map; + std::vector m_missing_attrs; + + AttrSetter(AttrMap& attrs) : m_attr_map(attrs) {} + + const std::vector& get_missing_attrs() { + return m_missing_attrs; + } + + bool should_skip(const std::string& name) { + if (m_attr_map.count(name) == 0) { + // attributes not specified is recorded as missing + m_missing_attrs.push_back(name); + return true; + } + + if (m_attr_map[name].any.is()) { + m_missing_attrs.push_back(name); + return true; + } + + if (m_attr_map[name].any.empty()) { + // input is set to empty, meaning default value is used. + return true; + } + return false; + } + + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_string()); + } + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_bool()); + } + void on_adapter(const std::string& name, ov::ValueAccessor& adapter) override { + if (should_skip(name)) + return; + auto& any = m_attr_map[name].any; + if (auto a = ov::as_type>(&adapter)) { + static_cast(*a) = any.as(); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_vector()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_vector()); + } else if (auto a = ov::as_type>>(&adapter)) { +#if defined(__APPLE__) || defined(__EMSCRIPTEN__) + static_cast&>(*a) = m_attr_map[name].as_vector(); +#else + a->set(m_attr_map[name].as_vector()); +#endif + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_vector()); + //} else if (auto a = ov::as_type>(&adapter)) { + // a->set(m_attr_map[name].as_string()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_string()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_vector()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_T_vector()); + } else { + OPENVINO_THROW("unsupported AttributeAdapter for attribute : ", name); + } + } + + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_double()); + } + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_int64_t()); + } + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_vector()); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_vector()); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_vector()); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_str_vector()); + } +}; + +class GenericPattern : public ov::pass::pattern::op::Pattern { +public: + OPENVINO_RTTI("GenericPattern"); + + explicit GenericPattern(const OutputVector& args = {}, const detail::AttrMap& attrs = {}) + : ov::pass::pattern::op::Pattern(args) { + set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic()); + m_attrs = attrs; + } + + // this allows code inside pred to access pattern node itself + void set_predicate(ov::pass::pattern::op::ValuePredicate pred) { + m_predicate = pred; + } + + bool match_value(ov::pass::pattern::Matcher* matcher, + const Output& pattern_value, + const Output& graph_value) override { + if (m_predicate(graph_value)) { + auto& pattern_map = matcher->get_pattern_value_map(); + pattern_map[shared_from_this()] = graph_value; + matcher->add_node(graph_value); + return (get_input_size() == 0 + ? true + : matcher->match_arguments(pattern_value.get_node(), graph_value.get_node_shared_ptr())); + } + return false; + } + + detail::AttrMap& get_attrs() { + return m_attrs; + } + +private: + detail::AttrMap m_attrs; +}; + +// A glue/syntax-sugar type which allows more types to be used as input to makePattern() +struct PatternNode { + std::shared_ptr node; + int output_port = -1; + + operator ov::Output() const { + return get_output(); + } + + ov::Output get_output() const { + if (output_port >= 0) + return node->output(output_port); + return node->get_default_output(); + } + + PatternNode(const Output& out) : node(out.get_node_shared_ptr()), output_port(out.get_index()) {} + + PatternNode() { + node = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank()); + } + PatternNode(ov::Rank rank) { + node = ov::pass::pattern::any_input([rank](const Output& value) { + if (!rank.compatible(value.get_partial_shape().rank())) { + _VERBOSE_LOG("*mismatched PatternNode rank ", value, " expecting ", rank); + return false; + } + return true; + }); + } + + PatternNode(values_info vt) { + node = ov::pass::pattern::any_input([vt](const Output& value) { + if (!vt.predicate(value)) { + _VERBOSE_LOG("*mismatched PatternNode ", value); + return false; + } + _VERBOSE_LOG(" matched PatternNode ", value); + return true; + }); + } + PatternNode(const std::shared_ptr& node) : node(node) {} + PatternNode(const std::shared_ptr& node) : node(node) {} + PatternNode(const std::shared_ptr& pattern) + : node(std::dynamic_pointer_cast(pattern)) {} + + // 1D-vector & scalar of symbol + PatternNode(std::initializer_list v) { + // initializer_list of Symbol ls special, need to be recorded + // and eval/check in the callback after whole match is complete, + // where all observed actual constant values are known, first + // we will go over all symbols and collect actual value for individual + // symbol(named symbol), and then we go over all derived symbols and + // evaluate their predicated values and compare against what observed, + // and check if they all match. + // node = ConstVector(std::vector(v), nullptr); + node = ov::pass::pattern::wrap_type(); + + auto& rt_info = node->get_rt_info(); + rt_info["symbolic_const_value"] = std::vector(v); + } + PatternNode(const std::vector& v) { + node = ov::pass::pattern::wrap_type(); + auto& rt_info = node->get_rt_info(); + rt_info["symbolic_const_value"] = v; + } + + PatternNode(Symbol v) { + node = ov::pass::pattern::wrap_type(); + auto& rt_info = node->get_rt_info(); + rt_info["symbolic_const_value"] = std::vector({v}); + } + + // scalar constant (treated as wildcard for single-element-constant with any rank) + PatternNode(int v) : node(std::make_shared(element::from(), Shape({}), v)) {} + PatternNode(float v) : node(std::make_shared(element::from(), Shape({}), v)) {} + + PatternNode(std::initializer_list v, values_info vi = nullptr) { + node = ConstVector(std::vector(v), vi); + } + PatternNode(std::initializer_list v, values_info vi = nullptr) { + node = ConstVector(std::vector(v), vi); + } + PatternNode(std::initializer_list v, values_info vi = nullptr) { + node = ConstVector(std::vector(v.begin(), v.end()), vi); + } + PatternNode(std::initializer_list v, values_info vi = nullptr) { + node = ConstVector(std::vector(v.begin(), v.end()), vi); + } + + // 1d const tensor or scalar + template ::value, bool>::type = true> + static std::shared_ptr ConstVector(const std::vector& vec, values_info vi = nullptr) { + if (vi.size() > 0) + return std::make_shared(vi[0].first, vi[0].second.to_shape(), vec); + // initializer_list w/o value_info means to create normal 1D vector + return std::make_shared(element::from(), Shape({vec.size()}), vec); + } +}; + +using SymbolObservationVector = std::vector>; + +template +void add_symbol_observed(SymbolObservationVector& sov, const Symbol& sym, const T& value) { + auto v = static_cast(value); + OPENVINO_ASSERT(static_cast(v) == value); // ensure there is no precison lost in double + sov.push_back(std::make_pair(sym, v)); +} +/* +template +static bool vector_equal_to_any(const std::vector& v0, detail::AttrAny& any) { + auto v1 = any.cast_to_vector(); + if (v0.size() != v1.size()) + return false; + return std::equal(v0.begin(), v0.end(), v1.begin()); +} + +template +static bool scalar_equal_to_any(const T& v0, detail::AttrAny& any) { + if (any.is()) { + return v0 == any.as(); + } else if (any.is()) { + return v0 == any.as(); + } + return v0 == any.as(); +} +*/ +// for arithmetic data type, Attr matcher will success as long as the actuall attributes +// is equal to the casted attributes from pattern w/o requiring exact type match. +class AttrMatcher : public ov::AttributeVisitor { +public: + AttrMap& m_attr_map; + std::vector m_missing_attrs; + SymbolObservationVector* m_psov; + bool m_all_matched; + + AttrMatcher(AttrMap& attrs, SymbolObservationVector* psov = nullptr) + : m_attr_map(attrs), + m_psov(psov), + m_all_matched(true) {} + + bool matched() { + return m_all_matched; + } + + const std::vector& get_missing_attrs() { + return m_missing_attrs; + } + + bool should_skip(const std::string& name, bool allow_symbol = false) { + if (m_attr_map.count(name) == 0) { + m_missing_attrs.push_back(name); + return true; + } + + if (!allow_symbol) { + OPENVINO_ASSERT(!m_attr_map[name].any.is(), "Symbol is not allowed."); + } + return false; + } + + void add_match_result(const std::string& name, bool is_matched) { + if (!is_matched) { + _VERBOSE_LOG(" attribute '", name, "' mismatch."); + } + m_all_matched = m_all_matched && is_matched; + } + + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + add_match_result(name, value.get() == m_attr_map[name].as_string()); + } + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + + // only integer is allowed to be of symbol type + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name, true)) + return; + auto& any = m_attr_map[name].any; + if (any.is()) { + if (m_psov) { + // collect symbol reference and do comparison later + add_symbol_observed(*m_psov, any.as(), value.get()); + } + return; + } + add_match_result(name, m_attr_map[name].cast_to() == value.get()); + } + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name, true)) + return; + auto& any = m_attr_map[name].any; + if (any.is()) { + if (m_psov) { + // collect symbol reference and do comparison later + add_symbol_observed(*m_psov, any.as(), value.get()); + } + return; + } + add_match_result(name, m_attr_map[name].cast_to() == value.get()); + } + + void on_adapter(const std::string& name, ov::ValueAccessor& adapter) override { + if (should_skip(name)) + return; + OPENVINO_ASSERT(m_attr_map.count(name) > 0); + auto& any = m_attr_map[name].any; + bool is_matched = true; + if (auto a = ov::as_type>(&adapter)) { + is_matched = (static_cast(*a) == any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = (a->get() == any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = (a->get() == any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = m_attr_map[name].equal_to(a->get()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = m_attr_map[name].equal_to(a->get()); + } else if (auto a = ov::as_type>>(&adapter)) { +#if defined(__APPLE__) || defined(__EMSCRIPTEN__) + is_matched = m_attr_map[name].equal_to(static_cast&>(*a)); +#else + is_matched = m_attr_map[name].equal_to(a->get()); +#endif + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = m_attr_map[name].equal_to(a->get()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = (a->get() == any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = (a->get() == any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = m_attr_map[name].equal_to(a->get()); + } else { + OPENVINO_THROW("AttrSetter met unsupported AttributeAdapter"); + } + add_match_result(name, is_matched); + } +}; +} // namespace detail + +//================================================================================================== + +inline std::shared_ptr GenInput(values_info vt = nullptr) { + return ov::pass::pattern::any_input([vt](const Output& value) { + if (!vt.predicate(value)) { + _VERBOSE_LOG("*mismatched GenInput ", value); + return false; + } + _VERBOSE_LOG(" matched GenInput ", value); + return true; + }); +} + +inline std::shared_ptr makePattern() { + detail::PatternNode g; + return g.node; +} + +inline std::shared_ptr makePattern(ov::Rank rank) { + detail::PatternNode g(rank); + return g.node; +} + +inline std::shared_ptr makePattern(values_info vt) { + detail::PatternNode g(vt); + return g.node; +} + +// unknown const +inline std::shared_ptr makeConst(const ov::element::Type& type, + const ov::PartialShape& pshape, + std::function pred) { + return ov::pass::pattern::wrap_type([type, pshape, pred](const Output& value) { + auto cnode = ov::as_type_ptr(value.get_node_shared_ptr()); + if (!cnode) + return false; + + if (!type.compatible(value.get_element_type()) || !pshape.compatible(value.get_partial_shape())) { + return false; + } + if (pred && !pred(*cnode)) { + return false; + } + return true; + }); +} + +template +std::shared_ptr makeConst(const ov::element::Type& type, + const ov::Shape& shape, + std::initializer_list values) { + return std::make_shared(type, shape, std::vector(values)); +} + +template +std::shared_ptr makeConst(const ov::element::Type& type, const ov::Shape& shape, const std::vector& values) { + return std::make_shared(type, shape, values); +} + +template +std::shared_ptr makePattern(const std::vector& inputs, + detail::AttrMap attrmap = {}, + values_info vt = nullptr, + const char* friendly_name = nullptr) { + auto* p_type_info = &(T::get_type_info_static()); + OutputVector args; + for (auto& in : inputs) + args.push_back(in.get_output()); + + // pattern nodes are better for pattern matching because + // - it can be generic/incomplete, so normal OP node is not working properly + // - it has predicate to correctly decide which branch to take (in Or pattern) + auto pattern_node = std::make_shared(args, attrmap); + + if (friendly_name) { + pattern_node->set_friendly_name(friendly_name); + } else { + std::stringstream ss; + ss << p_type_info->get_version() << "::" << p_type_info->name; + ss << "("; + const char* sep = ""; + for (auto& i : args) { + ss << sep << i.get_node()->get_name(); + sep = ","; + } + ss << ")"; + pattern_node->set_friendly_name(ss.str()); + } + + auto* pnode = pattern_node.get(); + pnode->set_predicate([p_type_info, vt, pnode, friendly_name, attrmap](const Output& value) { + (void)friendly_name; + auto value_node = value.get_node_shared_ptr(); + if (!value_node->get_type_info().is_castable(*p_type_info)) { + _VERBOSE_LOG("*mismatched makePattern OP type: ", pnode->get_friendly_name(), "vs", value); + return false; + } + + if (!vt.predicate(value)) { + _VERBOSE_LOG("*mismatched makePattern value info: ", pnode->get_friendly_name(), "vs", value); + return false; + } + + auto& attr_map = pnode->get_attrs(); + if (!attr_map.empty()) { + detail::AttrMatcher visitor(attr_map); + value_node->visit_attributes(visitor); + if (!visitor.matched()) { + _VERBOSE_LOG("*mismatched attributes : ", + pnode->get_friendly_name(), + " vs ", + value_node->get_friendly_name()); + return false; + } + } + + _VERBOSE_LOG(" matched makePattern ", pnode->get_friendly_name(), " == ", value); + return true; + }); + + return pattern_node; +} + +template +std::shared_ptr makeOP(const std::vector& inputs, + detail::AttrMap attrmap = {}, + const char* friendly_name = nullptr) { + std::shared_ptr node = std::make_shared(); + + OutputVector args; + for (auto& in : inputs) + args.push_back(in.get_output()); + node->set_arguments(args); + + detail::AttrSetter visitor(attrmap); + node->visit_attributes(visitor); + + auto missing_attrs = visitor.get_missing_attrs(); + + // when some attribute is missing or is symbol, the returned + // node is suitable for pattern matching only. + OPENVINO_ASSERT(missing_attrs.size() == 0, + "missing ", + missing_attrs.size(), + " attributes : ", + missing_attrs[0], + "..."); + + if (friendly_name) + node->set_friendly_name(friendly_name); + node->constructor_validate_and_infer_types(); + return node; +} + +template +std::shared_ptr GenConst_tril(values_info vt) { + return ov::pass::pattern::wrap_type([vt](const Output& value) { + auto s1 = as_type_ptr(value.get_node_shared_ptr()); + if (!s1) { + _VERBOSE_LOG("*mismatched GenConst_tril op type: opset1::Constant vs", value); + return false; + } + + if (!vt.predicate(value)) { + _VERBOSE_LOG("*mismatched GenConst_tril values_info:", value); + return false; + } + + // ignore higher dimensions, require lowerst 2D to be lower triangular + auto shape = s1->get_output_shape(0); + auto rank = shape.size(); + if (rank < 2) { + _VERBOSE_LOG("*mismatched GenConst_tril rank < 2 (rank=", rank, ")"); + return false; + } + if (shape[rank - 1] != shape[rank - 2]) { + _VERBOSE_LOG("*mismatched GenConst_tril shape[-1] != shape[-2] : ", + shape[rank - 1], + " != ", + shape[rank - 2]); + return false; + } + // NxN const matrix + auto N = shape[rank - 1]; + std::vector output_vector = s1->cast_vector(); + // check if it's unit lower triangular matrix + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < N; j++) { + if (static_cast(output_vector[i * N + j]) != static_cast(j <= i)) + return false; + } + } + return true; + }); +} + +inline std::shared_ptr operator|(const Output& lhs, const Output& rhs) { + return std::make_shared(OutputVector{lhs, rhs}); +} + +inline std::shared_ptr operator|(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { + return std::make_shared( + OutputVector{lhs->get_default_output(), rhs->get_default_output()}); +} + +inline std::shared_ptr GenSlice(detail::PatternNode data, Symbol start, Symbol stop, Symbol step, size_t axis) { + auto opt1 = makePattern({data, {start}, {stop}, {step}, {static_cast(axis)}}); + + std::vector vbegin(axis + 1, Symbol(0)); + std::vector vend(axis + 1, Symbol(0)); + std::vector vstride(axis + 1, Symbol(1)); + + vbegin[axis] = start; + vend[axis] = stop; + vstride[axis] = step; + + detail::PatternNode begin(vbegin); + detail::PatternNode end(vend); + detail::PatternNode stride(vstride); + + std::vector begin_mask(axis + 1, 1); + std::vector end_mask(axis + 1, 1); + std::vector new_axis_mask; + std::vector shrink_axis_mask; + std::vector ellipsis_mask; + + begin_mask[axis] = 0; + end_mask[axis] = 0; + + auto opt2 = makePattern({data, begin, end, stride}, + {{"begin_mask", begin_mask}, + {"end_mask", end_mask}, + {"new_axis_mask", new_axis_mask}, + {"shrink_axis_mask", shrink_axis_mask}, + {"ellipsis_mask", ellipsis_mask}}); + return opt1 | opt2; +} + +//================================================================================================== +class PatternValidator { +public: + PatternValidator(ov::pass::pattern::Matcher& m, bool force_verbose = false) { + auto saved_force_matcher_verbose = force_matcher_verbose; + force_matcher_verbose = force_verbose; + m_is_valid = validate(m); + force_matcher_verbose = saved_force_matcher_verbose; + } + + double& operator[](const char* symbol_name) { + return m_symbol_values[symbol_name]; + } + + operator bool() { + if (!m_is_valid) { + _VERBOSE_LOG("PatternValidator failed."); + } + return m_is_valid; + } + + bool validate(ov::pass::pattern::Matcher& m) { + detail::SymbolObservationVector sov; + + auto& pvmap = m.get_pattern_value_map(); + for (auto& pv : pvmap) { + auto pnode = pv.first; + auto value_node = pv.second.get_node_shared_ptr(); + auto& rt_info = pnode->get_rt_info(); + + if (auto pattern_node = std::dynamic_pointer_cast(pnode)) { + // pattern_node has no attribute and it has been matched in its predicate + if (rt_info.count("symbolic_const_value")) { + // symbolic constant node, a symbol reference is observed + auto& symbols = rt_info["symbolic_const_value"].as>(); + auto constop = std::dynamic_pointer_cast(value_node); + if (!constop) { + _VERBOSE_LOG("symbolic_const_value unexpected OP: ", value_node->get_friendly_name()); + return false; + } + auto ele_cnt = shape_size(constop->get_shape()); + auto ele_type = constop->get_element_type(); + + if (ele_cnt != symbols.size()) { + _VERBOSE_LOG("symbolic_const_value expect ", + symbols.size(), + " but got ", + ele_cnt, + " from ", + value_node->get_friendly_name()); + return false; + } + + if (ele_type == ov::element::i32 || ele_type == ov::element::f32 || ele_type == ov::element::i64) { + auto observed = constop->cast_vector(); + for (size_t i = 0; i < symbols.size(); i++) + detail::add_symbol_observed(sov, symbols[i], observed[i]); + } else { + _VERBOSE_LOG("Unexpect element type ", ele_type, " from ", value_node->get_friendly_name()); + return false; + } + } + continue; + } + if (auto pconst_node = std::dynamic_pointer_cast(pnode)) { + // const_node needs to match type/shape/value + auto vconst_node = std::dynamic_pointer_cast(value_node); + if (!vconst_node) { + _VERBOSE_LOG("expecting Constant op, but got ", value_node); + return false; + } + if (pconst_node->get_output_element_type(0) != vconst_node->get_output_element_type(0)) { + _VERBOSE_LOG("expecting Constant of type ", + pconst_node->get_output_element_type(0), + " but got ", + vconst_node); + return false; + } + // for constant node matched in pattern, a scalar constant is considered to + // be compatible with any shape with 1 element, like {}, {1,1}, {1,1,...} + const auto& expected_shape = pconst_node->get_output_shape(0); + if (expected_shape.size() == 0) { + if (shape_size(vconst_node->get_output_shape(0)) != 1) { + _VERBOSE_LOG("expecting a single element const, but got ", vconst_node); + return false; + } + } else { + if (expected_shape != vconst_node->get_output_shape(0)) { + _VERBOSE_LOG("expecting Constant of shape ", expected_shape, " but got ", vconst_node); + return false; + } + } + auto byte_size = + shape_size(vconst_node->get_output_shape(0)) * vconst_node->get_output_element_type(0).size(); + if (std::memcmp(pconst_node->get_data_ptr(), vconst_node->get_data_ptr(), byte_size) != 0) { + _VERBOSE_LOG("Constant value mismatch."); + return false; + } + continue; + } + + // compare attributes between them + // assume that there is no Symbol in the attributes, we need to fetch each attributes + // from + if (rt_info.count("__attrs__") == 0) { + _VERBOSE_LOG(" attr compare failed: __attrs__ not found for ", pnode->get_friendly_name()); + return false; + } + + // attr not specified is treated as not-care and ignored + // attr with symbol + + detail::AttrMap& attr_map = rt_info["__attrs__"].as(); + detail::AttrMatcher visitor(attr_map, &sov); + value_node->visit_attributes(visitor); + if (!visitor.matched()) { + _VERBOSE_LOG(" attr compare failed: ", + pnode->get_friendly_name(), + " vs ", + value_node->get_friendly_name()); + return false; + } + } + + // check symbol consistency & return independent symbols + // assign independent symbols & check literals + std::map symbol_value_map; + for (auto& ref : sov) { + auto& sym = ref.first; + auto& value = ref.second; + + if (sym.is_independent_var()) { + auto id = sym.get_id(); + if (symbol_value_map.count(id)) { + if (symbol_value_map[id] != value) { + _VERBOSE_LOG(" in-consistency between multiple references of same symbol : ", + symbol_value_map[id], + " != ", + value); + return false; + } + } else { + symbol_value_map[id] = value; + m_symbol_values[sym.get_name()] = value; + _VERBOSE_LOG("Independent Symbol: ", sym.get_name(), " = ", value); + } + } + + if (sym.is_literal_const()) { + auto literal = sym.eval(symbol_value_map); + if (literal != value) { + _VERBOSE_LOG(" mismatch between literal symbol & value : ", literal, " != ", value); + return false; + } + // no need to put literal into value map to eval them. + } + } + + // derive/eval dependent symbol's value and check against observed + for (auto& ref : sov) { + auto& sym = ref.first; + if (!sym.is_literal_const() && !sym.is_independent_var()) { + auto derived = sym.eval(symbol_value_map); + auto value = ref.second; + bool is_match; + + if (std::trunc(value) == value) { + // observed integer + is_match = (derived == value); + } else { + auto abs_diff = std::abs(derived - value); + auto avg = 0.5f * std::abs(derived + value); + if (avg != 0) { + is_match = abs_diff < avg * 1e-7; // relative error less than threshold + } else { + is_match = (derived == value); + } + } + if (!is_match) { + _VERBOSE_LOG(" mismatch between derived & value : ", + std::setprecision(std::numeric_limits::max_digits10), + derived, + " != ", + std::setprecision(std::numeric_limits::max_digits10), + value); + return false; + } + } + } + return true; + } + +private: + std::map m_symbol_values; + bool m_is_valid; +}; + +} // namespace gen_pattern +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/utils/print_model.hpp b/src/plugins/intel_cpu/src/utils/print_model.hpp new file mode 100644 index 00000000000000..6b4eb01180a264 --- /dev/null +++ b/src/plugins/intel_cpu/src/utils/print_model.hpp @@ -0,0 +1,415 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "openvino/core/attribute_visitor.hpp" +#include "openvino/core/model.hpp" +#include "openvino/core/node.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace pass { + +namespace detail { + +// to_code convert value into literal/constexpr/initializer_list/factory_calls in C++ source code +inline std::string to_code(bool value) { + return value ? "true" : "false"; +} +inline std::string to_code(const std::string& value) { + return std::string("\"") + value + "\""; +} +inline std::string to_code(const element::Type& value) { + return std::string("element::") + value.to_string(); +} +inline std::string to_code(const ov::Shape& value) { + std::stringstream ss; + ss << "ov::Shape({"; + for (auto& d : value) + ss << d << ","; + ss << "})"; + return ss.str(); +} +inline std::string to_code(int value) { + if (INT_MAX == value) { + return "INT_MAX"; + } + if (INT_MIN == value) { + return "INT_MIN"; + } + return std::to_string(value); +} +inline std::string to_code(int64_t value) { + if (LLONG_MAX == value) { + return "LLONG_MAX"; + } + if (LLONG_MIN == value) { + return "LLONG_MIN"; + } + const char* suffix = "LL"; + if (value == static_cast(static_cast(value))) { + // save suffix since most values can be expressed as int + // this produces more readable code + suffix = ""; + } + return std::to_string(value) + suffix; +} +inline std::string to_code(uint64_t value) { + if (ULLONG_MAX == value) { + return "ULLONG_MAX"; + } + const char* suffix = "uLL"; + if (value == static_cast(static_cast(value))) { + // save suffix since most values can be expressed as int + // this produces more readable code + suffix = ""; + } + return std::to_string(value) + suffix; +} +inline std::string to_code(int8_t value) { + return std::to_string(static_cast(value)); +} +inline std::string to_code(uint8_t value) { + return std::to_string(static_cast(value)); +} + +template +std::string to_code_float(T value) { + if (std::isnan(value)) { + return "NAN"; + } else if (std::isinf(value)) { + return (value > 0 ? "INFINITY" : "-INFINITY"); + } else if (value == FLT_MIN) { + return "FLT_MIN"; + } else if (value == -FLT_MIN) { + return "-FLT_MIN"; + } else if (value == FLT_MAX) { + return "FLT_MAX"; + } else if (value == -FLT_MAX) { + return "-FLT_MAX"; + } + auto strv = std::to_string(value); + if (strv.find(".") == std::string::npos && strv.find("e") == std::string::npos) + strv += ".0"; + if (std::is_same::value) + strv += "f"; + return strv; +} + +inline std::string to_code(float value) { + return to_code_float(value); +} +inline std::string to_code(double value) { + return to_code_float(value); +} +template +std::string to_code(const std::vector& values, bool no_braces = false, int maxsize = 80) { + std::stringstream ss; + if (!no_braces) + ss << "{"; + const char* sep = ""; + for (auto& v : values) { + if (ss.tellp() > maxsize) { + ss << "... (" << values.size() << " in total)"; + break; + } + ss << sep << to_code(v); + sep = ","; + } + if (!no_braces) + ss << "}"; + return ss.str(); +} + +template +std::string to_code(std::shared_ptr constop) { + bool no_braces = (constop->get_shape().size() == 0); + auto ele_type = constop->get_element_type(); + if (ele_type == element::Type_t::f32) { + return to_code(constop->get_vector(), no_braces); + } else if (ele_type == element::Type_t::i8) { + return to_code(constop->get_vector(), no_braces); + } else if (ele_type == element::Type_t::u8) { + return to_code(constop->get_vector(), no_braces); + } else if (ele_type == element::Type_t::i32) { + return to_code(constop->get_vector(), no_braces); + } else if (ele_type == element::Type_t::i64) { + return to_code(constop->get_vector(), no_braces); + } + + // general case + std::stringstream ss; + if (!no_braces) + ss << "{"; + auto ele_size = shape_size(constop->get_shape()); + if (ele_size < 9) { + const char* sep = ""; + for (auto v : constop->get_value_strings()) { + ss << sep << v; + sep = ", "; + } + } else { + ss << "..."; + } + if (!no_braces) + ss << "}"; + return ss.str(); +} + +class OstreamAttributeVisitor : public ngraph::AttributeVisitor { + std::ostream& os; + const char* sep = ""; + +public: + OstreamAttributeVisitor(std::ostream& os) : os(os) {} + + void append_attribute(const std::string& name, const std::string& value) { + os << sep << "{\"" << name << "\", " << value << "}"; + sep = ", "; + } + + void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { + if (auto a = ov::as_type>>(&adapter)) { + const auto& strset = a->get(); + std::vector values(strset.begin(), strset.end()); + append_attribute(name, to_code(values)); + } else if (auto a = ov::as_type>>(&adapter)) { + append_attribute(name, to_code(a->get())); + } else if (auto a = ov::as_type>(&adapter)) { + const auto& value = a->get(); + append_attribute(name, value.to_string()); + } else if (auto a = ov::as_type>>(&adapter)) { + const auto& vinfo = a->get()->get_info(); + std::stringstream ss; + ss << vinfo.variable_id << vinfo.data_shape << vinfo.data_type; + append_attribute(name, ss.str()); + } else { + append_attribute(name, "?"); + } + } + + void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { + append_attribute(name, to_code(adapter.get())); + } + void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { + append_attribute(name, to_code(adapter.get())); + } + void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { + append_attribute(name, to_code(adapter.get())); + } + void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { + append_attribute(name, to_code(adapter.get())); + } + void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { + append_attribute(name, to_code(adapter.get())); + } + void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { + append_attribute(name, to_code(adapter.get())); + } + void on_adapter(const std::string& name, ngraph::ValueAccessor>& adapter) override { + append_attribute(name, to_code(adapter.get())); + } + void on_adapter(const std::string& name, ngraph::ValueAccessor>& adapter) override { + append_attribute(name, to_code(adapter.get())); + } + void on_adapter(const std::string& name, ngraph::ValueAccessor>& adapter) override { + append_attribute(name, to_code(adapter.get())); + } + void on_adapter(const std::string& name, ngraph::ValueAccessor>& adapter) override { + append_attribute(name, to_code(adapter.get())); + } + void on_adapter(const std::string& name, ngraph::ValueAccessor>& adapter) override { + append_attribute(name, to_code(adapter.get())); + } + void on_adapter(const std::string& name, ngraph::ValueAccessor>& adapter) override { + append_attribute(name, "Model"); + } +}; + +template +void dump_cpp_style(std::ostream& os, const std::shared_ptr& model) { + const ov::Model& f = *model; + std::string prefix = ""; + std::string tag = ""; + std::string sep = ""; + os << prefix; + for (auto op : f.get_results()) { + os << sep << op->get_name(); + sep = ","; + } + os << " " << f.get_friendly_name() << "(\n" << prefix; + for (auto op : f.get_parameters()) { + os << " " << tag << op->get_friendly_name() << ",\n" << prefix; + } + os << ") {\n"; + + // collect all scalar & short 1D vectors for literal-style display + std::map, std::string> literal_consts; + for (auto op : f.get_ordered_ops()) { + if (auto constop = std::dynamic_pointer_cast(op)) { + // only i32/f32 type const literal can be parsed by C++ compiler + if (constop->get_output_element_type(0) != ov::element::i32 && + constop->get_output_element_type(0) != ov::element::i64 && + constop->get_output_element_type(0) != ov::element::f32) + continue; + auto shape = constop->get_shape(); + if (shape.size() > 1) + continue; + if (shape_size(constop->get_shape()) > 64) + continue; + literal_consts[op] = to_code(constop); + } + } + + auto get_output_values_info = [](std::shared_ptr& op) { + std::stringstream ss; + const char* sep = ""; + for (size_t i = 0; i < op->get_output_size(); i++) { + ss << sep << op->get_output_element_type(i) << op->get_output_partial_shape(i); + sep = " "; + } + return ss.str(); + }; + + // change name convension + std::map opname; + std::map opname_count; + for (auto op : f.get_ordered_ops()) { + auto name = op->get_friendly_name(); + std::replace(name.begin(), name.end(), '\\', '_'); + std::replace(name.begin(), name.end(), '/', '_'); + std::replace(name.begin(), name.end(), '.', '_'); + std::replace(name.begin(), name.end(), '[', '_'); + std::replace(name.begin(), name.end(), ']', '_'); + std::replace(name.begin(), name.end(), '-', 'n'); + if (name[0] >= '0' && name[0] <= '9') { + const auto& type_info = op->get_type_info(); + name.insert(0, type_info.name); + } + int idx = 0; + if (opname_count.count(name)) { + idx = opname_count[name] + 1; + } + opname_count[name] = idx; + + if (idx) + name += std::to_string(idx); + + opname[op.get()] = name; + } + + for (auto op : f.get_ordered_ops()) { + if (literal_consts.count(op)) + continue; + + const auto& type_info = op->get_type_info(); + auto version_info = std::string(type_info.get_version()); + auto type = version_info + "::" + type_info.name; + auto name = opname[op.get()]; + os << prefix << " "; + + if (auto constop = std::dynamic_pointer_cast(op)) { + os << "auto " << name << " = makeConst(" << to_code(op->get_output_element_type(0)) << ", " + << to_code(op->get_output_shape(0)) << ", " << to_code(constop) << ");" << std::endl; + } else { + os << "auto " << name << " = makeOP<" << type << ">({"; + // input args + sep = ""; + for (size_t i = 0; i < op->get_input_size(); i++) { + auto vout = op->get_input_source_output(i); + auto iop = vout.get_node_shared_ptr(); + if (iop->get_output_size() > 1) { + auto out_port = vout.get_index(); + os << sep << tag << opname[iop.get()] << "->output(" << out_port << ")"; + } else { + if (literal_consts.count(iop)) + os << sep << tag << literal_consts[iop]; + else + os << sep << tag << opname[iop.get()]; + } + sep = ", "; + } + os << "}"; + + // attributes as AnyMap + std::stringstream ss2; + OstreamAttributeVisitor osvis(ss2); + op->visit_attributes(osvis); + auto str_attr = ss2.str(); + if (str_attr.size()) + os << ", {" << str_attr << "}"; + os << "); // tensor_array<" << get_output_values_info(op) << "> " << op->get_friendly_name(); + + os << "("; + sep = ""; + for (size_t i = 0; i < op->get_input_size(); i++) { + auto vout = op->get_input_source_output(i); + auto iop = vout.get_node_shared_ptr(); + os << sep << tag << iop->get_friendly_name(); + if (iop->get_output_size() > 1) { + auto out_port = vout.get_index(); + os << "[" << out_port << "]"; + } + sep = ", "; + } + os << ")" << std::endl; + } + + // recursively output subgraphs + if (auto msubgraph = std::dynamic_pointer_cast(op)) { + auto cnt = msubgraph->get_internal_subgraphs_size(); + for (size_t i = 0; i < cnt; i++) { + os << " MultiSubGraphOp " << tag << msubgraph->get_friendly_name() << "[" << i << "]" << std::endl; + dump_cpp_style(os, msubgraph->get_function(i)); + } + } + } + os << prefix << "}\n"; +} + +} // namespace detail + +class OPENVINO_API PrintModel : public ov::pass::ModelPass { +public: + OPENVINO_RTTI("ov::pass::PrintModel"); + + PrintModel(std::string file_name) { + static int dump_index = 0; + m_file_name = std::string("modelprint_") + std::to_string(dump_index) + "_" + file_name; + dump_index++; + } + ~PrintModel() {} + + bool run_on_model(const std::shared_ptr& model) override { + if (m_file_name.empty()) + return false; + + std::ofstream ofs(m_file_name); + if (!ofs) { + // OPENVINO_WARN << "Error opening file " << m_file_name << " for output" << std::endl; + return false; + } + detail::dump_cpp_style(ofs, model); + ofs.close(); + return true; + } + +protected: + std::string m_file_name; +}; +} // namespace pass +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/rotary_pos_emb.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/rotary_pos_emb.cpp new file mode 100644 index 00000000000000..edfce4dcc9520e --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/rotary_pos_emb.cpp @@ -0,0 +1,184 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/common_utils.hpp" +#include "functional_test_utils/skip_tests_config.hpp" +#include "ie_precision.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" +#include "test_utils/cpu_test_utils.hpp" +#include "test_utils/fusing_test_utils.hpp" +#include "utils/gen_pattern.hpp" + +using namespace CPUTestUtils; +using namespace ov::gen_pattern; +using namespace ov::test; +using namespace ov; + +static ov::OutputVector makeCosSinCache(int max_position_embeddings, int rotary_ndims) { + std::vector lut_sin(max_position_embeddings * rotary_ndims, 0.0f); + std::vector lut_cos(max_position_embeddings * rotary_ndims, 0.0f); + + // rotate_half style cos/sin table: + // y1 = cos(m*xita_i) * x1 - sin(m*xita_i) * x2 + // y2 = cos(m*xita_i) * x2 + sin(m*xita_i) * x1 + // + for (int i = 0, k = 0; i < rotary_ndims; i += 2, k++) { + auto xita_i = 1.0 / std::pow(10000.0, static_cast(i) / rotary_ndims); + float* psin = lut_sin.data(); + float* pcos = lut_cos.data(); + for (int m = 0; m < max_position_embeddings; m++, psin += rotary_ndims, pcos += rotary_ndims) { + auto vsin = std::sin(xita_i * m); + auto vcos = std::cos(xita_i * m); + pcos[k] = pcos[k + rotary_ndims / 2] = vcos; + psin[k] = psin[k + rotary_ndims / 2] = vsin; + } + } + auto shape = ov::Shape({1, 1, static_cast(max_position_embeddings), static_cast(rotary_ndims)}); + auto Cos = makeConst(ov::element::f32, shape, lut_cos); + auto Sin = makeConst(ov::element::f32, shape, lut_sin); + return {Cos, Sin}; +} + +static std::shared_ptr buildROPE_Llama2(const int batch, + const int seq_length, + const int max_position_embeddings, + const int num_head, + const int ndims) { + auto input = std::make_shared(ov::element::f32, PartialShape{batch, -1, num_head, ndims}); + auto pos_id_end = std::make_shared(ov::element::i32, ov::Shape{}); + auto pos_ids = std::make_shared(ov::element::i32, PartialShape{1, -1}); + + auto cos_sin_cache = makeCosSinCache(max_position_embeddings, ndims); + auto Constant582 = cos_sin_cache[0]; + auto Constant585 = cos_sin_cache[1]; + + // concat KV length + auto transpose_Transpose = makeOP({input, {0, 2, 1, 3}}); + auto slice_Unsqueeze_426 = makeOP({pos_id_end, 0}); + auto ScatterUpdate_152236 = makeOP({{0, 0, 0}, {2}, slice_Unsqueeze_426, {0}}); + auto slice_Slice = makeOP({Constant582, {0, 0, 0}, ScatterUpdate_152236, {1, 1, 1}}, + {{"begin_mask", {1, 1, 0}}, + {"end_mask", {1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto squeeze_Squeeze = makeOP({slice_Slice, 1}); + auto squeeze_Squeeze_435 = makeOP({squeeze_Squeeze, 0}); + auto index_441_Gather = makeOP({squeeze_Squeeze_435, pos_ids, 0}, {{"batch_dims", 0}}); + auto unsqueeze_Unsqueeze = makeOP({index_441_Gather, 1}); + auto mul_Multiply = + makeOP({transpose_Transpose, unsqueeze_Unsqueeze}, {{"auto_broadcast", "numpy"}}); + auto size_ShapeOf_448 = makeOP({transpose_Transpose}, {{"output_type", "i32"}}); + auto size_Gather_450 = makeOP({size_ShapeOf_448, 3, 0}, {{"batch_dims", 0}}); + auto floor_divide_Divide = + makeOP({size_Gather_450, 2}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto floor_divide_Floor = makeOP({floor_divide_Divide}); + auto slice_Unsqueeze_452 = makeOP({floor_divide_Floor, 0}); + auto ScatterUpdate_152312 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_452, {0}}); + auto slice_Slice_459 = makeOP( + {transpose_Transpose, ScatterUpdate_152312, {0ll, 0ll, 0ll, LLONG_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Constant_182988 = makeConst(element::f32, + ov::Shape({ + 1, + 1, + 1, + 1, + }), + {-1.000000f}); + auto neg_Multiply = makeOP({slice_Slice_459, Constant_182988}, {{"auto_broadcast", "numpy"}}); + auto ScatterUpdate_152368 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_452, {0}}); + auto slice_Slice2 = + makeOP({transpose_Transpose, {0, 0, 0, 0}, ScatterUpdate_152368, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat = makeOP({neg_Multiply, slice_Slice2}, {{"axis", -1}}); + auto ScatterUpdate_152421 = makeOP({{0, 0, 0}, {2}, slice_Unsqueeze_426, {0}}); + auto slice_Slice_433 = makeOP({Constant585, {0, 0, 0}, ScatterUpdate_152421, {1, 1, 1}}, + {{"begin_mask", {1, 1, 0}}, + {"end_mask", {1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto squeeze_Squeeze_436 = makeOP({slice_Slice_433, 1}); + auto squeeze_Squeeze_437 = makeOP({squeeze_Squeeze_436, 0}); + auto index_446_Gather = makeOP({squeeze_Squeeze_437, pos_ids, 0}, {{"batch_dims", 0}}); + auto unsqueeze_Unsqueeze_447 = makeOP({index_446_Gather, 1}); + auto mul_Multiply_463 = + makeOP({cat_Concat, unsqueeze_Unsqueeze_447}, {{"auto_broadcast", "numpy"}}); + auto add_Add = makeOP({mul_Multiply, mul_Multiply_463}, {{"auto_broadcast", "numpy"}}); + + return std::make_shared(ov::NodeVector{add_Add}, ov::ParameterVector{input, pos_id_end, pos_ids}); +} + +namespace CPULayerTestsDefinitions { + +class RoPECPUTest : public SubgraphBaseTest { +public: + ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1) { + auto tensor = ov::Tensor(ov::element::i32, shape); + auto* ptr = static_cast(tensor.data()); + for (size_t i = 0; i < tensor.get_size(); i++) { + ptr[i] = start; + start += step; + } + return tensor; + } + + void generate_inputs(const std::vector& targetInputStaticShapes) override { + const auto& funcInputs = function->inputs(); + + const int position_id_start = 15; + auto& input_shape = targetInputStaticShapes[0]; + auto seq_length = input_shape[1]; + + ov::Tensor t_input = + utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768); + ov::Tensor t_position_id_end = create_i32_tensor(ov::Shape({}), position_id_start + seq_length); + ov::Tensor t_position_ids = create_i32_tensor(ov::Shape({1, seq_length}), position_id_start); + + inputs.clear(); + inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input}); + inputs.insert({funcInputs[1].get_node_shared_ptr(), t_position_id_end}); + inputs.insert({funcInputs[2].get_node_shared_ptr(), t_position_ids}); + } + +protected: + void SetUp() override { + targetDevice = ov::test::utils::DEVICE_CPU; + + const int batch = 2; + const int seq_length = 7; + const size_t max_position_embeddings = 2048; + const size_t ndims = 128; + const size_t num_head = 32; + + InputShape inpShape = {{batch, seq_length, num_head, ndims}, {{batch, seq_length, num_head, ndims}}}; + init_input_shapes({inpShape}); + function = buildROPE_Llama2(batch, seq_length, max_position_embeddings, num_head, ndims); + } +}; + +TEST_F(RoPECPUTest, smoke_CompareWithRefs) { + run(); +} + +} // namespace CPULayerTestsDefinitions diff --git a/src/plugins/intel_cpu/tests/unit/transformations/convert_to_rope.cpp b/src/plugins/intel_cpu/tests/unit/transformations/convert_to_rope.cpp new file mode 100644 index 00000000000000..91fae33253a494 --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/transformations/convert_to_rope.cpp @@ -0,0 +1,452 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "utils/gen_pattern.hpp" +#include "utils/print_model.hpp" + +using namespace testing; +using namespace ov::intel_cpu; +using namespace ov::gen_pattern; + +static ov::OutputVector makeCosSinCache(size_t max_position_embeddings, size_t rotary_ndims) { + std::vector lut_sin(max_position_embeddings * rotary_ndims, 0.0f); + std::vector lut_cos(max_position_embeddings * rotary_ndims, 0.0f); + + // rotate_half style cos/sin table: + // y1 = cos(m*xita_i) * x1 - sin(m*xita_i) * x2 + // y2 = cos(m*xita_i) * x2 + sin(m*xita_i) * x1 + // + for (size_t i = 0, k = 0; i < rotary_ndims; i += 2, k++) { + auto xita_i = 1.0 / std::pow(10000.0, static_cast(i) / rotary_ndims); + float* psin = lut_sin.data(); + float* pcos = lut_cos.data(); + for (size_t m = 0; m < max_position_embeddings; m++, psin += rotary_ndims, pcos += rotary_ndims) { + auto vsin = std::sin(xita_i * m); + auto vcos = std::cos(xita_i * m); + pcos[k] = pcos[k + rotary_ndims / 2] = vcos; + psin[k] = psin[k + rotary_ndims / 2] = vsin; + } + } + auto Cos = makeConst(ov::element::f32, ov::Shape({1, 1, max_position_embeddings, rotary_ndims}), lut_cos); + auto Sin = makeConst(ov::element::f32, ov::Shape({1, 1, max_position_embeddings, rotary_ndims}), lut_sin); + + return {Cos, Sin}; +} + +static std::shared_ptr buildROPE_Llama2(const size_t batch, + const size_t seq_length, + const size_t max_position_embeddings, + const size_t ndims, + bool sin_cos_preprocessing) { + auto input = std::make_shared(ov::element::f32, ov::Shape{batch, seq_length, 32, ndims}); + auto param_cos = std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length, ndims}); + auto param_sin = std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length, ndims}); + + auto seq_len = std::make_shared(ov::element::i32, ov::Shape{1}); + auto gather_id = std::make_shared(ov::element::i32, ov::Shape{1, seq_length}); + + auto gather_from_sin_cos = [&](const ov::Output& const_tab) { + auto ScatterUpdate_152236 = makeOP({{0, 0, 0}, {2}, seq_len, {0}}); + auto slice_Slice = makeOP({const_tab, {0, 0, 0}, ScatterUpdate_152236, {1, 1, 1}}, + {{"begin_mask", {1, 1, 0}}, + {"end_mask", {1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto squeeze_Squeeze_435 = + makeOP({slice_Slice, {-1, static_cast(ndims)}}, {{"special_zero", false}}); + auto index_441_Gather = makeOP({squeeze_Squeeze_435, gather_id, {0}}, {{"batch_dims", 0}}); + return makeOP({index_441_Gather, {1, 1, -1, static_cast(ndims)}}, + {{"special_zero", false}}); + }; + + ov::OutputVector cos_sin(2); + ov::ParameterVector parameters; + if (sin_cos_preprocessing) { + auto cos_sin_cache = makeCosSinCache(max_position_embeddings, ndims); + cos_sin[0] = gather_from_sin_cos(cos_sin_cache[0]); + cos_sin[1] = gather_from_sin_cos(cos_sin_cache[1]); + parameters = ov::ParameterVector{input, seq_len, gather_id}; + } else { + cos_sin[0] = param_cos; + cos_sin[1] = param_sin; + parameters = ov::ParameterVector{input, param_cos, param_sin}; + } + + auto transpose_Transpose = makeOP({input, {0, 2, 1, 3}}); + auto mul_Multiply = makeOP({transpose_Transpose, cos_sin[0]}, {{"auto_broadcast", "numpy"}}); + auto slice_Slice_459 = + makeOP({transpose_Transpose, {0, 0, 0, 64}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Constant_182988 = makeConst(ov::element::f32, + ov::Shape({ + 1, + 1, + 1, + 1, + }), + {-1.000000f}); + auto neg_Multiply = makeOP({slice_Slice_459, Constant_182988}, {{"auto_broadcast", "numpy"}}); + auto slice_Slice = + makeOP({transpose_Transpose, {0, 0, 0, 0}, {0, 0, 0, 64}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat = makeOP({neg_Multiply, slice_Slice}, {{"axis", -1}}); + auto mul_Multiply_463 = makeOP({cat_Concat, cos_sin[1]}, {{"auto_broadcast", "numpy"}}); + auto add_Add = makeOP({mul_Multiply, mul_Multiply_463}, {{"auto_broadcast", "numpy"}}); + + return std::make_shared(ov::NodeVector{add_Add}, parameters); +} + +TEST_F(TransformationTestsF, ConvertToROPE_LLama2_no_gather) { + disable_rt_info_check(); + const int batch = 2; + const int seq_length = 16; + const size_t max_position_embeddings = 2048; + const size_t ndims = 128; + const size_t num_head = 32; + + model = buildROPE_Llama2(batch, seq_length, max_position_embeddings, ndims, false); + manager.register_pass(); + + { + auto hidden_states = + std::make_shared(ov::element::f32, ov::Shape{batch, seq_length, num_head, ndims}); + auto param_cos = std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length, ndims}); + auto param_sin = std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length, ndims}); + auto add_Add = makeOP({hidden_states, param_cos, param_sin}, + {{"config.slice_start", 0}, + {"config.slice_stop", 0}, + {"config.input_trans0213", true}, + {"config.is_interleaved", false}, + {"config.rotary_ndims", static_cast(ndims)}, + {"config.gather_position_arg_id", 0}}); + + model_ref = std::make_shared(ov::NodeVector{add_Add}, + ov::ParameterVector{hidden_states, param_cos, param_sin}); + } +} + +TEST_F(TransformationTestsF, ConvertToROPE_LLama2_with_gather) { + disable_rt_info_check(); + const int batch = 2; + const int seq_length = 16; + const size_t max_position_embeddings = 2048; + const size_t ndims = 128; + const size_t num_head = 32; + + model = buildROPE_Llama2(batch, seq_length, max_position_embeddings, ndims, true); + manager.register_pass(); + + { + auto hidden_states = + std::make_shared(ov::element::f32, ov::Shape{batch, seq_length, num_head, ndims}); + auto seq_len = std::make_shared(ov::element::i32, ov::Shape{1}); + auto gather_id = std::make_shared(ov::element::i32, ov::Shape{1, seq_length}); + auto cos_sin_cache = makeCosSinCache(max_position_embeddings, ndims); + + auto add_Add = makeOP({hidden_states, cos_sin_cache[0], cos_sin_cache[1], gather_id}, + {{"config.slice_start", 0}, + {"config.slice_stop", 0}, + {"config.input_trans0213", true}, + {"config.is_interleaved", false}, + {"config.rotary_ndims", static_cast(ndims)}, + {"config.gather_position_arg_id", 3}}); + + model_ref = std::make_shared(ov::NodeVector{add_Add}, + ov::ParameterVector{hidden_states, seq_len, gather_id}); + } +} + +static std::shared_ptr buildROPE_GPTNEOX(const int batch, + const int seq_length, + const int max_position_embeddings, + const int ndims, + const int num_heads, + const int rotary_ndims, + bool sin_cos_preprocessing) { + auto batch_s = static_cast(batch); + auto seq_length_s = static_cast(seq_length); + auto ndims_s = static_cast(ndims); + auto rotary_ndims_s = static_cast(rotary_ndims); + auto num_heads_s = static_cast(num_heads); + + auto input = std::make_shared(ov::element::f32, + ov::Shape{batch_s, seq_length_s, num_heads_s, ndims_s * 3}); + auto seq_len = std::make_shared(ov::element::i32, ov::Shape{1}); + auto gather_idx = + std::make_shared(ov::element::i32, ov::Shape{1, 1, seq_length_s, rotary_ndims_s}); + auto batch_limit = std::make_shared(ov::element::i32, ov::Shape{1}); + + ov::ParameterVector parameters; + ov::OutputVector cos_sin(2); + if (sin_cos_preprocessing) { + auto cos_sin_lut = makeCosSinCache(max_position_embeddings, rotary_ndims); + auto ro_slice_Slice = makeOP({cos_sin_lut[0], {0}, batch_limit, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + cos_sin[0] = makeOP({ro_slice_Slice, gather_idx}, {{"axis", 2}}); + + auto ro_slice_Slice_385 = makeOP({cos_sin_lut[1], {0}, batch_limit, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + cos_sin[1] = makeOP({ro_slice_Slice_385, gather_idx}, {{"axis", 2}}); + parameters = ov::ParameterVector{input, gather_idx, batch_limit}; + } else { + auto param_cos = + std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length_s, rotary_ndims_s}); + auto param_sin = + std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length_s, rotary_ndims_s}); + parameters = ov::ParameterVector{input, param_cos, param_sin}; + cos_sin[0] = param_cos; + cos_sin[1] = param_sin; + } + + auto slice_Slice = makeOP({input, {0, 0, 0, 0}, {0, 0, 0, ndims}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto permute_Transpose = makeOP({slice_Slice, {0, 2, 1, 3}}); + auto slice_Slice_351 = + makeOP({permute_Transpose, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto mul_Multiply = makeOP({slice_Slice_351, cos_sin[0]}, {{"auto_broadcast", "numpy"}}); + auto slice_Slice_420 = makeOP( + {slice_Slice_351, {0, 0, 0, rotary_ndims / 2}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Constant_396096 = makeConst(ov::element::f32, + ov::Shape({ + 1, + 1, + 1, + 1, + }), + {-1.000000f}); + auto neg_Multiply = makeOP({slice_Slice_420, Constant_396096}, {{"auto_broadcast", "numpy"}}); + auto slice_Slice_414 = + makeOP({slice_Slice_351, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims / 2}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat = makeOP({neg_Multiply, slice_Slice_414}, {{"axis", -1}}); + auto mul_Multiply_424 = makeOP({cat_Concat, cos_sin[1]}, {{"auto_broadcast", "numpy"}}); + auto add_Add = makeOP({mul_Multiply, mul_Multiply_424}, {{"auto_broadcast", "numpy"}}); + auto slice_Slice_357 = + makeOP({permute_Transpose, {0, 0, 0, rotary_ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat_458 = makeOP({add_Add, slice_Slice_357}, {{"axis", -1}}); + + return std::make_shared(ov::NodeVector{cat_Concat_458}, parameters); +} + +TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_no_gather) { + disable_rt_info_check(); + const int batch = 2; + const int seq_len = 16; + const int ndims = 80; + const int num_heads = 32; + const int rotary_ndims = 20; + const int max_position_embeddings = 2048; + + model = buildROPE_GPTNEOX(batch, seq_len, max_position_embeddings, ndims, num_heads, rotary_ndims, false); + manager.register_pass(); + { + auto input = + std::make_shared(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims * 3}); + auto param_cos = + std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_len, rotary_ndims}); + auto param_sin = + std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_len, rotary_ndims}); + auto rope = makeOP({input, param_cos, param_sin}, + {{"config.slice_start", 0}, + {"config.slice_stop", ndims}, + {"config.input_trans0213", true}, + {"config.is_interleaved", false}, + {"config.rotary_ndims", rotary_ndims}, + {"config.gather_position_arg_id", 0}}); + model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, param_cos, param_sin}); + } +} + +TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_with_gather) { + disable_rt_info_check(); + const int batch = 2; + const int seq_len = 16; + const int ndims = 80; + const int rotary_ndims = 20; + const int num_heads = 32; + const int max_position_embeddings = 2048; + + model = buildROPE_GPTNEOX(batch, seq_len, max_position_embeddings, ndims, num_heads, rotary_ndims, true); + manager.register_pass(); + { + auto cos_sin = makeCosSinCache(max_position_embeddings, rotary_ndims); + auto input = + std::make_shared(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims * 3}); + auto gather_idx = + std::make_shared(ov::element::i32, ov::Shape{1, 1, seq_len, rotary_ndims}); + auto batch_limit = std::make_shared(ov::element::i32, ov::Shape{1}); + + auto rope = makeOP({input, cos_sin[0], cos_sin[1], gather_idx}, + {{"config.slice_start", 0}, + {"config.slice_stop", ndims}, + {"config.input_trans0213", true}, + {"config.is_interleaved", false}, + {"config.rotary_ndims", rotary_ndims}, + {"config.gather_position_arg_id", 3}}); + model_ref = + std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, gather_idx, batch_limit}); + } +} + +TEST_F(TransformationTestsF, ConvertToROPE_GPTJ) { + disable_rt_info_check(); + const int batch = 2; + const int seq_len = 7; + const int num_heads = 16; + const int ndims = 256; + const int rotary_ndims = 64; + { + std::vector rpi_idx(rotary_ndims); + for (int i = 0, index = 0; i < rotary_ndims; i += 2, index++) { + rpi_idx[i] = index; + rpi_idx[i + 1] = index; + } + auto repeat_interleave_index = makeConst(ov::element::i32, ov::Shape({rotary_ndims}), rpi_idx); + + auto input = + std::make_shared(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims}); + auto gather_sin_cos = + std::make_shared(ov::element::f32, ov::Shape{1, seq_len, rotary_ndims}); + + auto split = makeOP({gather_sin_cos, {-1}, {rotary_ndims / 2, -1}}); + auto sin_tab = + makeOP({split->output(0), {1, -1, 1, rotary_ndims / 2}}, {{"special_zero", false}}); + auto cos_tab = + makeOP({split->output(1), {1, -1, 1, rotary_ndims / 2}}, {{"special_zero", false}}); + + auto slice_Slice_576 = + makeOP({input, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto repeat_interleave_Cos = + makeOP({cos_tab, repeat_interleave_index, {3}}, {{"batch_dims", 0}}); + auto mul_Multiply_757 = + makeOP({slice_Slice_576, repeat_interleave_Cos}, {{"auto_broadcast", "numpy"}}); + + auto slice_Slice_787 = + makeOP({slice_Slice_576, {0, 0, 0, 1}, {0, 0, 0, INT_MAX}, {1, 1, 1, 2}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Constant_191672 = makeConst(ov::element::f32, + ov::Shape({ + 1, + 1, + 1, + 1, + }), + {-1.000000f}); + auto neg_Multiply_790 = + makeOP({slice_Slice_787, Constant_191672}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_61918 = makeOP({neg_Multiply_790, {-1}}); + auto slice_Slice_781 = + makeOP({slice_Slice_576, {0, 0, 0, 0}, {0, 0, 0, INT_MAX}, {1, 1, 1, 2}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Unsqueeze_61919 = makeOP({slice_Slice_781, {-1}}); + auto stack_795 = makeOP({Unsqueeze_61918, Unsqueeze_61919}, {{"axis", -1}}); + auto ShapeOf_165368 = makeOP>( + {stack_795}, + {{"type_relax", true}, {"input_data_types", {}}, {"output_data_types", {ov::element::i32}}}); + auto flatten_Slice_811 = makeOP({ShapeOf_165368, {0}, {3}, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto flatten_Concat_814 = makeOP({flatten_Slice_811, {-1}}, {{"axis", 0}}); + auto flatten_Reshape_815 = + makeOP({stack_795, flatten_Concat_814}, {{"special_zero", true}}); + auto repeat_interleave_Sin = + makeOP({sin_tab, repeat_interleave_index, {3}}, {{"batch_dims", 0}}); + auto mul_Multiply_816 = + makeOP({flatten_Reshape_815, repeat_interleave_Sin}, {{"auto_broadcast", "numpy"}}); + auto add_Add_819 = makeOP({mul_Multiply_757, mul_Multiply_816}, {{"auto_broadcast", "numpy"}}); + auto slice_Slice_582 = + makeOP({input, {0, 0, 0, rotary_ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat_826 = makeOP({add_Add_819, slice_Slice_582}, {{"axis", -1}}); + auto permute_Transpose_828 = makeOP({cat_Concat_826, {0, 2, 1, 3}}); + model = std::make_shared(ov::NodeVector{permute_Transpose_828}, + ov::ParameterVector{input, gather_sin_cos}); + } + manager.register_pass(); + { + auto input = + std::make_shared(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims}); + auto cos_sin = std::make_shared(ov::element::f32, ov::Shape{1, seq_len, rotary_ndims}); + auto rope = makeOP({input, cos_sin, cos_sin}, + {{"config.slice_start", 0}, + {"config.slice_stop", 0}, + {"config.input_trans0213", false}, + {"config.is_interleaved", true}, + {"config.rotary_ndims", rotary_ndims}, + {"config.gather_position_arg_id", 0}}); + model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, cos_sin}); + } +} \ No newline at end of file