Skip to content

Commit

Permalink
move RMSFusion pass ahead of ConvertPrecision
Browse files Browse the repository at this point in the history
  • Loading branch information
ceciliapeng2011 committed Sep 20, 2024
1 parent 42b7322 commit 2b84653
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 28 deletions.
4 changes: 1 addition & 3 deletions src/common/transformations/include/ov_ops/rms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ class TRANSFORMATIONS_API RMS : public ov::op::Op {
/// \param output_type Output element type
RMS(const Output<Node>& data,
const Output<Node>& gamma,
double epsilson,
const ov::element::Type output_type = ov::element::undefined);
double epsilson);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

Expand All @@ -45,7 +44,6 @@ class TRANSFORMATIONS_API RMS : public ov::op::Op {

private:
double m_epsilon{0};
ov::element::Type m_output_type;
};

} // namespace internal
Expand Down
11 changes: 4 additions & 7 deletions src/common/transformations/src/ov_ops/rms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,24 @@ namespace ov {
namespace op {
namespace internal {

RMS::RMS(const Output<Node>& data, const Output<Node>& gamma, double epsilson, const ov::element::Type output_type)
RMS::RMS(const Output<Node>& data, const Output<Node>& gamma, double epsilson)
: Op({data, gamma}),
m_epsilon(epsilson),
m_output_type(output_type) {
m_epsilon(epsilson) {
validate_and_infer_types();
}

bool RMS::visit_attributes(ov::AttributeVisitor& visitor) {
visitor.on_attribute("epsilon", m_epsilon);
visitor.on_attribute("output_type", m_output_type);
return true;
}

void RMS::validate_and_infer_types() {
auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type;
set_output_type(0, output_type, get_input_partial_shape(0));
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}

std::shared_ptr<Node> RMS::clone_with_new_inputs(const ov::OutputVector& new_args) const {
check_new_args_count(this, new_args);
return std::make_shared<RMS>(new_args.at(0), new_args.at(1), m_epsilon, m_output_type);
return std::make_shared<RMS>(new_args.at(0), new_args.at(1), m_epsilon);
}

} // namespace internal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "ov_ops/rms.hpp"
#include "transformations/utils/utils.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/or.hpp"

namespace ov {
namespace pass {
Expand Down Expand Up @@ -43,28 +45,38 @@ RMSFusion::RMSFusion(bool force_tail_convert) {

// x^2
auto const_power = wrap_type<ov::op::v0::Constant>(constant_value(2));
auto power = wrap_type<ov::op::v1::Power>({x, const_power});
const auto optional_convert_power = pattern::optional<ov::op::v0::Convert>(const_power);
auto power = wrap_type<ov::op::v1::Power>({x, optional_convert_power});

// ReduceMean(x^2,axes)
auto mean_axes = wrap_type<ov::op::v0::Constant>(constant_value(-1));
auto mean = wrap_type<ov::op::v1::ReduceMean>({power, mean_axes});

// ReduceMean(x^2,axes)+eps
auto eps = wrap_type<ov::op::v0::Constant>();
auto add_eps = wrap_type<ov::op::v1::Add>({mean, eps});
const auto optional_convert_eps = pattern::optional<ov::op::v0::Convert>(eps);
auto add_eps = wrap_type<ov::op::v1::Add>({mean, optional_convert_eps});

// Sqrt(ReduceMean(x^2,axes)+eps)
auto sqrt = wrap_type<ov::op::v0::Sqrt>({add_eps});

// 1/Sqrt(ReduceMean(x^2,axes)+eps)
auto const_div = wrap_type<ov::op::v0::Constant>(constant_value(-1));
auto div = wrap_type<ov::op::v1::Power>({sqrt, const_div});
auto const_div1 = wrap_type<ov::op::v0::Constant>(constant_value(-1));
auto div1 = wrap_type<ov::op::v1::Power>({sqrt, const_div1});

auto const_div2 = wrap_type<ov::op::v0::Constant>(constant_value(1));
const auto optional_convert_div2 = pattern::optional<ov::op::v0::Convert>(const_div2);
auto div2 = wrap_type<ov::op::v1::Divide>({optional_convert_div2, sqrt});
auto div = std::make_shared<pattern::op::Or>(OutputVector{div1, div2});

// x * 1/Sqrt(ReduceMean(x^2,axes)+eps)
auto mul1 = wrap_type<ov::op::v1::Multiply>({x, div});

// x * 1/Sqrt(ReduceMean(x^2,axes)+eps) * gamma
auto gamma = wrap_type<ov::op::v0::Constant>(type_matches(element::f32));
auto gamma1 = wrap_type<ov::op::v0::Constant>(type_matches(element::f32));
auto gamma2 = wrap_type<ov::op::v0::Constant>(type_matches(element::f16));
const auto optional_convert_gamma = pattern::optional<ov::op::v0::Convert>(gamma2);
auto gamma = std::make_shared<pattern::op::Or>(OutputVector{gamma1, optional_convert_gamma});
auto mul2 = wrap_type<ov::op::v1::Multiply>({gamma, mul1});

std::shared_ptr<ov::Node> comp = mul2;
Expand All @@ -88,7 +100,8 @@ RMSFusion::RMSFusion(bool force_tail_convert) {
return false;
}

const auto& gamma_node = pattern_map.at(gamma).get_node_shared_ptr();
const auto& mul2_node = pattern_map.at(mul2).get_node_shared_ptr();
const auto& gamma_node = mul2_node->input_values()[0].get_node_shared_ptr();

const auto& mean_node = pattern_map.at(mean).get_node_shared_ptr();
const auto& axes = pattern_map.at(mean_axes).get_node_shared_ptr();
Expand All @@ -100,8 +113,7 @@ RMSFusion::RMSFusion(bool force_tail_convert) {
return false;
}

auto output_type = m.get_match_root()->get_output_element_type(0);
auto rms = std::make_shared<ov::op::internal::RMS>(x_output, gamma_node, eps_value, output_type);
auto rms = std::make_shared<ov::op::internal::RMS>(x_output, gamma_node, eps_value);
rms->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), rms);
ov::replace_node(m.get_match_root(), rms);
Expand Down
23 changes: 13 additions & 10 deletions src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,19 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
// fuse softmax, MVN patterns, so that they will not be marked as precision sensitive in ConvertPrecision
manager.register_pass<ov::pass::SoftmaxFusion>();
manager.register_pass<ov::pass::MVNFusion>();

// fuse RMS patterns, so that they will not be marked as precision sensitive in ConvertPrecision
manager.register_pass<ov::pass::RMSFusion>(false);
pass_config->set_callback<ov::pass::RMSFusion>([=](const_node_ptr& root) -> bool {
if (!root->get_input_node_ptr(0)->get_input_partial_shape(0).is_static()) {
return false;
}
const auto& gamma_shape = root->get_input_node_ptr(0)->get_input_partial_shape(0).to_shape();
const int32_t vec_size = 8;
auto ret = static_cast<int32_t>((gamma_shape.back() / vec_size)) > static_cast<int32_t>(device_info.max_work_group_size);
return ret;
});

// decompose MVNs that sre not supported in GPU, so that they will be marked as precision sensitive in ConvertPrecision
manager.register_pass<ov::pass::MVN6Decomposition>();
// Run these broadcast optimizations earlier to ensure that those are executed before NopElimination/ConstantFolding
Expand Down Expand Up @@ -831,16 +844,6 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {

manager.register_pass<ov::pass::ConvertGatherToGatherCompressed>();
auto pass_config = manager.get_pass_config();
pass_config->set_callback<ov::pass::RMSFusion>([=](const_node_ptr& root) -> bool {
if (!root->get_input_node_ptr(0)->get_input_partial_shape(0).is_static()) {
return false;
}
const auto& gamma_shape = root->get_input_node_ptr(0)->get_input_partial_shape(0).to_shape();
const int32_t vec_size = 8;
return static_cast<int32_t>((gamma_shape.back() / vec_size)) > static_cast<int32_t>(device_info.max_work_group_size);
});

manager.register_pass<ov::pass::RMSFusion>();
manager.register_pass<ov::intel_gpu::KVCacheFusion>();
manager.register_pass<ov::intel_gpu::FullyConnectedConvertFusion>();
manager.register_pass<ov::intel_gpu::TransposeFusion>(device_info.supports_immad);
Expand Down

0 comments on commit 2b84653

Please sign in to comment.