From 2694dd3025bcb8159dc1ca45f7b69a081928b724 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Tue, 29 Aug 2023 17:43:32 -0700 Subject: [PATCH] [Linalg] fix intermediate results cleaning in fuse-ext (#43) --- .../Dialect/Linalg/Transforms/Transforms.cpp | 41 ++++++++++++++----- .../test/Dialect/Linalg/fuse-attention.mlir | 10 +++-- .../Dialect/Linalg/transform-op-fuse.mlir | 35 +++++++++++++++- 3 files changed, 71 insertions(+), 15 deletions(-) diff --git a/compiler/lib/Dialect/Linalg/Transforms/Transforms.cpp b/compiler/lib/Dialect/Linalg/Transforms/Transforms.cpp index 92bdc2499..c22ee72fd 100644 --- a/compiler/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/compiler/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -354,15 +354,15 @@ LogicalResult mlir::scf::isValidTiling(Operation *tiled) { //===----------------------------------------------------------------------===// bool mlir::scf::isResultLoopInvariant(Operation *op, int64_t resultNumber, - bool hasOneOrZeroUse, bool allParallel) { + bool passUsesCheck, bool allParallel) { if (op == nullptr) return false; if (auto linalgExtOp = dyn_cast(op)) { - return linalgExtOp.isResultLoopInvariant(resultNumber, hasOneOrZeroUse, + return linalgExtOp.isResultLoopInvariant(resultNumber, passUsesCheck, allParallel); } else if (isa(op)) { - return hasOneOrZeroUse && allParallel; + return passUsesCheck && allParallel; } return false; } @@ -1281,6 +1281,13 @@ mlir::scf::tileConsumerAndFuseProducerUsingSCFForOpExt( // check getLoopIteratorTypes for each fusedOp // if parallel, corresponding getRegionIterArgs will be simplified unsigned resultOffset = 0; + + llvm::DenseSet unfusedOpsSet; + for (auto &p : fusedOps) { + Operation *unfusedOp = p.first; + unfusedOpsSet.insert(unfusedOp); + } + for (const auto &p : fusedOps) { auto unfusedOp = p.first; auto fusedOp = p.second; @@ -1300,18 +1307,19 @@ mlir::scf::tileConsumerAndFuseProducerUsingSCFForOpExt( auto result = unfusedOp->getResult(i); auto effectiveUseCnt = - llvm::count_if(result.getUses(), [](OpOperand &opOperand) { + llvm::count_if(result.getUses(), [&](OpOperand &opOperand) { + if (unfusedOpsSet.contains(opOperand.getOwner())) + return false; + if (auto dstOp = dyn_cast( opOperand.getOwner())) { return !dstOp.isDpsInit(&opOperand); } + return !isa(opOperand.getOwner()); }); - bool hasOneOrZeroUseGeneral = - unfusedOp == consumer ? effectiveUseCnt < 1 : effectiveUseCnt <= 1; - - bool hasOneOrZeroUseForExtract = effectiveUseCnt <= 1; + bool hasZeroOutsideUse = effectiveUseCnt == 0; auto confirmAllParallel = [&](size_t loopCnt) { bool allParallel = true; @@ -1343,7 +1351,7 @@ mlir::scf::tileConsumerAndFuseProducerUsingSCFForOpExt( auto iterArg = forOp.getRegionIterArg(resultOffset + i); auto iterOperand = forOp.getIterOperands()[resultOffset + i]; - if (isResultLoopInvariant(unfusedOp, i, hasOneOrZeroUseGeneral, + if (isResultLoopInvariant(unfusedOp, i, hasZeroOutsideUse, confirmedAllParallel)) { iterArg.replaceUsesWithIf(iterOperand, [&](OpOperand &use) { return (opCollection.contains(use.getOwner()) || @@ -1351,9 +1359,20 @@ mlir::scf::tileConsumerAndFuseProducerUsingSCFForOpExt( }); } + // The following replace is used to optimize the following IR: + // + // %0 = tensor.empty + // scf.for ... (%arg0 = %0, ...) + // %1 = tensor.extract %arg0 + // "use"(%1)... + // + // to + // + // scf.for ... + // %0 = tensor.empty + // "use"(%0) if (simplifyLoopIter && - isResultLoopInvariant(unfusedOp, i, hasOneOrZeroUseForExtract, - confirmedAllParallel)) { + isResultLoopInvariant(unfusedOp, i, true, confirmedAllParallel)) { iterArg.replaceUsesWithIf(iterOperand, [&](OpOperand &use) { return isa(use.getOwner()); }); diff --git a/compiler/test/Dialect/Linalg/fuse-attention.mlir b/compiler/test/Dialect/Linalg/fuse-attention.mlir index f4868b89f..6595bb838 100644 --- a/compiler/test/Dialect/Linalg/fuse-attention.mlir +++ b/compiler/test/Dialect/Linalg/fuse-attention.mlir @@ -9,12 +9,15 @@ func.func @fuse_dot_attention(%arg0: tensor<1024x32xf32>, %arg1: tensor<32x512xf // CHECK: scf.for // CHECK: linalg.fill // CHECK: linalg.matmul -// CHECK: linalg_ext.softmax +// CHECK: %[[V0:.*]]:4 = linalg_ext.softmax // CHECK: linalg_ext.diag // CHECK: linalg.fill // CHECK: linalg.matmul -// CHECK: linalg.matmul -// CHECK: scf.yield +// CHECK: %[[V1:.*]] = linalg.matmul +// CHECK: %[[INS0:.*]] = tensor.insert_slice %[[V1]] +// CHECK: %[[INS1:.*]] = tensor.insert_slice %[[V0]]#1 +// CHECK: %[[INS2:.*]] = tensor.insert_slice %[[V0]]#2 +// CHECK: scf.yield %[[INS0]], %[[INS1]], %[[INS2]] // CHECK: } // CHECK: scf.yield // CHECK: } @@ -46,6 +49,7 @@ transform.sequence failures(propagate) { %0 = transform.structured.match attributes{"__root__"} in %arg1 : (!pdl.operation) -> !pdl.operation %1, %loops:2 = transform.structured.fuse_ext %0 {tile_sizes = [4, 0, 8], tile_interchange = [2, 1, 0]} transform.structured.tile_loop_hint %1 : !pdl.operation + cleanup } // ----- diff --git a/compiler/test/Dialect/Linalg/transform-op-fuse.mlir b/compiler/test/Dialect/Linalg/transform-op-fuse.mlir index 47ad5e4ae..6147425ae 100644 --- a/compiler/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/compiler/test/Dialect/Linalg/transform-op-fuse.mlir @@ -1136,7 +1136,6 @@ func.func @elew_pad_elew(%arg0: tensor<12x12xf32>) -> tensor<14x14xf32> { %3 = linalg.elemwise_unary {__root__} ins(%padded : tensor<14x14xf32>) outs(%2 : tensor<14x14xf32>) -> tensor<14x14xf32> return %3 : tensor<14x14xf32> } - // CHECK-LABEL: func.func @elew_pad_elew // CHECK: scf.for // CHECK: scf.for @@ -1150,3 +1149,37 @@ func.func @elew_pad_elew(%arg0: tensor<12x12xf32>) -> tensor<14x14xf32> { // CHECK: scf.yield // CHECK: scf.yield +// ----- + +#map = affine_map<(d0) -> (d0)> + +func.func @multi_results_one_in_tile_fuse_path_one_in_terminator(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>, %arg2: tensor<1024xf32>) -> (tensor<1024xf32>, tensor<1024xf32>) { + // CHECK-LABEL: func.func @multi_results_one_in_tile_fuse_path_one_in_terminator + // CHECK: %[[E0:.*]] = tensor.empty() : tensor<1024xf32> + // CHECK: scf.for {{.*}} iter_args(%[[ARG0:.*]] = %[[E0]], %[[ARG1:.*]] = %[[E0]]) + // CHECK: %[[V0:.*]]:2 = linalg.generic{{.*}}__g0__ + // CHECK: %[[V1:.*]] = linalg.generic{{.*}}__g1__ + // CHECK-DAG: %[[INS0:.*]] = tensor.insert_slice %[[V1]] into %[[ARG0]] + // CHECK-DAG: %[[INS1:.*]] = tensor.insert_slice %[[V0]]#1 into %[[ARG1]] + // CHECK: scf.yield %[[INS0]], %[[INS1]] + %0 = tensor.empty() : tensor<1024xf32> + %1:2 = linalg.generic {__g0__, indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1024xf32>, tensor<1024xf32>) outs(%0, %0 : tensor<1024xf32>, tensor<1024xf32>) { + ^bb0(%in_0: f32, %in_1: f32, %out_0: f32, %out_1: f32): + %2 = arith.addf %in_0, %in_1 : f32 + %3 = arith.subf %in_0, %in_1 : f32 + linalg.yield %2, %3 : f32, f32 + } -> (tensor<1024xf32>, tensor<1024xf32>) + %2 = linalg.generic {__root__, __g1__, indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg2, %1#0 : tensor<1024xf32>, tensor<1024xf32>) outs(%0 : tensor<1024xf32>) { + ^bb0(%in_0: f32, %in_1: f32, %out : f32): + %2 = arith.addf %in_0, %in_1 : f32 + linalg.yield %2 : f32 + } -> tensor<1024xf32> + return %1#1, %2 : tensor<1024xf32>, tensor<1024xf32> +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation): + %0 = transform.structured.match attributes {__root__} in %arg0 : (!pdl.operation) -> !pdl.operation + %transformed, %loops = transform.structured.fuse_ext %0 {tile_interchange = [], tile_sizes = [1]} + cleanup +}