Skip to content

Commit

Permalink
Match all element types
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-smnk committed Jul 29, 2024
1 parent c793696 commit 2d35d0b
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,10 @@ void injectMLIR(std::shared_ptr<ov::Model> model, MLIRContext* context, bool tpp
using namespace ov::op;
manager.set_per_pass_validation(false);
manager.register_pass<ov::pass::SymbolicPropagation>();
manager.register_pass<BinaryEltwisePattern<v1::Add, linalg::AddOp>>(ov::element::f32);
manager.register_pass<BinaryEltwisePattern<v1::Subtract, linalg::SubOp>>(ov::element::f32);
manager.register_pass<BinaryEltwisePattern<v1::Multiply, linalg::MulOp>>(ov::element::f32);
manager.register_pass<BinaryEltwisePattern<v1::Divide, linalg::DivOp>>(ov::element::f32);
manager.register_pass<BinaryEltwisePattern<v1::Add, linalg::AddOp>>();
manager.register_pass<BinaryEltwisePattern<v1::Subtract, linalg::SubOp>>();
manager.register_pass<BinaryEltwisePattern<v1::Multiply, linalg::MulOp>>();
manager.register_pass<BinaryEltwisePattern<v1::Divide, linalg::DivOp>>();
manager.register_pass<ReluPattern>();
manager.register_pass<MatMulPattern>();
manager.register_pass<Partitioner>(context, tpp_mlir_enabled);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,7 @@ Location createLocation(MLIRContext* ctx, NodePtr node) {
return createLayerLocation(ctx, node->get_friendly_name(), node->get_type_name());
}

bool elementwise_no_broadcast_predicate_impl(const ov::Output<ov::Node>& output, ov::element::Type type) {
if (output.get_element_type() != type) {
return false;
}
bool elementwise_no_broadcast_predicate(const ov::Output<ov::Node>& output) {
if (has_dynamic_rank(output.get_node_shared_ptr())) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,7 @@ RankedTensorType importTensor(MLIRContext* ctx,

Location createLocation(MLIRContext* ctx, NodePtr node);

bool elementwise_no_broadcast_predicate_impl(const ov::Output<ov::Node>& output, ov::element::Type type);

template <ov::element::Type_t type>
bool elementwise_no_broadcast_predicate(const ov::Output<ov::Node>& output) {
return elementwise_no_broadcast_predicate_impl(output, type);
}
bool elementwise_no_broadcast_predicate(const ov::Output<ov::Node>& output);

// Borrowed it from TPP-MLIR. FIXME: Do we have a better upstreamed alternative?
template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ using namespace ov::pass::pattern;
using namespace ov::op;

ReluPattern::ReluPattern()
: MarkPattern(wrap_type<v0::Relu>({any_input()}, elementwise_no_broadcast_predicate<ov::element::f32>),
ConvertRelu()) {}
: MarkPattern(wrap_type<v0::Relu>({any_input()}, elementwise_no_broadcast_predicate), ConvertRelu()) {}

} // namespace mlir
} // namespace ov

0 comments on commit 2d35d0b

Please sign in to comment.