Skip to content

Commit

Permalink
Simpler broadcast dims cacluations, moved to common utils.
Browse files Browse the repository at this point in the history
  • Loading branch information
slyalin committed Jul 25, 2024
1 parent b8d50cf commit 13b7952
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ Value ConversionContext::get_dimension_value(const Dimension& d) {
return dimension_map.at(symbol);
}

SmallVector<Value> ConversionContext::get_dynamic_dimension_values (const PartialShape& shape) {
SmallVector<Value> dims;
for (const auto& dim: shape) {
if (dim.is_dynamic()) {
dims.push_back(get_dimension_value(dim));
}
}
return dims;
}


const std::string& subgraph_mark() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class ConversionContext {
void convert(NodePtr node);

Value get_dimension_value(const Dimension& d);

SmallVector<Value> get_dynamic_dimension_values (const PartialShape& shape);
};


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,28 +209,46 @@ bool statically_broadcastable(const PartialShape& from, const PartialShape& to)
return true;
}

std::vector<int64_t> broadcast_dimensions(const PartialShape& from, const PartialShape& to) {
assert(statically_broadcastable(from, to));
BroadcastDimensions broadcast_dimensions(const PartialShape& src, const PartialShape& dst) {
assert(statically_broadcastable(src, dst));

auto from_rank = from.rank().get_length();
auto to_rank = to.rank().get_length();
auto src_rank = src.rank().get_length();
auto dst_rank = dst.rank().get_length();
auto offset = dst_rank - src_rank;

auto offset = to_rank - from_rank;
std::vector<int64_t> dimensions;
BroadcastDimensions result;
auto& [collapse_groups, dimensions] = result;
ReassociationIndices group;
bool group_bonded = false; // true if `group` has a non-brodcasted dimension

for(size_t i = 0; i < to_rank; ++i) {
if (i < offset) {
dimensions.push_back(i);
size_t dst_i = 0; // dimension index in the `dst` shape
for(; dst_i < offset; ++dst_i) {
dimensions.push_back(dst_i);
}
for(; dst_i < dst_rank; ++dst_i) {
auto src_i = dst_i - offset;
auto src_d = src[src_i];
auto dst_d = dst[dst_i];
if(has_broadcast(src_d, dst_d)) {
dimensions.push_back(dst_i);
} else {
auto d_from = from[i - offset];
auto d_to = to[i];
if(has_broadcast(d_from, d_to)) {
dimensions.push_back(i);
if(group_bonded) {
collapse_groups.emplace_back(group);
group = ReassociationIndices();
} else {
group_bonded = true;
}
}
group.push_back(src_i);
}

return dimensions;
if(group_bonded && !group.empty()) {
collapse_groups.emplace_back(group);
}

assert(dst_rank - dimensions.size() == collapse_groups.size());

return result;
}

bool symbol_ancestor_less (SymbolPtr x, SymbolPtr y) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Location.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

#include "typedefs.hpp"

Expand Down Expand Up @@ -60,7 +61,8 @@ bool has_broadcast(Dimension from, Dimension to);

bool statically_broadcastable(const PartialShape& from, const PartialShape& to);

std::vector<int64_t> broadcast_dimensions(const PartialShape& from, const PartialShape& to);
using BroadcastDimensions = std::tuple<SmallVector<ReassociationIndices>, SmallVector<int64_t>>;
BroadcastDimensions broadcast_dimensions(const PartialShape& from, const PartialShape& to);

bool symbol_ancestor_less (SymbolPtr x, SymbolPtr y);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,59 +36,25 @@ class ConvertBinaryEltwise {
auto outType = importTensor(context.context, ov_output_shape, ov_output_element_type);
const int output_rank = ov_output_shape.rank().get_length();

SmallVector<Value> dynamicSizes;
for (auto [idx, dim] : llvm::enumerate(ov_output_shape)) {
if (!dim.is_dynamic())
continue;
dynamicSizes.push_back(context.get_dimension_value(dim));
}
SmallVector<Value> dynamic_dimensions = context.get_dynamic_dimension_values(ov_output_shape);

SmallVector<Value> broadcasted_inputs;
for(size_t i = 0; i < inputs.size(); ++i) {
auto dimensions = broadcast_dimensions(node->get_input_partial_shape(i), ov_output_shape);
auto [collapse_groups, dimensions] = broadcast_dimensions(node->get_input_partial_shape(i), ov_output_shape);
if(!dimensions.empty()) {
// FIXME: Find a way to avoid dimension squeezing before applying linalg.broadcast

// Step 1: Squeeze input shape to eliminate broadcasted dimensions
SmallVector<ReassociationIndices, 6> squeeze_map;
ReassociationIndices ri_cur;
size_t output_idx = 0; // index in ov_output_shape
bool group_open = true;
for(auto [_, dim]: llvm::enumerate(dimensions)) {
for(; output_idx < dim; ++output_idx) {
if(!ri_cur.empty() && !group_open) {
squeeze_map.emplace_back(ri_cur);
ri_cur = ReassociationIndices();
}
ri_cur.push_back(output_idx);
group_open = false;
}
assert(dim == output_idx);
ri_cur.push_back(dim);
++output_idx;
}
for(; output_idx < output_rank; ++output_idx) {
if(group_open) {
ri_cur.push_back(output_idx);
squeeze_map.push_back(ri_cur);
group_open = false;
} else {
squeeze_map.push_back({output_idx});
}
}

auto squeezed = builder.create<tensor::CollapseShapeOp>(loc, inputs[i], squeeze_map);

auto squeezed = builder.create<tensor::CollapseShapeOp>(loc, inputs[i], collapse_groups);
// Step 2: Broadcast squeezed shape to the target shape
auto empty = builder.create<tensor::EmptyOp>(loc, outType, dynamicSizes);
auto empty = builder.create<tensor::EmptyOp>(loc, outType, dynamic_dimensions);
auto op = builder.create<linalg::BroadcastOp>(loc, squeezed, empty, dimensions);
broadcasted_inputs.push_back(op.getResult()[0]);
} else {
broadcasted_inputs.push_back(inputs[i]);
}
}

auto empty = builder.create<tensor::EmptyOp>(loc, outType, dynamicSizes);
auto empty = builder.create<tensor::EmptyOp>(loc, outType, dynamic_dimensions);
auto op = m_op_builder(builder, loc, ValueRange(broadcasted_inputs), ValueRange{empty});
context.addOutputs(node, op);
}
Expand Down

0 comments on commit 13b7952

Please sign in to comment.