From e35dc645fc26b05ae4072ca2b2d9720cc0338d32 Mon Sep 17 00:00:00 2001 From: Sergey Lyalin Date: Thu, 25 Jul 2024 14:23:11 +0000 Subject: [PATCH] Element type configurable restriction for the new BinaryEltwisePattern. Forced f32 in the conversion pipeline. --- .../src/transformations/mlir/convert.cpp | 8 ++++---- .../src/transformations/mlir/op/binary_eltwise.cpp | 7 +++++-- .../src/transformations/mlir/op/binary_eltwise.hpp | 12 +++++++++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/common/transformations/src/transformations/mlir/convert.cpp b/src/common/transformations/src/transformations/mlir/convert.cpp index 8311596240b7a1..63bafdc0d5fa1a 100644 --- a/src/common/transformations/src/transformations/mlir/convert.cpp +++ b/src/common/transformations/src/transformations/mlir/convert.cpp @@ -277,10 +277,10 @@ void injectMLIR(std::shared_ptr model, MLIRContext* context) { using namespace ov::op; manager.set_per_pass_validation(false); manager.register_pass(); - manager.register_pass>(); - manager.register_pass>(); - manager.register_pass>(); - manager.register_pass>(); + manager.register_pass>(ov::element::f32); + manager.register_pass>(ov::element::f32); + manager.register_pass>(ov::element::f32); + manager.register_pass>(ov::element::f32); manager.register_pass(); manager.register_pass(); manager.register_pass(context); diff --git a/src/common/transformations/src/transformations/mlir/op/binary_eltwise.cpp b/src/common/transformations/src/transformations/mlir/op/binary_eltwise.cpp index 66e29b337f37dd..c840c720b4f497 100644 --- a/src/common/transformations/src/transformations/mlir/op/binary_eltwise.cpp +++ b/src/common/transformations/src/transformations/mlir/op/binary_eltwise.cpp @@ -67,11 +67,14 @@ namespace mlir { using namespace ov::pass::pattern; -BinaryEltwisePatternBase::BinaryEltwisePatternBase(NodeTypeInfo wrapped_type, Builder op_builder) +BinaryEltwisePatternBase::BinaryEltwisePatternBase(NodeTypeInfo wrapped_type, Builder op_builder, const std::set& element_types) : MarkPattern( std::make_shared( wrapped_type, - [](const Output& output) { + [element_types](const Output& output) { + if(!element_types.empty() && !element_types.count(output.get_element_type())) { + return false; + } auto node = output.get_node_shared_ptr(); for(const auto& input: node->inputs()) { if(!statically_broadcastable(input.get_partial_shape(), output.get_partial_shape())) { diff --git a/src/common/transformations/src/transformations/mlir/op/binary_eltwise.hpp b/src/common/transformations/src/transformations/mlir/op/binary_eltwise.hpp index 8bcec93e3da7f0..1c410608cbc4a5 100644 --- a/src/common/transformations/src/transformations/mlir/op/binary_eltwise.hpp +++ b/src/common/transformations/src/transformations/mlir/op/binary_eltwise.hpp @@ -14,19 +14,25 @@ class BinaryEltwisePatternBase : public MarkPattern { using Builder = std::function; OPENVINO_RTTI("BinaryEltwisePatternBase", "0"); - BinaryEltwisePatternBase(NodeTypeInfo wrapped_type, Builder op_builder); + BinaryEltwisePatternBase(NodeTypeInfo wrapped_type, Builder op_builder, const std::set& element_types = {}); }; template class BinaryEltwisePattern : public BinaryEltwisePatternBase { public: - BinaryEltwisePattern () : + // Allow conversion for given `element_types` only, except case when `element_types` is empty which means no restrictions on types, everything is allowed. + BinaryEltwisePattern (const std::set& element_types = {}) : BinaryEltwisePatternBase( OVOp::get_type_info_static(), [](OpBuilder& builder, ::mlir::Location loc, ValueRange ins, ValueRange outs) -> Operation* { return builder.create(loc, ins, outs); - }) + }, + element_types) + {} + + BinaryEltwisePattern (const element::Type& element_type) : + BinaryEltwisePattern(std::set{element_type}) {} };