Skip to content

Commit

Permalink
[Linalg] fix intermediate results cleaning in fuse-ext (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochengji authored Aug 30, 2023
1 parent da97be4 commit 2694dd3
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 15 deletions.
41 changes: 30 additions & 11 deletions compiler/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,15 +354,15 @@ LogicalResult mlir::scf::isValidTiling(Operation *tiled) {
//===----------------------------------------------------------------------===//

bool mlir::scf::isResultLoopInvariant(Operation *op, int64_t resultNumber,
bool hasOneOrZeroUse, bool allParallel) {
bool passUsesCheck, bool allParallel) {
if (op == nullptr)
return false;

if (auto linalgExtOp = dyn_cast<linalg_ext::LinalgExtOp>(op)) {
return linalgExtOp.isResultLoopInvariant(resultNumber, hasOneOrZeroUse,
return linalgExtOp.isResultLoopInvariant(resultNumber, passUsesCheck,
allParallel);
} else if (isa<linalg::LinalgOp>(op)) {
return hasOneOrZeroUse && allParallel;
return passUsesCheck && allParallel;
}
return false;
}
Expand Down Expand Up @@ -1281,6 +1281,13 @@ mlir::scf::tileConsumerAndFuseProducerUsingSCFForOpExt(
// check getLoopIteratorTypes for each fusedOp
// if parallel, corresponding getRegionIterArgs will be simplified
unsigned resultOffset = 0;

llvm::DenseSet<Operation *> unfusedOpsSet;
for (auto &p : fusedOps) {
Operation *unfusedOp = p.first;
unfusedOpsSet.insert(unfusedOp);
}

for (const auto &p : fusedOps) {
auto unfusedOp = p.first;
auto fusedOp = p.second;
Expand All @@ -1300,18 +1307,19 @@ mlir::scf::tileConsumerAndFuseProducerUsingSCFForOpExt(
auto result = unfusedOp->getResult(i);

auto effectiveUseCnt =
llvm::count_if(result.getUses(), [](OpOperand &opOperand) {
llvm::count_if(result.getUses(), [&](OpOperand &opOperand) {
if (unfusedOpsSet.contains(opOperand.getOwner()))
return false;

if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(
opOperand.getOwner())) {
return !dstOp.isDpsInit(&opOperand);
}

return !isa<tensor::DimOp>(opOperand.getOwner());
});

bool hasOneOrZeroUseGeneral =
unfusedOp == consumer ? effectiveUseCnt < 1 : effectiveUseCnt <= 1;

bool hasOneOrZeroUseForExtract = effectiveUseCnt <= 1;
bool hasZeroOutsideUse = effectiveUseCnt == 0;

auto confirmAllParallel = [&](size_t loopCnt) {
bool allParallel = true;
Expand Down Expand Up @@ -1343,17 +1351,28 @@ mlir::scf::tileConsumerAndFuseProducerUsingSCFForOpExt(
auto iterArg = forOp.getRegionIterArg(resultOffset + i);
auto iterOperand = forOp.getIterOperands()[resultOffset + i];

if (isResultLoopInvariant(unfusedOp, i, hasOneOrZeroUseGeneral,
if (isResultLoopInvariant(unfusedOp, i, hasZeroOutsideUse,
confirmedAllParallel)) {
iterArg.replaceUsesWithIf(iterOperand, [&](OpOperand &use) {
return (opCollection.contains(use.getOwner()) ||
valCollection.contains(use.get()));
});
}

// The following replace is used to optimize the following IR:
//
// %0 = tensor.empty
// scf.for ... (%arg0 = %0, ...)
// %1 = tensor.extract %arg0
// "use"(%1)...
//
// to
//
// scf.for ...
// %0 = tensor.empty
// "use"(%0)
if (simplifyLoopIter &&
isResultLoopInvariant(unfusedOp, i, hasOneOrZeroUseForExtract,
confirmedAllParallel)) {
isResultLoopInvariant(unfusedOp, i, true, confirmedAllParallel)) {
iterArg.replaceUsesWithIf(iterOperand, [&](OpOperand &use) {
return isa<tensor::ExtractSliceOp>(use.getOwner());
});
Expand Down
10 changes: 7 additions & 3 deletions compiler/test/Dialect/Linalg/fuse-attention.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ func.func @fuse_dot_attention(%arg0: tensor<1024x32xf32>, %arg1: tensor<32x512xf
// CHECK: scf.for
// CHECK: linalg.fill
// CHECK: linalg.matmul
// CHECK: linalg_ext.softmax
// CHECK: %[[V0:.*]]:4 = linalg_ext.softmax
// CHECK: linalg_ext.diag
// CHECK: linalg.fill
// CHECK: linalg.matmul
// CHECK: linalg.matmul
// CHECK: scf.yield
// CHECK: %[[V1:.*]] = linalg.matmul
// CHECK: %[[INS0:.*]] = tensor.insert_slice %[[V1]]
// CHECK: %[[INS1:.*]] = tensor.insert_slice %[[V0]]#1
// CHECK: %[[INS2:.*]] = tensor.insert_slice %[[V0]]#2
// CHECK: scf.yield %[[INS0]], %[[INS1]], %[[INS2]]
// CHECK: }
// CHECK: scf.yield
// CHECK: }
Expand Down Expand Up @@ -46,6 +49,7 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match attributes{"__root__"} in %arg1 : (!pdl.operation) -> !pdl.operation
%1, %loops:2 = transform.structured.fuse_ext %0 {tile_sizes = [4, 0, 8], tile_interchange = [2, 1, 0]}
transform.structured.tile_loop_hint %1 : !pdl.operation
cleanup
}

// -----
Expand Down
35 changes: 34 additions & 1 deletion compiler/test/Dialect/Linalg/transform-op-fuse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,6 @@ func.func @elew_pad_elew(%arg0: tensor<12x12xf32>) -> tensor<14x14xf32> {
%3 = linalg.elemwise_unary {__root__} ins(%padded : tensor<14x14xf32>) outs(%2 : tensor<14x14xf32>) -> tensor<14x14xf32>
return %3 : tensor<14x14xf32>
}

// CHECK-LABEL: func.func @elew_pad_elew
// CHECK: scf.for
// CHECK: scf.for
Expand All @@ -1150,3 +1149,37 @@ func.func @elew_pad_elew(%arg0: tensor<12x12xf32>) -> tensor<14x14xf32> {
// CHECK: scf.yield
// CHECK: scf.yield

// -----

#map = affine_map<(d0) -> (d0)>

func.func @multi_results_one_in_tile_fuse_path_one_in_terminator(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>, %arg2: tensor<1024xf32>) -> (tensor<1024xf32>, tensor<1024xf32>) {
// CHECK-LABEL: func.func @multi_results_one_in_tile_fuse_path_one_in_terminator
// CHECK: %[[E0:.*]] = tensor.empty() : tensor<1024xf32>
// CHECK: scf.for {{.*}} iter_args(%[[ARG0:.*]] = %[[E0]], %[[ARG1:.*]] = %[[E0]])
// CHECK: %[[V0:.*]]:2 = linalg.generic{{.*}}__g0__
// CHECK: %[[V1:.*]] = linalg.generic{{.*}}__g1__
// CHECK-DAG: %[[INS0:.*]] = tensor.insert_slice %[[V1]] into %[[ARG0]]
// CHECK-DAG: %[[INS1:.*]] = tensor.insert_slice %[[V0]]#1 into %[[ARG1]]
// CHECK: scf.yield %[[INS0]], %[[INS1]]
%0 = tensor.empty() : tensor<1024xf32>
%1:2 = linalg.generic {__g0__, indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1024xf32>, tensor<1024xf32>) outs(%0, %0 : tensor<1024xf32>, tensor<1024xf32>) {
^bb0(%in_0: f32, %in_1: f32, %out_0: f32, %out_1: f32):
%2 = arith.addf %in_0, %in_1 : f32
%3 = arith.subf %in_0, %in_1 : f32
linalg.yield %2, %3 : f32, f32
} -> (tensor<1024xf32>, tensor<1024xf32>)
%2 = linalg.generic {__root__, __g1__, indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg2, %1#0 : tensor<1024xf32>, tensor<1024xf32>) outs(%0 : tensor<1024xf32>) {
^bb0(%in_0: f32, %in_1: f32, %out : f32):
%2 = arith.addf %in_0, %in_1 : f32
linalg.yield %2 : f32
} -> tensor<1024xf32>
return %1#1, %2 : tensor<1024xf32>, tensor<1024xf32>
}

transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
%0 = transform.structured.match attributes {__root__} in %arg0 : (!pdl.operation) -> !pdl.operation
%transformed, %loops = transform.structured.fuse_ext %0 {tile_interchange = [], tile_sizes = [1]}
cleanup
}

0 comments on commit 2694dd3

Please sign in to comment.