diff --git a/src/common/transformations/include/ov_ops/rms.hpp b/src/common/transformations/include/ov_ops/rms.hpp index 4e22e505819a10..a9a90f817cc6bd 100644 --- a/src/common/transformations/include/ov_ops/rms.hpp +++ b/src/common/transformations/include/ov_ops/rms.hpp @@ -26,8 +26,7 @@ class TRANSFORMATIONS_API RMS : public ov::op::Op { /// \param output_type Output element type RMS(const Output& data, const Output& gamma, - double epsilson, - const ov::element::Type output_type = ov::element::undefined); + double epsilson); bool visit_attributes(ov::AttributeVisitor& visitor) override; @@ -45,7 +44,6 @@ class TRANSFORMATIONS_API RMS : public ov::op::Op { private: double m_epsilon{0}; - ov::element::Type m_output_type; }; } // namespace internal diff --git a/src/common/transformations/src/ov_ops/rms.cpp b/src/common/transformations/src/ov_ops/rms.cpp index 885494336a1c45..688dad3ac30b72 100644 --- a/src/common/transformations/src/ov_ops/rms.cpp +++ b/src/common/transformations/src/ov_ops/rms.cpp @@ -8,27 +8,24 @@ namespace ov { namespace op { namespace internal { -RMS::RMS(const Output& data, const Output& gamma, double epsilson, const ov::element::Type output_type) +RMS::RMS(const Output& data, const Output& gamma, double epsilson) : Op({data, gamma}), - m_epsilon(epsilson), - m_output_type(output_type) { + m_epsilon(epsilson) { validate_and_infer_types(); } bool RMS::visit_attributes(ov::AttributeVisitor& visitor) { visitor.on_attribute("epsilon", m_epsilon); - visitor.on_attribute("output_type", m_output_type); return true; } void RMS::validate_and_infer_types() { - auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type; - set_output_type(0, output_type, get_input_partial_shape(0)); + set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); } std::shared_ptr RMS::clone_with_new_inputs(const ov::OutputVector& new_args) const { check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), new_args.at(1), m_epsilon, m_output_type); + return std::make_shared(new_args.at(0), new_args.at(1), m_epsilon); } } // namespace internal diff --git a/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp index 27dbf6bcf737fe..5969504a9f5561 100644 --- a/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp @@ -16,6 +16,8 @@ #include "openvino/pass/pattern/op/wrap_type.hpp" #include "ov_ops/rms.hpp" #include "transformations/utils/utils.hpp" +#include "openvino/pass/pattern/op/optional.hpp" +#include "openvino/pass/pattern/op/or.hpp" namespace ov { namespace pass { @@ -43,7 +45,8 @@ RMSFusion::RMSFusion(bool force_tail_convert) { // x^2 auto const_power = wrap_type(constant_value(2)); - auto power = wrap_type({x, const_power}); + const auto optional_convert_power = pattern::optional(const_power); + auto power = wrap_type({x, optional_convert_power}); // ReduceMean(x^2,axes) auto mean_axes = wrap_type(constant_value(-1)); @@ -51,20 +54,29 @@ RMSFusion::RMSFusion(bool force_tail_convert) { // ReduceMean(x^2,axes)+eps auto eps = wrap_type(); - auto add_eps = wrap_type({mean, eps}); + const auto optional_convert_eps = pattern::optional(eps); + auto add_eps = wrap_type({mean, optional_convert_eps}); // Sqrt(ReduceMean(x^2,axes)+eps) auto sqrt = wrap_type({add_eps}); // 1/Sqrt(ReduceMean(x^2,axes)+eps) - auto const_div = wrap_type(constant_value(-1)); - auto div = wrap_type({sqrt, const_div}); + auto const_div1 = wrap_type(constant_value(-1)); + auto div1 = wrap_type({sqrt, const_div1}); + + auto const_div2 = wrap_type(constant_value(1)); + const auto optional_convert_div2 = pattern::optional(const_div2); + auto div2 = wrap_type({optional_convert_div2, sqrt}); + auto div = std::make_shared(OutputVector{div1, div2}); // x * 1/Sqrt(ReduceMean(x^2,axes)+eps) auto mul1 = wrap_type({x, div}); // x * 1/Sqrt(ReduceMean(x^2,axes)+eps) * gamma - auto gamma = wrap_type(type_matches(element::f32)); + auto gamma1 = wrap_type(type_matches(element::f32)); + auto gamma2 = wrap_type(type_matches(element::f16)); + const auto optional_convert_gamma = pattern::optional(gamma2); + auto gamma = std::make_shared(OutputVector{gamma1, optional_convert_gamma}); auto mul2 = wrap_type({gamma, mul1}); std::shared_ptr comp = mul2; @@ -88,7 +100,8 @@ RMSFusion::RMSFusion(bool force_tail_convert) { return false; } - const auto& gamma_node = pattern_map.at(gamma).get_node_shared_ptr(); + const auto& mul2_node = pattern_map.at(mul2).get_node_shared_ptr(); + const auto& gamma_node = mul2_node->input_values()[0].get_node_shared_ptr(); const auto& mean_node = pattern_map.at(mean).get_node_shared_ptr(); const auto& axes = pattern_map.at(mean_axes).get_node_shared_ptr(); @@ -100,8 +113,7 @@ RMSFusion::RMSFusion(bool force_tail_convert) { return false; } - auto output_type = m.get_match_root()->get_output_element_type(0); - auto rms = std::make_shared(x_output, gamma_node, eps_value, output_type); + auto rms = std::make_shared(x_output, gamma_node, eps_value); rms->set_friendly_name(m.get_match_root()->get_friendly_name()); ov::copy_runtime_info(m.get_matched_nodes(), rms); ov::replace_node(m.get_match_root(), rms); diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 40c7ab48c486cb..b1387872f0db11 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -290,6 +290,19 @@ void TransformationsPipeline::apply(std::shared_ptr func) { // fuse softmax, MVN patterns, so that they will not be marked as precision sensitive in ConvertPrecision manager.register_pass(); manager.register_pass(); + + // fuse RMS patterns, so that they will not be marked as precision sensitive in ConvertPrecision + manager.register_pass(false); + pass_config->set_callback([=](const_node_ptr& root) -> bool { + if (!root->get_input_node_ptr(0)->get_input_partial_shape(0).is_static()) { + return false; + } + const auto& gamma_shape = root->get_input_node_ptr(0)->get_input_partial_shape(0).to_shape(); + const int32_t vec_size = 8; + auto ret = static_cast((gamma_shape.back() / vec_size)) > static_cast(device_info.max_work_group_size); + return ret; + }); + // decompose MVNs that sre not supported in GPU, so that they will be marked as precision sensitive in ConvertPrecision manager.register_pass(); // Run these broadcast optimizations earlier to ensure that those are executed before NopElimination/ConstantFolding @@ -831,16 +844,6 @@ void TransformationsPipeline::apply(std::shared_ptr func) { manager.register_pass(); auto pass_config = manager.get_pass_config(); - pass_config->set_callback([=](const_node_ptr& root) -> bool { - if (!root->get_input_node_ptr(0)->get_input_partial_shape(0).is_static()) { - return false; - } - const auto& gamma_shape = root->get_input_node_ptr(0)->get_input_partial_shape(0).to_shape(); - const int32_t vec_size = 8; - return static_cast((gamma_shape.back() / vec_size)) > static_cast(device_info.max_work_group_size); - }); - - manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(device_info.supports_immad);