Skip to content

Commit

Permalink
Broadcast support for elementwise ops (#148)
Browse files Browse the repository at this point in the history
* Broadcast support for element-wise ops and more economical way of dynamic dimensions handling based on symbols.

* Simpler broadcast dims cacluations, moved to common utils.

* Use common function to compute dynamic dimension values in MatMul and Relu.

* Element type configurable restriction for the new BinaryEltwisePattern. Forced f32 in the conversion pipeline.
  • Loading branch information
slyalin authored Jul 25, 2024
1 parent 1c34cbf commit f0f76b9
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,25 @@ void ConversionContext::set_convertor(NodePtr node, const Convertor& convertor)
node->get_rt_info()[rt_info_convertor()] = as_any;
}

Value ConversionContext::get_dimension_value(const Dimension& d) {
auto symbol = d.get_symbol();
assert(symbol);
symbol = ov::symbol::ancestor_of(symbol);
// Suppose all dimensions are known and the map is populated
// FIXME: Add dimensions on demand to avoid unnecessary operations in the produced MLIR
assert(dimension_map.count(symbol));
return dimension_map.at(symbol);
}

SmallVector<Value> ConversionContext::get_dynamic_dimension_values (const PartialShape& shape) {
SmallVector<Value> dims;
for (const auto& dim: shape) {
if (dim.is_dynamic()) {
dims.push_back(get_dimension_value(dim));
}
}
return dims;
}


const std::string& subgraph_mark() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mlir/IR/Builders.h"

#include "typedefs.hpp"
#include "convert_common.hpp"

