Skip to content

Commit

Permalink
[mlir] Extend SCF loopUnrollByFactor to return the result loops (#114573
Browse files Browse the repository at this point in the history
)

There is a need of accessing the resulted epilog loop from the SC loop
unroller. It'd clean and convenient to get that directly from the loop
unroller instead of rescanning the whole function, as discussed in
triton-lang/triton#5027 . I'm changing the
result type of `loopUnrollByFactor` for that.
  • Loading branch information
htyu authored Nov 4, 2024
1 parent 6127724 commit fa57c7a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
17 changes: 12 additions & 5 deletions mlir/include/mlir/Dialect/SCF/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,18 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
ArrayRef<std::vector<unsigned>> combinedDimensions);

/// Unrolls this for operation by the specified unroll factor. Returns failure
/// if the loop cannot be unrolled either due to restrictions or due to invalid
/// unroll factors. Requires positive loop bounds and step. If specified,
/// annotates the Ops in each unrolled iteration by applying `annotateFn`.
LogicalResult loopUnrollByFactor(
struct UnrolledLoopInfo {
std::optional<scf::ForOp> mainLoopOp = std::nullopt;
std::optional<scf::ForOp> epilogueLoopOp = std::nullopt;
};

/// Unrolls this for operation by the specified unroll factor. Returns the
/// unrolled main loop and the eplilog loop, if the loop is unrolled. Otherwise
/// returns failure if the loop cannot be unrolled either due to restrictions or
/// due to invalid unroll factors. Requires positive loop bounds and step. If
/// specified, annotates the Ops in each unrolled iteration by applying
/// `annotateFn`.
FailureOr<UnrolledLoopInfo> loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);

Expand Down
19 changes: 12 additions & 7 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,15 +372,16 @@ static void generateUnrolledLoop(
loopBodyBlock->getTerminator()->setOperands(lastYielded);
}

/// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
LogicalResult mlir::loopUnrollByFactor(
/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
/// eplilog loop, if the loop is unrolled.
FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
assert(unrollFactor > 0 && "expected positive unroll factor");

// Return if the loop body is empty.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
return success();
return UnrolledLoopInfo{forOp, std::nullopt};

// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
Expand All @@ -402,7 +403,7 @@ LogicalResult mlir::loopUnrollByFactor(
if (*constTripCount == 1 &&
failed(forOp.promoteIfSingleIteration(rewriter)))
return failure();
return success();
return UnrolledLoopInfo{forOp, std::nullopt};
}

int64_t tripCountEvenMultiple =
Expand Down Expand Up @@ -450,6 +451,8 @@ LogicalResult mlir::loopUnrollByFactor(
boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
}

UnrolledLoopInfo resultLoops;

// Create epilogue clean up loop starting at 'upperBoundUnrolled'.
if (generateEpilogueLoop) {
OpBuilder epilogueBuilder(forOp->getContext());
Expand All @@ -467,7 +470,8 @@ LogicalResult mlir::loopUnrollByFactor(
}
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
epilogueForOp.getInitArgs().size(), results);
(void)epilogueForOp.promoteIfSingleIteration(rewriter);
if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
resultLoops.epilogueLoopOp = epilogueForOp;
}

// Create unrolled loop.
Expand All @@ -489,8 +493,9 @@ LogicalResult mlir::loopUnrollByFactor(
},
annotateFn, iterArgs, yieldedValues);
// Promote the loop body up if this has turned into a single iteration loop.
(void)forOp.promoteIfSingleIteration(rewriter);
return success();
if (forOp.promoteIfSingleIteration(rewriter).failed())
resultLoops.mainLoopOp = forOp;
return resultLoops;
}

/// Check if bounds of all inner loops are defined outside of `forOp`
Expand Down

0 comments on commit fa57c7a

Please sign in to comment.