From 2d35d0b6b87c2a925e456ef6960018e5c09301cb Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Mon, 29 Jul 2024 18:22:02 +0200 Subject: [PATCH] Match all element types --- .../transformations/src/transformations/mlir/convert.cpp | 8 ++++---- .../src/transformations/mlir/convert_common.cpp | 5 +---- .../src/transformations/mlir/convert_common.hpp | 7 +------ .../transformations/src/transformations/mlir/op/relu.cpp | 3 +-- 4 files changed, 7 insertions(+), 16 deletions(-) diff --git a/src/common/transformations/src/transformations/mlir/convert.cpp b/src/common/transformations/src/transformations/mlir/convert.cpp index b5a076749ee5fc..da777e30094da1 100644 --- a/src/common/transformations/src/transformations/mlir/convert.cpp +++ b/src/common/transformations/src/transformations/mlir/convert.cpp @@ -283,10 +283,10 @@ void injectMLIR(std::shared_ptr model, MLIRContext* context, bool tpp using namespace ov::op; manager.set_per_pass_validation(false); manager.register_pass(); - manager.register_pass>(ov::element::f32); - manager.register_pass>(ov::element::f32); - manager.register_pass>(ov::element::f32); - manager.register_pass>(ov::element::f32); + manager.register_pass>(); + manager.register_pass>(); + manager.register_pass>(); + manager.register_pass>(); manager.register_pass(); manager.register_pass(); manager.register_pass(context, tpp_mlir_enabled); diff --git a/src/common/transformations/src/transformations/mlir/convert_common.cpp b/src/common/transformations/src/transformations/mlir/convert_common.cpp index de8782c77c12bc..45ab7a904f0d90 100644 --- a/src/common/transformations/src/transformations/mlir/convert_common.cpp +++ b/src/common/transformations/src/transformations/mlir/convert_common.cpp @@ -134,10 +134,7 @@ Location createLocation(MLIRContext* ctx, NodePtr node) { return createLayerLocation(ctx, node->get_friendly_name(), node->get_type_name()); } -bool elementwise_no_broadcast_predicate_impl(const ov::Output& output, ov::element::Type type) { - if (output.get_element_type() != type) { - return false; - } +bool elementwise_no_broadcast_predicate(const ov::Output& output) { if (has_dynamic_rank(output.get_node_shared_ptr())) { return false; } diff --git a/src/common/transformations/src/transformations/mlir/convert_common.hpp b/src/common/transformations/src/transformations/mlir/convert_common.hpp index d6f794bb24bb12..5b6b83ba6d3d5f 100644 --- a/src/common/transformations/src/transformations/mlir/convert_common.hpp +++ b/src/common/transformations/src/transformations/mlir/convert_common.hpp @@ -38,12 +38,7 @@ RankedTensorType importTensor(MLIRContext* ctx, Location createLocation(MLIRContext* ctx, NodePtr node); -bool elementwise_no_broadcast_predicate_impl(const ov::Output& output, ov::element::Type type); - -template -bool elementwise_no_broadcast_predicate(const ov::Output& output) { - return elementwise_no_broadcast_predicate_impl(output, type); -} +bool elementwise_no_broadcast_predicate(const ov::Output& output); // Borrowed it from TPP-MLIR. FIXME: Do we have a better upstreamed alternative? template diff --git a/src/common/transformations/src/transformations/mlir/op/relu.cpp b/src/common/transformations/src/transformations/mlir/op/relu.cpp index 116fd24de7229a..6f7157f9bd4bd4 100644 --- a/src/common/transformations/src/transformations/mlir/op/relu.cpp +++ b/src/common/transformations/src/transformations/mlir/op/relu.cpp @@ -42,8 +42,7 @@ using namespace ov::pass::pattern; using namespace ov::op; ReluPattern::ReluPattern() - : MarkPattern(wrap_type({any_input()}, elementwise_no_broadcast_predicate), - ConvertRelu()) {} + : MarkPattern(wrap_type({any_input()}, elementwise_no_broadcast_predicate), ConvertRelu()) {} } // namespace mlir } // namespace ov