diff --git a/src/common/transformations/src/transformations/mlir/conversion_context.cpp b/src/common/transformations/src/transformations/mlir/conversion_context.cpp index 5d1dfbf619a9b0..89692ae2bc80d7 100644 --- a/src/common/transformations/src/transformations/mlir/conversion_context.cpp +++ b/src/common/transformations/src/transformations/mlir/conversion_context.cpp @@ -58,6 +58,25 @@ void ConversionContext::set_convertor(NodePtr node, const Convertor& convertor) node->get_rt_info()[rt_info_convertor()] = as_any; } +Value ConversionContext::get_dimension_value(const Dimension& d) { + auto symbol = d.get_symbol(); + assert(symbol); + symbol = ov::symbol::ancestor_of(symbol); + // Suppose all dimensions are known and the map is populated + // FIXME: Add dimensions on demand to avoid unnecessary operations in the produced MLIR + assert(dimension_map.count(symbol)); + return dimension_map.at(symbol); +} + +SmallVector ConversionContext::get_dynamic_dimension_values (const PartialShape& shape) { + SmallVector dims; + for (const auto& dim: shape) { + if (dim.is_dynamic()) { + dims.push_back(get_dimension_value(dim)); + } + } + return dims; +} const std::string& subgraph_mark() { diff --git a/src/common/transformations/src/transformations/mlir/conversion_context.hpp b/src/common/transformations/src/transformations/mlir/conversion_context.hpp index 3f2c0ee9b34619..314b0529642453 100644 --- a/src/common/transformations/src/transformations/mlir/conversion_context.hpp +++ b/src/common/transformations/src/transformations/mlir/conversion_context.hpp @@ -11,6 +11,7 @@ #include "mlir/IR/Builders.h" #include "typedefs.hpp" +#include "convert_common.hpp" namespace ov { namespace mlir { @@ -20,6 +21,7 @@ using ::mlir::MLIRContext; using ::mlir::OpBuilder; using ::mlir::Operation; using ::mlir::SmallVector; +using ::mlir::ValueRange; class ConversionContext { static std::string rt_info_convertor (); @@ -32,6 +34,7 @@ class ConversionContext { mlir::MLIRContext* context; mlir::OpBuilder* block_builder; NodeOutputMap nodeOutputMap; + std::map dimension_map; ConversionContext(mlir::MLIRContext* context, mlir::OpBuilder* block_builder); @@ -45,6 +48,10 @@ class ConversionContext { static void set_convertor(NodePtr node, const Convertor& convertor); void convert(NodePtr node); + + Value get_dimension_value(const Dimension& d); + + SmallVector get_dynamic_dimension_values (const PartialShape& shape); }; diff --git a/src/common/transformations/src/transformations/mlir/convert.cpp b/src/common/transformations/src/transformations/mlir/convert.cpp index ee40c112b40e9a..63bafdc0d5fa1a 100644 --- a/src/common/transformations/src/transformations/mlir/convert.cpp +++ b/src/common/transformations/src/transformations/mlir/convert.cpp @@ -68,6 +68,7 @@ #include "mlir_op.hpp" #include "op/matmul.hpp" #include "op/relu.hpp" +#include "op/binary_eltwise.hpp" #include "openvino/core/dimension.hpp" #include "openvino/core/rt_info.hpp" #include "openvino/core/symbol.hpp" @@ -107,30 +108,6 @@ SmallVector get_types_for_values(mlir::MLIRContext* context, const o return types; } -template -struct ConvertBinary { - void operator()(ConversionContext& context, NodePtr node) { - auto loc = createLocation(context.context, node); - auto& builder = context.builder(); - // TODO: Support broadcasts - const auto inputs = context.getInputs(node); - auto outType = cast(inputs[0].getType()); - // Named binary ops directly overwrite data in `outs` buffer so, there is no need to provide non-empty - // destination at the tensor-level. - // Use `tensor.empty` to avoid temporary buffer allocation and memcpy after bufferization. - llvm::SmallVector dynamicSizes; - for (auto [idx, dim] : llvm::enumerate(outType.getShape())) { - if (!mlir::ShapedType::isDynamic(dim)) - continue; - auto dimSize = builder.create(loc, inputs[0], idx); - dynamicSizes.push_back(dimSize); - } - auto empty = builder.create(loc, outType, dynamicSizes); - auto op = builder.create(loc, mlir::ValueRange{inputs[0], inputs[1]}, mlir::ValueRange{empty}); - context.addOutputs(node, op); - } -}; - mlir::OwningOpRef ngraph_to_mlir(MLIRContext* context, const ov::OutputVector& inputs, @@ -159,6 +136,24 @@ mlir::OwningOpRef ngraph_to_mlir(MLIRContext* context, auto loc = createLocation(context, inputs[i].get_node_shared_ptr()); auto tensor = block_builder.create(loc, funcInputVal, /*restrict = */ true); conversion_context.nodeOutputMap.emplace(inputs[i], tensor); + + // FIXME: Avoid pre-population of dimension_map, take dimension values only if needed + auto input_shape = inputs[i].get_partial_shape(); + auto input_rank = input_shape.rank(); + if(input_rank.is_static()) { + for(size_t j = 0; j < input_rank.get_length(); ++j) { + auto dim = input_shape[j]; + if(dim.is_dynamic()) { + auto symbol = dim.get_symbol(); + assert(symbol); + symbol = ov::symbol::ancestor_of(symbol); + if(dim.is_dynamic() && !conversion_context.dimension_map.count(symbol)) { + auto dimSize = block_builder.create(loc, tensor, j); + conversion_context.dimension_map[symbol] = dimSize; + } + } + } + } } for (size_t i = 0; i < nodes.size(); ++i) { @@ -276,21 +271,16 @@ class Partitioner : public ov::pass::ModelPass { } }; -template -NodePtr elementwise_f32_binary_no_broadcast() { - using namespace ov::pass::pattern; - return wrap_type({any_input(), any_input()}, elementwise_no_broadcast_predicate); -} void injectMLIR(std::shared_ptr model, MLIRContext* context) { ov::pass::Manager manager; using namespace ov::op; manager.set_per_pass_validation(false); manager.register_pass(); - manager.register_pass(elementwise_f32_binary_no_broadcast(), ConvertBinary()); - manager.register_pass(elementwise_f32_binary_no_broadcast(), ConvertBinary()); - manager.register_pass(elementwise_f32_binary_no_broadcast(), ConvertBinary()); - manager.register_pass(elementwise_f32_binary_no_broadcast(), ConvertBinary()); + manager.register_pass>(ov::element::f32); + manager.register_pass>(ov::element::f32); + manager.register_pass>(ov::element::f32); + manager.register_pass>(ov::element::f32); manager.register_pass(); manager.register_pass(); manager.register_pass(context); diff --git a/src/common/transformations/src/transformations/mlir/convert_common.cpp b/src/common/transformations/src/transformations/mlir/convert_common.cpp index 6bca04c759a356..1499acb0ba844f 100644 --- a/src/common/transformations/src/transformations/mlir/convert_common.cpp +++ b/src/common/transformations/src/transformations/mlir/convert_common.cpp @@ -132,38 +132,128 @@ bool elementwise_no_broadcast_predicate_impl(const ov::Output& output, if (output.get_element_type() != type) { return false; } + if (has_dynamic_rank(output.get_node_shared_ptr())) { + return false; + } // Check if implicit broadcast is possible, reject in this case // Relies on symbolic information -- register SymbolicPropagation before applying this pattern auto inputs = output.get_node_shared_ptr()->inputs(); auto output_shape = output.get_partial_shape(); - if (output_shape.rank().is_dynamic()) { - return false; - } + if (std::any_of(inputs.begin(), inputs.end(), [&](const ov::Input& input) { auto input_shape = input.get_partial_shape(); - return input_shape.rank().is_dynamic() || - output_shape.rank().get_length() != input_shape.rank().get_length(); + if(output_shape.rank().get_length() != input_shape.rank().get_length()) { + return true; + } + for (size_t i = 0; i < output_shape.size(); ++i) { + if(!are_equal_dimensions(input_shape[i], output_shape[i])) + return true; + } + return false; })) { return false; } + return true; +} + +bool has_dynamic_rank(NodePtr node) { + auto inputs = node->inputs(); + auto outputs = node->outputs(); if (std::any_of(inputs.begin(), inputs.end(), [&](const ov::Input& input) { - for (size_t i = 0; i < output_shape.size(); ++i) { - auto input_shape = input.get_partial_shape(); - if (output_shape[i] != input_shape[i]) - return true; - if (output_shape[i].is_static() && input_shape[i].is_static()) - continue; - if (!ov::symbol::are_equal(output_shape[i].get_symbol(), input_shape[i].get_symbol())) - return true; - } - return false; + return input.get_partial_shape().rank().is_dynamic(); })) { + return true; + } + if (std::any_of(outputs.begin(), outputs.end(), [&](const ov::Output& output) { + return output.get_partial_shape().rank().is_dynamic(); + })) { + return true; + } + return false; +} + +bool are_equal_dimensions(Dimension d1, Dimension d2) { + return + d1.is_static() && d2.is_static() && d1 == d2 + || + ov::symbol::are_equal(d1.get_symbol(), d2.get_symbol()); +} + +bool has_broadcast(Dimension from, Dimension to) { + return from.is_static() && from.get_length() == 1 && !are_equal_dimensions(from, to); +} + +bool statically_broadcastable(const PartialShape& from, const PartialShape& to) { + if(from.rank().is_dynamic() || to.rank().is_dynamic()) { // FIXME: `from` can has dynamic rank + return false; + } + + auto from_rank = from.rank().get_length(); + auto to_rank = to.rank().get_length(); + + if(from_rank > to_rank) { // such cases shouldn't be allowed to this function, but kept to make the function generic return false; } + auto offset = to_rank - from_rank; + for(size_t i = 0; i < from_rank; ++i) { + auto d_from = from[i]; + auto d_to = to[offset + i]; + if(!are_equal_dimensions(d_from, d_to) && !has_broadcast(d_from, d_to)) { + // cannot deduce neither dimensions broadcast nor dimensions equality + return false; + } + } + return true; } +BroadcastDimensions broadcast_dimensions(const PartialShape& src, const PartialShape& dst) { + assert(statically_broadcastable(src, dst)); + + auto src_rank = src.rank().get_length(); + auto dst_rank = dst.rank().get_length(); + auto offset = dst_rank - src_rank; + + BroadcastDimensions result; + auto& [collapse_groups, dimensions] = result; + ReassociationIndices group; + bool group_bonded = false; // true if `group` has a non-brodcasted dimension + + size_t dst_i = 0; // dimension index in the `dst` shape + for(; dst_i < offset; ++dst_i) { + dimensions.push_back(dst_i); + } + for(; dst_i < dst_rank; ++dst_i) { + auto src_i = dst_i - offset; + auto src_d = src[src_i]; + auto dst_d = dst[dst_i]; + if(has_broadcast(src_d, dst_d)) { + dimensions.push_back(dst_i); + } else { + if(group_bonded) { + collapse_groups.emplace_back(group); + group = ReassociationIndices(); + } else { + group_bonded = true; + } + } + group.push_back(src_i); + } + + if(group_bonded && !group.empty()) { + collapse_groups.emplace_back(group); + } + + assert(dst_rank - dimensions.size() == collapse_groups.size()); + + return result; +} + +bool symbol_ancestor_less (SymbolPtr x, SymbolPtr y) { + return ov::symbol::ancestor_of(x) < ov::symbol::ancestor_of(y); +} + } // namespace mlir } // namespace ov \ No newline at end of file diff --git a/src/common/transformations/src/transformations/mlir/convert_common.hpp b/src/common/transformations/src/transformations/mlir/convert_common.hpp index a33c99e6bedc57..6622ea5ed70c0c 100644 --- a/src/common/transformations/src/transformations/mlir/convert_common.hpp +++ b/src/common/transformations/src/transformations/mlir/convert_common.hpp @@ -9,6 +9,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Location.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "typedefs.hpp" @@ -52,5 +53,18 @@ mlir::arith::ConstantOp getConstant(OpBuilder &builder, const ov::element::Type& return builder.create(unkLoc, type, attr); } +bool has_dynamic_rank(NodePtr node); + +bool are_equal_dimensions(Dimension d1, Dimension d2); + +bool has_broadcast(Dimension from, Dimension to); + +bool statically_broadcastable(const PartialShape& from, const PartialShape& to); + +using BroadcastDimensions = std::tuple, SmallVector>; +BroadcastDimensions broadcast_dimensions(const PartialShape& from, const PartialShape& to); + +bool symbol_ancestor_less (SymbolPtr x, SymbolPtr y); + } // namespace mlir } // namespace ov \ No newline at end of file diff --git a/src/common/transformations/src/transformations/mlir/op/binary_eltwise.cpp b/src/common/transformations/src/transformations/mlir/op/binary_eltwise.cpp new file mode 100644 index 00000000000000..c840c720b4f497 --- /dev/null +++ b/src/common/transformations/src/transformations/mlir/op/binary_eltwise.cpp @@ -0,0 +1,91 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Value.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Linalg/Passes.h" + +#include +#include "openvino/pass/pattern/op/wrap_type.hpp" + +#include "binary_eltwise.hpp" + +namespace { + +using namespace ov; +using namespace ov::mlir; +using ::mlir::ValueRange; + +class ConvertBinaryEltwise { + + BinaryEltwisePatternBase::Builder m_op_builder; + +public: + + ConvertBinaryEltwise(BinaryEltwisePatternBase::Builder op_builder) : m_op_builder(op_builder) {} + + void operator()(ConversionContext& context, NodePtr node) { + auto loc = createLocation(context.context, node); + auto& builder = context.builder(); + const auto inputs = context.getInputs(node); + const auto ov_output_element_type = node->get_output_element_type(0); + const auto ov_output_shape = node->get_output_partial_shape(0); + auto outType = importTensor(context.context, ov_output_shape, ov_output_element_type); + const int output_rank = ov_output_shape.rank().get_length(); + + SmallVector dynamic_dimensions = context.get_dynamic_dimension_values(ov_output_shape); + + SmallVector broadcasted_inputs; + for(size_t i = 0; i < inputs.size(); ++i) { + auto [collapse_groups, dimensions] = broadcast_dimensions(node->get_input_partial_shape(i), ov_output_shape); + if(!dimensions.empty()) { + // FIXME: Find a way to avoid dimension squeezing before applying linalg.broadcast + // Step 1: Squeeze input shape to eliminate broadcasted dimensions + auto squeezed = builder.create(loc, inputs[i], collapse_groups); + // Step 2: Broadcast squeezed shape to the target shape + auto empty = builder.create(loc, outType, dynamic_dimensions); + auto op = builder.create(loc, squeezed, empty, dimensions); + broadcasted_inputs.push_back(op.getResult()[0]); + } else { + broadcasted_inputs.push_back(inputs[i]); + } + } + + auto empty = builder.create(loc, outType, dynamic_dimensions); + auto op = m_op_builder(builder, loc, ValueRange(broadcasted_inputs), ValueRange{empty}); + context.addOutputs(node, op); + } +}; + +} // namespace + +namespace ov { +namespace mlir { + +using namespace ov::pass::pattern; + +BinaryEltwisePatternBase::BinaryEltwisePatternBase(NodeTypeInfo wrapped_type, Builder op_builder, const std::set& element_types) + : MarkPattern( + std::make_shared( + wrapped_type, + [element_types](const Output& output) { + if(!element_types.empty() && !element_types.count(output.get_element_type())) { + return false; + } + auto node = output.get_node_shared_ptr(); + for(const auto& input: node->inputs()) { + if(!statically_broadcastable(input.get_partial_shape(), output.get_partial_shape())) { + return false; + } + } + return true; + }, + OutputVector{any_input(), any_input()}), + ConvertBinaryEltwise(op_builder)) + {} + +} // namespace mlir +} // namespace ov diff --git a/src/common/transformations/src/transformations/mlir/op/binary_eltwise.hpp b/src/common/transformations/src/transformations/mlir/op/binary_eltwise.hpp new file mode 100644 index 00000000000000..1c410608cbc4a5 --- /dev/null +++ b/src/common/transformations/src/transformations/mlir/op/binary_eltwise.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "../conversion_context.hpp" + +namespace ov { +namespace mlir { + +class BinaryEltwisePatternBase : public MarkPattern { +public: + using Builder = std::function; + + OPENVINO_RTTI("BinaryEltwisePatternBase", "0"); + BinaryEltwisePatternBase(NodeTypeInfo wrapped_type, Builder op_builder, const std::set& element_types = {}); +}; + + +template +class BinaryEltwisePattern : public BinaryEltwisePatternBase { +public: + // Allow conversion for given `element_types` only, except case when `element_types` is empty which means no restrictions on types, everything is allowed. + BinaryEltwisePattern (const std::set& element_types = {}) : + BinaryEltwisePatternBase( + OVOp::get_type_info_static(), + [](OpBuilder& builder, ::mlir::Location loc, ValueRange ins, ValueRange outs) -> Operation* { + return builder.create(loc, ins, outs); + }, + element_types) + {} + + BinaryEltwisePattern (const element::Type& element_type) : + BinaryEltwisePattern(std::set{element_type}) + {} +}; + + +} // namespace mlir +} // namespace ov + diff --git a/src/common/transformations/src/transformations/mlir/op/matmul.cpp b/src/common/transformations/src/transformations/mlir/op/matmul.cpp index f125e40ed409b6..229b4734358cfd 100644 --- a/src/common/transformations/src/transformations/mlir/op/matmul.cpp +++ b/src/common/transformations/src/transformations/mlir/op/matmul.cpp @@ -20,6 +20,7 @@ struct ConvertMatMul { void operator()(ConversionContext& context, NodePtr node) { auto matmul_node = std::dynamic_pointer_cast(node); assert(matmul_node); + // FIXME: current code limitation assert(!matmul_node->get_transpose_a() && matmul_node->get_transpose_b()); @@ -29,20 +30,9 @@ struct ConvertMatMul { const auto inputs = context.getInputs(node); const auto ov_output_element_type = node->get_output_element_type(0); const auto ov_output_shape = node->get_output_partial_shape(0); - auto outType = importTensor(context.context, ov_output_shape, ov_output_element_type); // Instead of this (WRONG): cast(inputs[0].getType()); - - llvm::SmallVector dynamicSizes; - for (auto [idx, dim] : llvm::enumerate(outType.getShape())) { - if (!mlir::ShapedType::isDynamic(dim)) - continue; - // FIXME: correct in case if (!transpose_a && transpose_b) - auto dimSize = - builder.create(loc, - idx == 0 ? inputs[0] : inputs[1], - 0); // TODO: Use symbols instead of taking dims directly from inputs - dynamicSizes.push_back(dimSize); - } - auto empty = builder.create(loc, outType, dynamicSizes); + auto outType = importTensor(context.context, ov_output_shape, ov_output_element_type); + auto dynamic_dimensions = context.get_dynamic_dimension_values(ov_output_shape); + auto empty = builder.create(loc, outType, dynamic_dimensions); auto zero = getConstant(builder, ov_output_element_type, 0); auto fill = builder.create(loc, mlir::ValueRange{zero}, mlir::ValueRange{empty}); // TODO: Add other variants of transpose_a/transpose_b diff --git a/src/common/transformations/src/transformations/mlir/op/relu.cpp b/src/common/transformations/src/transformations/mlir/op/relu.cpp index a25f571f61cddf..116fd24de7229a 100644 --- a/src/common/transformations/src/transformations/mlir/op/relu.cpp +++ b/src/common/transformations/src/transformations/mlir/op/relu.cpp @@ -19,22 +19,12 @@ struct ConvertRelu { void operator()(ConversionContext& context, NodePtr node) { auto loc = createLocation(context.context, node); auto& builder = context.builder(); - // TODO: Support broadcasts const auto input = context.getInputs(node)[0]; const auto ov_output_element_type = node->get_output_element_type(0); const auto ov_output_shape = node->get_output_partial_shape(0); auto outType = importTensor(context.context, ov_output_shape, ov_output_element_type); - // Named unary ops directly overwrite data in `outs` buffer so, there is no need to provide non-empty - // destination at the tensor-level. - // Use `tensor.empty` to avoid temporary buffer allocation and memcpy after bufferization. - llvm::SmallVector dynamicSizes; - for (auto [idx, dim] : llvm::enumerate(outType.getShape())) { - if (!mlir::ShapedType::isDynamic(dim)) - continue; - auto dimSize = builder.create(loc, input, idx); - dynamicSizes.push_back(dimSize); - } - auto empty = builder.create(loc, outType, dynamicSizes); + auto dynamic_dimensions = context.get_dynamic_dimension_values(ov_output_shape); + auto empty = builder.create(loc, outType, dynamic_dimensions); auto zero = getConstant(builder, ov_output_element_type, 0); auto fill = builder.create(loc, mlir::ValueRange{zero}, mlir::ValueRange{empty}); auto relu =