Skip to content

Commit

Permalink
Element type configurable restriction for the new BinaryEltwisePatter…
Browse files Browse the repository at this point in the history
…n. Forced f32 in the conversion pipeline.
  • Loading branch information
slyalin committed Jul 25, 2024
1 parent e7555c8 commit e35dc64
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,10 @@ void injectMLIR(std::shared_ptr<ov::Model> model, MLIRContext* context) {
using namespace ov::op;
manager.set_per_pass_validation(false);
manager.register_pass<ov::pass::SymbolicPropagation>();
manager.register_pass<BinaryEltwisePattern<v1::Add, linalg::AddOp>>();
manager.register_pass<BinaryEltwisePattern<v1::Subtract, linalg::SubOp>>();
manager.register_pass<BinaryEltwisePattern<v1::Multiply, linalg::MulOp>>();
manager.register_pass<BinaryEltwisePattern<v1::Divide, linalg::DivOp>>();
manager.register_pass<BinaryEltwisePattern<v1::Add, linalg::AddOp>>(ov::element::f32);
manager.register_pass<BinaryEltwisePattern<v1::Subtract, linalg::SubOp>>(ov::element::f32);
manager.register_pass<BinaryEltwisePattern<v1::Multiply, linalg::MulOp>>(ov::element::f32);
manager.register_pass<BinaryEltwisePattern<v1::Divide, linalg::DivOp>>(ov::element::f32);
manager.register_pass<ReluPattern>();
manager.register_pass<MatMulPattern>();
manager.register_pass<Partitioner>(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,14 @@ namespace mlir {

using namespace ov::pass::pattern;

BinaryEltwisePatternBase::BinaryEltwisePatternBase(NodeTypeInfo wrapped_type, Builder op_builder)
BinaryEltwisePatternBase::BinaryEltwisePatternBase(NodeTypeInfo wrapped_type, Builder op_builder, const std::set<element::Type>& element_types)
: MarkPattern(
std::make_shared<pass::pattern::op::WrapType>(
wrapped_type,
[](const Output<Node>& output) {
[element_types](const Output<Node>& output) {
if(!element_types.empty() && !element_types.count(output.get_element_type())) {
return false;
}
auto node = output.get_node_shared_ptr();
for(const auto& input: node->inputs()) {
if(!statically_broadcastable(input.get_partial_shape(), output.get_partial_shape())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,25 @@ class BinaryEltwisePatternBase : public MarkPattern {
using Builder = std::function<Operation*(OpBuilder&, ::mlir::Location, ValueRange, ValueRange)>;

OPENVINO_RTTI("BinaryEltwisePatternBase", "0");
BinaryEltwisePatternBase(NodeTypeInfo wrapped_type, Builder op_builder);
BinaryEltwisePatternBase(NodeTypeInfo wrapped_type, Builder op_builder, const std::set<element::Type>& element_types = {});
};


template <typename OVOp, typename LinalgOp>
class BinaryEltwisePattern : public BinaryEltwisePatternBase {
public:
BinaryEltwisePattern () :
// Allow conversion for given `element_types` only, except case when `element_types` is empty which means no restrictions on types, everything is allowed.
BinaryEltwisePattern (const std::set<element::Type>& element_types = {}) :
BinaryEltwisePatternBase(
OVOp::get_type_info_static(),
[](OpBuilder& builder, ::mlir::Location loc, ValueRange ins, ValueRange outs) -> Operation* {
return builder.create<LinalgOp>(loc, ins, outs);
})
},
element_types)
{}

BinaryEltwisePattern (const element::Type& element_type) :
BinaryEltwisePattern(std::set<element::Type>{element_type})
{}
};

Expand Down

0 comments on commit e35dc64

Please sign in to comment.