diff --git a/compiler/include/byteir/Dialect/MemRef/Transforms/MultiBufferExt.h b/compiler/include/byteir/Dialect/MemRef/Transforms/MultiBufferExt.h deleted file mode 100644 index 532dcee5b..000000000 --- a/compiler/include/byteir/Dialect/MemRef/Transforms/MultiBufferExt.h +++ /dev/null @@ -1,81 +0,0 @@ -//===- RemoveCopy.h -------------------------------------------*--- C++ -*-===// -// -// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef BYTEIR_DIALECT_MEMREF_TRANSFORMS_MULTIBUFFEREXT_H -#define BYTEIR_DIALECT_MEMREF_TRANSFORMS_MULTIBUFFEREXT_H - -#include "mlir/Support/LogicalResult.h" -#include "llvm/ADT/STLFunctionalExtras.h" - -namespace mlir { -class OpBuilder; -class RewritePatternSet; -class RewriterBase; -class Value; -class ValueRange; - -namespace arith { -class WideIntEmulationConverter; -class NarrowTypeEmulationConverter; -} // namespace arith - -namespace memref { -class AllocOp; -class AllocaOp; -class DeallocOp; - -/// Transformation to do multi-buffering/array expansion to remove dependencies -/// on the temporary allocation between consecutive loop iterations. -/// It returns the new allocation if the original allocation was multi-buffered -/// and returns failure() otherwise. -/// When `skipOverrideAnalysis`, the pass will apply the transformation -/// without checking thwt the buffer is overrided at the beginning of each -/// iteration. This implies that user knows that there is no data carried across -/// loop iterations. Example: -/// ``` -/// %0 = memref.alloc() : memref<4x128xf32> -/// scf.for %iv = %c1 to %c1024 step %c3 { -/// memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32> -/// "some_use"(%0) : (memref<4x128xf32>) -> () -/// } -/// ``` -/// into: -/// ``` -/// %0 = memref.alloc() : memref<5x4x128xf32> -/// scf.for %iv = %c1 to %c1024 step %c3 { -/// %s = arith.subi %iv, %c1 : index -/// %d = arith.divsi %s, %c3 : index -/// %i = arith.remsi %d, %c5 : index -/// %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] : -/// memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>> -/// memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, strided<...>> -/// "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> () -/// } -/// ``` -template -FailureOr multiBufferExt(RewriterBase &rewriter, - AllocOpType allocOp, unsigned multiplier, - bool skipOverrideAnalysis = false); -/// Call into `multiBuffer` with locally constructed IRRewriter. -template -FailureOr multiBufferExt(AllocOpType allocOp, unsigned multiplier, - bool skipOverrideAnalysis = false); - -} // namespace memref -} // namespace mlir - -#endif // BYTEIR_DIALECT_MEMREF_TRANSFORMS_MULTIBUFFEREXT_H \ No newline at end of file diff --git a/compiler/lib/Dialect/MemRef/CMakeLists.txt b/compiler/lib/Dialect/MemRef/CMakeLists.txt index c76cf1281..9304445ce 100644 --- a/compiler/lib/Dialect/MemRef/CMakeLists.txt +++ b/compiler/lib/Dialect/MemRef/CMakeLists.txt @@ -1,7 +1,6 @@ add_mlir_dialect_library(ByteIRMemRefPasses Transforms/ApplyMemRefAffineLayout.cpp Transforms/ExtractAddressComputation.cpp - Transforms/MultiBufferExt.cpp Transforms/RemoveCopy.cpp Transforms/SimplifyLinearizedIndex.cpp Transforms/SimplifyView.cpp diff --git a/compiler/lib/Dialect/MemRef/Transforms/MultiBufferExt.cpp b/compiler/lib/Dialect/MemRef/Transforms/MultiBufferExt.cpp deleted file mode 100644 index 639170d4d..000000000 --- a/compiler/lib/Dialect/MemRef/Transforms/MultiBufferExt.cpp +++ /dev/null @@ -1,283 +0,0 @@ -//===- MultiBufferExt.cpp -----------------------------------------*--- C++ -//-*-===// -// -// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -// Some code comes from mlir/lib/Dialect/Memref/Transforms/MultiBuffer.cpp of -// LLVM Project. -// Original license: -//===----------- MultiBuffering.cpp ---------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/MemRef/Transforms/Transforms.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Dominance.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Interfaces/LoopLikeInterface.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" - -using namespace mlir; - -#define DEBUG_TYPE "memref-multi-buffer-ext" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") - -/// Return true if the op fully overwrite the given `buffer` value. -static bool overrideBuffer(Operation *op, Value buffer) { - auto memrefCopyOp = dyn_cast(op); - auto linalgCopyOp = dyn_cast(op); - if (memrefCopyOp) - return memrefCopyOp.getTarget() == buffer; - if (linalgCopyOp) - return linalgCopyOp.getDpsInitOperand(0)->get() == buffer; - return false; -} - -/// Replace the uses of `oldOp` with the given `val` and for subview uses -/// propagate the type change. Changing the memref type may require propagating -/// it through subview ops so we cannot just do a replaceAllUse but need to -/// propagate the type change and erase old subview ops. -static void replaceUsesAndPropagateType(RewriterBase &rewriter, - Operation *oldOp, Value val) { - SmallVector opsToDelete; - SmallVector operandsToReplace; - - // Save the operand to replace / delete later (avoid iterator invalidation). - // TODO: can we use an early_inc iterator? - for (OpOperand &use : oldOp->getUses()) { - // Non-subview ops will be replaced by `val`. - auto subviewUse = dyn_cast(use.getOwner()); - if (!subviewUse) { - operandsToReplace.push_back(&use); - continue; - } - - // `subview(old_op)` is replaced by a new `subview(val)`. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(subviewUse); - Type newType = memref::SubViewOp::inferRankReducedResultType( - subviewUse.getType().getShape(), cast(val.getType()), - subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), - subviewUse.getStaticStrides()); - Value newSubview = rewriter.create( - subviewUse->getLoc(), cast(newType), val, - subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), - subviewUse.getMixedStrides()); - - // Ouch recursion ... is this really necessary? - replaceUsesAndPropagateType(rewriter, subviewUse, newSubview); - - opsToDelete.push_back(use.getOwner()); - } - - // Perform late replacement. - // TODO: can we use an early_inc iterator? - for (OpOperand *operand : operandsToReplace) { - Operation *op = operand->getOwner(); - rewriter.startRootUpdate(op); - operand->set(val); - rewriter.finalizeRootUpdate(op); - } - - // Perform late op erasure. - // TODO: can we use an early_inc iterator? - for (Operation *op : opsToDelete) - rewriter.eraseOp(op); -} - -namespace mlir { -namespace memref { - -// Transformation to do multi-buffering/array expansion to remove dependencies -// on the temporary allocation between consecutive loop iterations. -// Returns success if the transformation happened and failure otherwise. -// This is not a pattern as it requires propagating the new memref type to its -// uses and requires updating subview ops. -template -FailureOr -multiBufferExt(RewriterBase &rewriter, AllocOpType allocOp, - unsigned multiBufferingFactor, bool skipOverrideAnalysis) { - LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n"); - DominanceInfo dom(allocOp->getParentOp()); - LoopLikeOpInterface candidateLoop; - for (Operation *user : allocOp->getUsers()) { - auto parentLoop = user->getParentOfType(); - if (!parentLoop) { - if (isa(user)) { - // Allow dealloc outside of any loop. - // TODO: The whole precondition function here is very brittle and will - // need to rethought an isolated into a cleaner analysis. - continue; - } - LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n"); - LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n"); - return failure(); - } - if (!skipOverrideAnalysis) { - /// Make sure there is no loop-carried dependency on the allocation. - if (!overrideBuffer(user, allocOp.getResult())) { - LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n"); - continue; - } - // If this user doesn't dominate all the other users keep looking. - if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { - return !dom.dominates(user, otherUser); - })) { - LLVM_DEBUG( - DBGS() << "--Skip user: does not dominate all other users\n"); - continue; - } - } else { - if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { - return !isa(otherUser) && - !parentLoop->isProperAncestor(otherUser); - })) { - LLVM_DEBUG( - DBGS() - << "--Skip user: not all other users are in the parent loop\n"); - continue; - } - } - candidateLoop = parentLoop; - break; - } - - if (!candidateLoop) { - LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n"); - return failure(); - } - - std::optional inductionVar = candidateLoop.getSingleInductionVar(); - std::optional lowerBound = candidateLoop.getSingleLowerBound(); - std::optional singleStep = candidateLoop.getSingleStep(); - if (!inductionVar || !lowerBound || !singleStep || - !llvm::hasSingleElement(candidateLoop.getLoopRegions())) { - LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n"); - return failure(); - } - - if (!dom.dominates(allocOp.getOperation(), candidateLoop)) { - LLVM_DEBUG(DBGS() << "Skip alloc: does not dominate candidate loop\n"); - return failure(); - } - - LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n"); - - // 1. Construct the multi-buffered memref type. - ArrayRef originalShape = allocOp.getType().getShape(); - SmallVector multiBufferedShape{multiBufferingFactor}; - llvm::append_range(multiBufferedShape, originalShape); - LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n"); - MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType()) - .setShape(multiBufferedShape) - .setLayout(MemRefLayoutAttrInterface()); - LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n"); - - // 2. Create the multi-buffered alloc. - Location loc = allocOp->getLoc(); - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(allocOp); - auto mbAlloc = rewriter.create(loc, mbMemRefType, ValueRange{}, - allocOp->getAttrs()); - LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n"); - - // 3. Within the loop, build the modular leading index (i.e. each loop - // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor). - rewriter.setInsertionPointToStart( - &candidateLoop.getLoopRegions().front()->front()); - Value ivVal = *inductionVar; - Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound); - Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep); - AffineExpr iv, lb, step; - bindDims(rewriter.getContext(), iv, lb, step); - Value bufferIndex = affine::makeComposedAffineApply( - rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor, - {ivVal, lbVal, stepVal}); - LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n"); - - // 4. Build the subview accessing the particular slice, taking modular - // rotation into account. - int64_t mbMemRefTypeRank = mbMemRefType.getRank(); - IntegerAttr zero = rewriter.getIndexAttr(0); - IntegerAttr one = rewriter.getIndexAttr(1); - SmallVector offsets(mbMemRefTypeRank, zero); - SmallVector sizes(mbMemRefTypeRank, one); - SmallVector strides(mbMemRefTypeRank, one); - // Offset is [bufferIndex, 0 ... 0 ]. - offsets.front() = bufferIndex; - // Sizes is [1, original_size_0 ... original_size_n ]. - for (int64_t i = 0, e = originalShape.size(); i != e; ++i) - sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]); - // Strides is [1, 1 ... 1 ]. - auto dstMemref = - cast(memref::SubViewOp::inferRankReducedResultType( - originalShape, mbMemRefType, offsets, sizes, strides)); - Value subview = rewriter.create(loc, dstMemref, mbAlloc, - offsets, sizes, strides); - LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); - - // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to - // handle dealloc uses separately.. - for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) { - auto deallocOp = dyn_cast(use.getOwner()); - if (!deallocOp) - continue; - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(deallocOp); - auto newDeallocOp = - rewriter.create(deallocOp->getLoc(), mbAlloc); - (void)newDeallocOp; - LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n"); - rewriter.eraseOp(deallocOp); - } - - // 6. RAUW with the particular slice, taking modular rotation into account. - replaceUsesAndPropagateType(rewriter, allocOp, subview); - - // 7. Finally, erase the old allocOp. - rewriter.eraseOp(allocOp); - - return mbAlloc; -} - -template -FailureOr multiBufferExt(AllocOpType allocOp, - unsigned multiBufferingFactor, - bool skipOverrideAnalysis) { - IRRewriter rewriter(allocOp->getContext()); - return multiBufferExt(rewriter, allocOp, multiBufferingFactor, - skipOverrideAnalysis); -} - -template FailureOr multiBufferExt(memref::AllocOp, unsigned, - bool); -template FailureOr multiBufferExt(memref::AllocaOp, unsigned, - bool); -} // namespace memref -} // namespace mlir \ No newline at end of file