From 86fd8c4d1061c0012fc79bb681d8d56930a454fa Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 26 Jul 2024 12:53:51 +0200 Subject: [PATCH 1/2] More matmul variants --- .../src/transformations/mlir/op/matmul.cpp | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/common/transformations/src/transformations/mlir/op/matmul.cpp b/src/common/transformations/src/transformations/mlir/op/matmul.cpp index 635779c0bf52fe..b447a8a40e70d5 100644 --- a/src/common/transformations/src/transformations/mlir/op/matmul.cpp +++ b/src/common/transformations/src/transformations/mlir/op/matmul.cpp @@ -29,8 +29,24 @@ struct ConvertMatMul { auto empty = builder.create(loc, outType, dynamic_dimensions); auto zero = getConstant(builder, ov_output_element_type, 0); auto fill = builder.create(loc, mlir::ValueRange{zero}, mlir::ValueRange{empty}); - // TODO: Add other variants of transpose_a/transpose_b - auto matmul = builder.create(loc, mlir::ValueRange{inputs[0], inputs[1]}, mlir::ValueRange{fill.getResult(0)}); + + mlir::ValueRange ins{inputs[0], inputs[1]}; + mlir::ValueRange outs{fill.getResult(0)}; + + auto matmul_node = std::dynamic_pointer_cast(node); + assert(matmul_node); + bool isTransposedA = matmul_node->get_transpose_a(); + bool isTransposedB = matmul_node->get_transpose_b(); + assert(!(isTransposedA && isTransposedB)); + Operation* matmul; + if (isTransposedA) { + matmul = builder.create(loc, ins, outs); + } else if (isTransposedB) { + matmul = builder.create(loc, ins, outs); + } else { + matmul = builder.create(loc, ins, outs); + } + context.addOutputs(node, matmul); } }; @@ -48,11 +64,9 @@ MatMulPattern::MatMulPattern() : MarkPattern( auto node = std::dynamic_pointer_cast(output.get_node_shared_ptr()); assert(node); // FIXME: current code limitation - return - !has_dynamic_rank(node) && - !node->get_transpose_a() && node->get_transpose_b() && - node->get_input_partial_shape(0).rank().get_length() == 2 && - node->get_input_partial_shape(1).rank().get_length() == 2; + return !has_dynamic_rank(node) && !(node->get_transpose_a() && node->get_transpose_b()) && + node->get_input_partial_shape(0).rank().get_length() == 2 && + node->get_input_partial_shape(1).rank().get_length() == 2; }), ConvertMatMul()) { } From fda0ef6ba1527b4a29928521d750ce57f2a91965 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 26 Jul 2024 13:27:16 +0200 Subject: [PATCH 2/2] Fixes --- .../transformations/src/transformations/mlir/op/matmul.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/common/transformations/src/transformations/mlir/op/matmul.cpp b/src/common/transformations/src/transformations/mlir/op/matmul.cpp index b447a8a40e70d5..b8c50db1ed4f31 100644 --- a/src/common/transformations/src/transformations/mlir/op/matmul.cpp +++ b/src/common/transformations/src/transformations/mlir/op/matmul.cpp @@ -30,14 +30,15 @@ struct ConvertMatMul { auto zero = getConstant(builder, ov_output_element_type, 0); auto fill = builder.create(loc, mlir::ValueRange{zero}, mlir::ValueRange{empty}); - mlir::ValueRange ins{inputs[0], inputs[1]}; - mlir::ValueRange outs{fill.getResult(0)}; + mlir::SmallVector ins{inputs[0], inputs[1]}; + mlir::SmallVector outs{fill.getResult(0)}; auto matmul_node = std::dynamic_pointer_cast(node); assert(matmul_node); bool isTransposedA = matmul_node->get_transpose_a(); bool isTransposedB = matmul_node->get_transpose_b(); assert(!(isTransposedA && isTransposedB)); + Operation* matmul; if (isTransposedA) { matmul = builder.create(loc, ins, outs);