Skip to content

Commit

Permalink
[GPU] Moved RMSFusion higher in pipeline and added output type fuse
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed Nov 14, 2024
1 parent a661f0d commit c563ce1
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 12 deletions.
6 changes: 6 additions & 0 deletions src/common/transformations/include/ov_ops/rms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ class TRANSFORMATIONS_API RMS : public ov::op::Op {
m_epsilon = epsilon;
}

void set_output_type(const element::Type& output_type) {
m_output_type = output_type;
}
// Overload collision with method on Node
using Node::set_output_type;

private:
double m_epsilon{0};
ov::element::Type m_output_type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "openvino/pass/constant_folding.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/reference/convert.hpp"
#include "ov_ops/rms.hpp"
#include "ov_ops/type_relaxed.hpp"
#include "transformations/fp16_compression/align_mixed_fp32_fp16_types.hpp"
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
Expand Down Expand Up @@ -59,6 +60,7 @@ bool fuse_type_to_maxpool(const std::shared_ptr<ov::Node>& node, const precision
bool fuse_type_to_nonzero(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
bool fuse_type_to_bucketize(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
bool fuse_type_to_ctc_greedy_decoder_seq_len(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
bool fuse_type_to_rms(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);

bool fuse_type_to_random_uniform_v8(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);

Expand Down Expand Up @@ -465,7 +467,8 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ov::Model>&
{ov::op::v0::PriorBox::get_type_info_static(), fuse_type_to_prior_box<ov::op::v0::PriorBox>},
{ov::op::v8::PriorBox::get_type_info_static(), fuse_type_to_prior_box<ov::op::v8::PriorBox>},
{ov::op::v0::PriorBoxClustered::get_type_info_static(), fuse_type_to_prior_box<ov::op::v0::PriorBoxClustered>},
{ov::op::v15::SearchSorted::get_type_info_static(), fuse_type_to_search_sorted_v15}};
{ov::op::v15::SearchSorted::get_type_info_static(), fuse_type_to_search_sorted_v15},
{ov::op::internal::RMS::get_type_info_static(), fuse_type_to_rms}};

for (const auto& it : m_additional_type_to_fuse_map) {
type_to_fuse[it.first] = it.second;
Expand Down Expand Up @@ -858,6 +861,20 @@ bool fuse_type_to_nms_rotated(const std::shared_ptr<ov::Node>& node, const preci
return res;
}

bool fuse_type_to_rms(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions) {
auto it = precisions.find(node->get_output_element_type(0));
if (it == precisions.end())
return false;
const auto& to = it->second;
if (auto rms = ov::as_type_ptr<ov::op::internal::RMS>(node)) {
if (to.is_real()) {
rms->set_output_type(to);
return true;
}
}
return false;
}

namespace {

bool update_type(size_t idx,
Expand Down
25 changes: 14 additions & 11 deletions src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
#include "transformations/op_conversions/convert_broadcast3.hpp"
#include "transformations/op_conversions/convert_deformable_conv_v8_to_v1.hpp"
#include "transformations/op_conversions/convert_depth_to_space.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
#include "transformations/op_conversions/convert_gather_0d.hpp"
#include "transformations/op_conversions/convert_gather_downgrade.hpp"
#include "transformations/op_conversions/convert_gelu.hpp"
Expand Down Expand Up @@ -338,6 +339,19 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
});
}

manager.register_pass<ov::pass::ConvertDivideWithConstant>();
manager.register_pass<ov::pass::ConvertDivide>();

pass_config->set_callback<ov::pass::RMSFusion>([=](const_node_ptr& root) -> bool {
if (!root->get_input_partial_shape(0).is_static()) {
return false;
}
const auto& gamma_shape = root->get_input_partial_shape(0).to_shape();
const int32_t vec_size = 8;
return static_cast<int32_t>((gamma_shape.back() / vec_size)) > static_cast<int32_t>(device_info.max_work_group_size);
});
manager.register_pass<ov::pass::RMSFusion>(false);

const bool keep_precision_sensitive_in_fp32_1 = true;
const bool convert_input_output_precision = false;
const bool store_original_precision_as_rt_attribute = true;
Expand Down Expand Up @@ -855,16 +869,6 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {

manager.register_pass<ov::pass::ConvertGatherToGatherCompressed>();
auto pass_config = manager.get_pass_config();
pass_config->set_callback<ov::pass::RMSFusion>([=](const_node_ptr& root) -> bool {
if (!root->get_input_node_ptr(0)->get_input_partial_shape(0).is_static()) {
return false;
}
const auto& gamma_shape = root->get_input_node_ptr(0)->get_input_partial_shape(0).to_shape();
const int32_t vec_size = 8;
return static_cast<int32_t>((gamma_shape.back() / vec_size)) > static_cast<int32_t>(device_info.max_work_group_size);
});

manager.register_pass<ov::pass::RMSFusion>();
manager.register_pass<ov::intel_gpu::KVCacheFusion>();
manager.register_pass<ov::intel_gpu::FullyConnectedConvertFusion>();
manager.register_pass<ov::intel_gpu::TransposeFusion>(device_info.supports_immad);
Expand Down Expand Up @@ -930,7 +934,6 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
GPU_DEBUG_IF(cldnn::debug_configuration::get_instance()->verbose >= 1) {
manager.register_pass<ov::intel_gpu::PrintModelStatistics>();
}

manager.run_passes(func);
}
}
Expand Down

0 comments on commit c563ce1

Please sign in to comment.