diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index 4001ba3fc84c9d..02ffa0da7a8b86 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -111,11 +111,18 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op); void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef> combinedDimensions); -/// Unrolls this for operation by the specified unroll factor. Returns failure -/// if the loop cannot be unrolled either due to restrictions or due to invalid -/// unroll factors. Requires positive loop bounds and step. If specified, -/// annotates the Ops in each unrolled iteration by applying `annotateFn`. -LogicalResult loopUnrollByFactor( +struct UnrolledLoopInfo { + std::optional mainLoopOp = std::nullopt; + std::optional epilogueLoopOp = std::nullopt; +}; + +/// Unrolls this for operation by the specified unroll factor. Returns the +/// unrolled main loop and the eplilog loop, if the loop is unrolled. Otherwise +/// returns failure if the loop cannot be unrolled either due to restrictions or +/// due to invalid unroll factors. Requires positive loop bounds and step. If +/// specified, annotates the Ops in each unrolled iteration by applying +/// `annotateFn`. +FailureOr loopUnrollByFactor( scf::ForOp forOp, uint64_t unrollFactor, function_ref annotateFn = nullptr); diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 43fcc595af0f7e..247311d66ff949 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -372,15 +372,16 @@ static void generateUnrolledLoop( loopBodyBlock->getTerminator()->setOperands(lastYielded); } -/// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled. -LogicalResult mlir::loopUnrollByFactor( +/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the +/// eplilog loop, if the loop is unrolled. +FailureOr mlir::loopUnrollByFactor( scf::ForOp forOp, uint64_t unrollFactor, function_ref annotateFn) { assert(unrollFactor > 0 && "expected positive unroll factor"); // Return if the loop body is empty. if (llvm::hasSingleElement(forOp.getBody()->getOperations())) - return success(); + return UnrolledLoopInfo{forOp, std::nullopt}; // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases. @@ -402,7 +403,7 @@ LogicalResult mlir::loopUnrollByFactor( if (*constTripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter))) return failure(); - return success(); + return UnrolledLoopInfo{forOp, std::nullopt}; } int64_t tripCountEvenMultiple = @@ -450,6 +451,8 @@ LogicalResult mlir::loopUnrollByFactor( boundsBuilder.create(loc, step, unrollFactorCst); } + UnrolledLoopInfo resultLoops; + // Create epilogue clean up loop starting at 'upperBoundUnrolled'. if (generateEpilogueLoop) { OpBuilder epilogueBuilder(forOp->getContext()); @@ -467,7 +470,8 @@ LogicalResult mlir::loopUnrollByFactor( } epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(), epilogueForOp.getInitArgs().size(), results); - (void)epilogueForOp.promoteIfSingleIteration(rewriter); + if (epilogueForOp.promoteIfSingleIteration(rewriter).failed()) + resultLoops.epilogueLoopOp = epilogueForOp; } // Create unrolled loop. @@ -489,8 +493,9 @@ LogicalResult mlir::loopUnrollByFactor( }, annotateFn, iterArgs, yieldedValues); // Promote the loop body up if this has turned into a single iteration loop. - (void)forOp.promoteIfSingleIteration(rewriter); - return success(); + if (forOp.promoteIfSingleIteration(rewriter).failed()) + resultLoops.mainLoopOp = forOp; + return resultLoops; } /// Check if bounds of all inner loops are defined outside of `forOp`