namespace ov {
namespace mlir {
Expand All @@ -20,6 +21,7 @@ using ::mlir::MLIRContext;
using ::mlir::OpBuilder;
using ::mlir::Operation;
using ::mlir::SmallVector;
using ::mlir::ValueRange;

class ConversionContext {
static std::string rt_info_convertor ();
Expand All @@ -32,6 +34,7 @@ class ConversionContext {
mlir::MLIRContext* context;
mlir::OpBuilder* block_builder;
NodeOutputMap nodeOutputMap;
std::map<SymbolPtr, Value> dimension_map;

ConversionContext(mlir::MLIRContext* context, mlir::OpBuilder* block_builder);

Expand All @@ -45,6 +48,10 @@ class ConversionContext {
static void set_convertor(NodePtr node, const Convertor& convertor);

void convert(NodePtr node);

Value get_dimension_value(const Dimension& d);

SmallVector<Value> get_dynamic_dimension_values (const PartialShape& shape);
};


Expand Down
56 changes: 23 additions & 33 deletions src/common/transformations/src/transformations/mlir/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
#include "mlir_op.hpp"
#include "op/matmul.hpp"
#include "op/relu.hpp"
#include "op/binary_eltwise.hpp"
#include "openvino/core/dimension.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/core/symbol.hpp"
Expand Down Expand Up @@ -107,30 +108,6 @@ SmallVector<mlir::Type> get_types_for_values(mlir::MLIRContext* context, const o
return types;
}

template <typename TargetOp>
struct ConvertBinary {
void operator()(ConversionContext& context, NodePtr node) {
auto loc = createLocation(context.context, node);
auto& builder = context.builder();
// TODO: Support broadcasts
const auto inputs = context.getInputs(node);
auto outType = cast<mlir::ShapedType>(inputs[0].getType());
// Named binary ops directly overwrite data in `outs` buffer so, there is no need to provide non-empty
// destination at the tensor-level.
// Use `tensor.empty` to avoid temporary buffer allocation and memcpy after bufferization.
llvm::SmallVector<Value> dynamicSizes;
for (auto [idx, dim] : llvm::enumerate(outType.getShape())) {
if (!mlir::ShapedType::isDynamic(dim))
continue;
auto dimSize = builder.create<tensor::DimOp>(loc, inputs[0], idx);
dynamicSizes.push_back(dimSize);
}
auto empty = builder.create<tensor::EmptyOp>(loc, outType, dynamicSizes);
auto op = builder.create<TargetOp>(loc, mlir::ValueRange{inputs[0], inputs[1]}, mlir::ValueRange{empty});
context.addOutputs(node, op);
}
};


mlir::OwningOpRef<mlir::ModuleOp> ngraph_to_mlir(MLIRContext* context,
const ov::OutputVector& inputs,
Expand Down Expand Up @@ -159,6 +136,24 @@ mlir::OwningOpRef<mlir::ModuleOp> ngraph_to_mlir(MLIRContext* context,
auto loc = createLocation(context, inputs[i].get_node_shared_ptr());
auto tensor = block_builder.create<bufferization::ToTensorOp>(loc, funcInputVal, /*restrict = */ true);
conversion_context.nodeOutputMap.emplace(inputs[i], tensor);

// FIXME: Avoid pre-population of dimension_map, take dimension values only if needed
auto input_shape = inputs[i].get_partial_shape();
auto input_rank = input_shape.rank();
if(input_rank.is_static()) {
for(size_t j = 0; j < input_rank.get_length(); ++j) {
auto dim = input_shape[j];
if(dim.is_dynamic()) {
auto symbol = dim.get_symbol();
assert(symbol);
symbol = ov::symbol::ancestor_of(symbol);
if(dim.is_dynamic() && !conversion_context.dimension_map.count(symbol)) {
auto dimSize = block_builder.create<tensor::DimOp>(loc, tensor, j);
conversion_context.dimension_map[symbol] = dimSize;
}
}
}
}
}

for (size_t i = 0; i < nodes.size(); ++i) {
Expand Down Expand Up @@ -276,21 +271,16 @@ class Partitioner : public ov::pass::ModelPass {
}
};

template <typename Op>
NodePtr elementwise_f32_binary_no_broadcast() {
using namespace ov::pass::pattern;
return wrap_type<Op>({any_input(), any_input()}, elementwise_no_broadcast_predicate<ov::element::f32>);
}

void injectMLIR(std::shared_ptr<ov::Model> model, MLIRContext* context) {
ov::pass::Manager manager;
using namespace ov::op;
manager.set_per_pass_validation(false);
manager.register_pass<ov::pass::SymbolicPropagation>();
manager.register_pass<MarkPattern>(elementwise_f32_binary_no_broadcast<v1::Add>(), ConvertBinary<linalg::AddOp>());
manager.register_pass<MarkPattern>(elementwise_f32_binary_no_broadcast<v1::Subtract>(), ConvertBinary<linalg::SubOp>());
manager.register_pass<MarkPattern>(elementwise_f32_binary_no_broadcast<v1::Multiply>(), ConvertBinary<linalg::MulOp>());
manager.register_pass<MarkPattern>(elementwise_f32_binary_no_broadcast<v1::Divide>(), ConvertBinary<linalg::DivOp>());
manager.register_pass<BinaryEltwisePattern<v1::Add, linalg::AddOp>>(ov::element::f32);
manager.register_pass<BinaryEltwisePattern<v1::Subtract, linalg::SubOp>>(ov::element::f32);
manager.register_pass<BinaryEltwisePattern<v1::Multiply, linalg::MulOp>>(ov::element::f32);
manager.register_pass<BinaryEltwisePattern<v1::Divide, linalg::DivOp>>(ov::element::f32);
manager.register_pass<ReluPattern>();
manager.register_pass<MatMulPattern>();
manager.register_pass<Partitioner>(context);
Expand Down
120 changes: 105 additions & 15 deletions src/common/transformations/src/transformations/mlir/convert_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,38 +132,128 @@ bool elementwise_no_broadcast_predicate_impl(const ov::Output<ov::Node>& output,
if (output.get_element_type() != type) {
return false;
}
if (has_dynamic_rank(output.get_node_shared_ptr())) {
return false;
}
// Check if implicit broadcast is possible, reject in this case
// Relies on symbolic information -- register SymbolicPropagation before applying this pattern
auto inputs = output.get_node_shared_ptr()->inputs();
auto output_shape = output.get_partial_shape();
if (output_shape.rank().is_dynamic()) {
return false;
}

if (std::any_of(inputs.begin(), inputs.end(), [&](const ov::Input<ov::Node>& input) {
auto input_shape = input.get_partial_shape();
return input_shape.rank().is_dynamic() ||
output_shape.rank().get_length() != input_shape.rank().get_length();
if(output_shape.rank().get_length() != input_shape.rank().get_length()) {
return true;
}
for (size_t i = 0; i < output_shape.size(); ++i) {
if(!are_equal_dimensions(input_shape[i], output_shape[i]))
return true;
}
return false;
})) {
return false;
}

return true;
}

bool has_dynamic_rank(NodePtr node) {
auto inputs = node->inputs();
auto outputs = node->outputs();
if (std::any_of(inputs.begin(), inputs.end(), [&](const ov::Input<ov::Node>& input) {
for (size_t i = 0; i < output_shape.size(); ++i) {
auto input_shape = input.get_partial_shape();
if (output_shape[i] != input_shape[i])
return true;
if (output_shape[i].is_static() && input_shape[i].is_static())
continue;
if (!ov::symbol::are_equal(output_shape[i].get_symbol(), input_shape[i].get_symbol()))
return true;
}
return false;
return input.get_partial_shape().rank().is_dynamic();
})) {
return true;
}
if (std::any_of(outputs.begin(), outputs.end(), [&](const ov::Output<ov::Node>& output) {
return output.get_partial_shape().rank().is_dynamic();
})) {
return true;
}
return false;
}

bool are_equal_dimensions(Dimension d1, Dimension d2) {
return
d1.is_static() && d2.is_static() && d1 == d2
||
ov::symbol::are_equal(d1.get_symbol(), d2.get_symbol());
}

bool has_broadcast(Dimension from, Dimension to) {
return from.is_static() && from.get_length() == 1 && !are_equal_dimensions(from, to);
}

bool statically_broadcastable(const PartialShape& from, const PartialShape& to) {
if(from.rank().is_dynamic() || to.rank().is_dynamic()) { // FIXME: `from` can has dynamic rank
return false;
}

auto from_rank = from.rank().get_length();
auto to_rank = to.rank().get_length();

if(from_rank > to_rank) { // such cases shouldn't be allowed to this function, but kept to make the function generic
return false;
}

auto offset = to_rank - from_rank;
for(size_t i = 0; i < from_rank; ++i) {
auto d_from = from[i];
auto d_to = to[offset + i];
if(!are_equal_dimensions(d_from, d_to) && !has_broadcast(d_from, d_to)) {
// cannot deduce neither dimensions broadcast nor dimensions equality
return false;
}
}

return true;
}

BroadcastDimensions broadcast_dimensions(const PartialShape& src, const PartialShape& dst) {
assert(statically_broadcastable(src, dst));

auto src_rank = src.rank().get_length();
auto dst_rank = dst.rank().get_length();
auto offset = dst_rank - src_rank;

BroadcastDimensions result;
auto& [collapse_groups, dimensions] = result;
ReassociationIndices group;
bool group_bonded = false; // true if `group` has a non-brodcasted dimension

size_t dst_i = 0; // dimension index in the `dst` shape
for(; dst_i < offset; ++dst_i) {
dimensions.push_back(dst_i);
}
for(; dst_i < dst_rank; ++dst_i) {
auto src_i = dst_i - offset;
auto src_d = src[src_i];
auto dst_d = dst[dst_i];
if(has_broadcast(src_d, dst_d)) {
dimensions.push_back(dst_i);
} else {
if(group_bonded) {
collapse_groups.emplace_back(group);
group = ReassociationIndices();
} else {
group_bonded = true;
}
}
group.push_back(src_i);
}

if(group_bonded && !group.empty()) {
collapse_groups.emplace_back(group);
}

assert(dst_rank - dimensions.size() == collapse_groups.size());

return result;
}

bool symbol_ancestor_less (SymbolPtr x, SymbolPtr y) {
return ov::symbol::ancestor_of(x) < ov::symbol::ancestor_of(y);
}

} // namespace mlir
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Location.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

#include "typedefs.hpp"

Expand Down Expand Up @@ -52,5 +53,18 @@ mlir::arith::ConstantOp getConstant(OpBuilder &builder, const ov::element::Type&
return builder.create<arith::ConstantOp>(unkLoc, type, attr);
}

bool has_dynamic_rank(NodePtr node);

bool are_equal_dimensions(Dimension d1, Dimension d2);

bool has_broadcast(Dimension from, Dimension to);

bool statically_broadcastable(const PartialShape& from, const PartialShape& to);

using BroadcastDimensions = std::tuple<SmallVector<ReassociationIndices>, SmallVector<int64_t>>;
BroadcastDimensions broadcast_dimensions(const PartialShape& from, const PartialShape& to);

bool symbol_ancestor_less (SymbolPtr x, SymbolPtr y);

} // namespace mlir
} // namespace ov
Loading

0 comments on commit f0f76b9

Please sign in to comment.