Skip to content

Commit

Permalink
AIRFuseChannels: A number of hotfixes (#683)
Browse files Browse the repository at this point in the history
* Relax unnecessary condition on the puts when fusion channels in time

* Fixup an issue where channel indices are in reverse order

* Ensure correct ordering of channel ops after channel fusion

* Correct air.channel ordering when they are generated

* Force generation of a for loop when channels are fused in NFL mode

* AIRSegmentLoopFusion pass fusing non-perfect for loop nest generated from AIRFuseChannels

* Switch to using the other hasNElements impl.

* Fixup error where air-loop-fusion fusing for loops using alloc's ordering

* Avoid fusing NFL into for loops if NFL puts/gets have differing data access pattern

* Ensure shim dma loop splitting respect data dependency per shim dma channel

* Update (correct) the ordering of channel puts and gets in test

* Update the test to check for NFL mode generating new for loop

* Change loop nest comparison ordering, now that the ordering of air.channels have been corrected
  • Loading branch information
erwei-xilinx authored Jul 30, 2024
1 parent d849f3a commit a913d16
Show file tree
Hide file tree
Showing 13 changed files with 473 additions and 99 deletions.
26 changes: 19 additions & 7 deletions mlir/lib/Conversion/AIRRtToNpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,21 +466,33 @@ void identifyTargetAffineForAndOps(
// Identify the target for loops and their target child ops
int index = 0;
for (auto for_op : f.getBody().getOps<affine::AffineForOp>()) {
SmallVector<StringRef> metadataVec;
// Get for_op's immediate child op
for_op.walk([&](airrt::DmaMemcpyNdOp memcpyOp) {
// Get for_op's immediate child op
target_ops_vec.push_back(SmallVector<Operation *>{});
StringRef metadata =
memcpyOp->getAttrOfType<mlir::FlatSymbolRefAttr>("metadata")
.getValue();
// Check if any operand's defining ops needs to be hoisted together.
SmallVector<Operation *> oper_def_ops;
xilinx::air::getDefiningOpsToOperands(memcpyOp.getOperation(),
oper_def_ops);
for (auto o : oper_def_ops) {
if (o->getParentOp() == memcpyOp->getParentOp()) {
push_back_if_unique<Operation *>(target_ops_vec[index], o);
}

// Ensure memcpy ops operating on the same metadata (i.e. the same shim
// dma channel) are hoisted together, to maintain data dependency.
auto it = std::find(metadataVec.begin(), metadataVec.end(), metadata);
if (it != metadataVec.end())
index = it - metadataVec.begin();
else {
metadataVec.push_back(metadata);
target_ops_vec.push_back(SmallVector<Operation *>{});
}

for (auto o : oper_def_ops)
if (o->getParentOp() == memcpyOp->getParentOp())
push_back_if_unique<Operation *>(target_ops_vec[index], o);
push_back_if_unique<Operation *>(target_ops_vec[index],
memcpyOp.getOperation());
index++;
index = target_ops_vec.size();
});
}
}
Expand Down
182 changes: 150 additions & 32 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"

#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
Expand Down Expand Up @@ -3389,6 +3390,12 @@ class AIRFuseChannels
} else if (std::get<1>(checkScfForMergeableRes) == "NFL") {
// Fuse air.channels temporally, if there isn't any for loop to fuse
// into.
createDummyForOpsAroundOps<air::ChannelPutOp>(
getChannelPutsFusableByFor(chanA, chanB));
createDummyForOpsAroundOps<air::ChannelGetOp>(
getChannelGetsFusableByFor(chanA, chanB));
mergeChannelOpsTemporally(chanA, chanB, "UB");
// Fuse both channel ops into the loop
chan_merge_map[chanB] = chanA;
}
}
Expand Down Expand Up @@ -3439,6 +3446,86 @@ class AIRFuseChannels
SmallVector<unsigned> targetMemorySpaces;

private:
// Get a vector of channel ops which can be fused using a new for loop.
template <typename T>
bool areConsistentMemoryAccessPattern(std::vector<T> a_vec,
std::vector<T> b_vec) {
Value memref = a_vec[0].getMemref();
SmallVector<Value> offsets = a_vec[0].getOffsets();
SmallVector<Value> sizes = a_vec[0].getSizes();
SmallVector<Value> strides = a_vec[0].getStrides();
for (unsigned i = 1; i < a_vec.size(); i++)
if ((!areTheSameMemref(memref, a_vec[i].getMemref())) ||
(!areTheSameSSAValueLists(offsets, a_vec[i].getOffsets())) ||
(!areTheSameSSAValueLists(sizes, a_vec[i].getSizes())) ||
(!areTheSameSSAValueLists(strides, a_vec[i].getStrides())))
return false; // Inconsistent memory use for all puts
for (unsigned i = 0; i < b_vec.size(); i++)
if ((!areTheSameMemref(memref, b_vec[i].getMemref())) ||
(!areTheSameSSAValueLists(offsets, b_vec[i].getOffsets())) ||
(!areTheSameSSAValueLists(sizes, b_vec[i].getSizes())) ||
(!areTheSameSSAValueLists(strides, b_vec[i].getStrides())))
return false; // Inconsistent memory use between a puts and b puts
return true;
}
std::vector<air::ChannelPutOp>
getChannelPutsFusableByFor(air::ChannelOp chanA, air::ChannelOp chanB) {
std::vector<air::ChannelPutOp> a_puts = getChannelPutOpThroughSymbol(chanA);
std::vector<air::ChannelPutOp> b_puts = getChannelPutOpThroughSymbol(chanB);

if (areConsistentMemoryAccessPattern<air::ChannelPutOp>(a_puts, b_puts))
return a_puts;
else
return std::vector<air::ChannelPutOp>{};
}
std::vector<air::ChannelGetOp>
getChannelGetsFusableByFor(air::ChannelOp chanA, air::ChannelOp chanB) {
std::vector<air::ChannelGetOp> a_gets = getChannelGetOpThroughSymbol(chanA);
std::vector<air::ChannelGetOp> b_gets = getChannelGetOpThroughSymbol(chanB);

if (areConsistentMemoryAccessPattern<air::ChannelGetOp>(a_gets, b_gets))
return a_gets;
else
return std::vector<air::ChannelGetOp>{};
}
// Create single-iteration for loops around a vector of operations of type T.
template <typename T>
void createDummyForOpsAroundOps(std::vector<T> ops) {
for (auto t_o : ops) {
Operation *op = t_o.getOperation();
OpBuilder builder(op);
IRMapping remap;
auto loc = op->getLoc();
auto ctx = op->getContext();
auto zeroIdx = builder.create<arith::ConstantIndexOp>(loc, 0);
auto oneIdx = builder.create<arith::ConstantIndexOp>(loc, 1);
auto newForOp = scf::ForOp();

if (air::getAsyncTokenFromOp(op)) {
newForOp = builder.create<scf::ForOp>(
loc, zeroIdx, oneIdx, oneIdx,
builder
.create<air::WaitAllOp>(loc, air::AsyncTokenType::get(ctx),
air::getAsyncDependenciesFromOp(op))
.getAsyncToken());
for (auto dep : air::getAsyncDependenciesFromOp(op))
remap.map(dep, newForOp.getRegionIterArgs()[0]);
} else
newForOp = builder.create<scf::ForOp>(loc, zeroIdx, oneIdx, oneIdx);
builder.setInsertionPointToStart(newForOp.getBody());
auto newOp = dyn_cast<T>(builder.clone(*op, remap));

if (auto oldAsyncToken = air::getAsyncTokenFromOp(op)) {
builder.create<scf::YieldOp>(loc, newOp.getAsyncToken());
oldAsyncToken.replaceAllUsesWith(newForOp->getResult(0));
} else
builder.create<scf::YieldOp>(loc);
}
for (auto e : ops)
e->erase();
return;
}

void sortChannelsByLoopNests(air::ChannelOp &chan_a, air::ChannelOp &chan_b) {
std::vector<air::ChannelPutOp> a_puts =
getChannelPutOpThroughSymbol(chan_a);
Expand Down Expand Up @@ -3567,31 +3654,31 @@ class AIRFuseChannels
for (int i = 0; i < (int)max_loop_nest_count; i++)
if (i != outermostScfFor)
controlLoopIndices.push_back(i);
// TODO: Assuming b_loop_nest is before a_loop_nest. Always true? TODO:
// TODO: Assuming a_loop_nest is before b_loop_nest. Always true? TODO:
// formalize using async dep.
unsigned index = 0;
if (a_loop_nest.size() > b_loop_nest.size()) {
if (a_loop_nest.size() < b_loop_nest.size()) {
for (auto i : controlLoopIndices) {
if (!areEquivalentControlLoops(a_loop_nest[i], b_loop_nest[index++]))
if (!areEquivalentControlLoops(a_loop_nest[index++], b_loop_nest[i]))
return notMergeable;
}
// Check if the skipped scf.for loop has LB >= 1. This is a sign of
// peeling, indicating opportunity of merge by unpeeling into LB.
auto outerMostScfFor =
dyn_cast<scf::ForOp>(a_loop_nest[outermostScfFor]->getParentOp());
dyn_cast<scf::ForOp>(b_loop_nest[outermostScfFor]->getParentOp());
assert(outerMostScfFor);
if (auto constLB = getConstantIntValue(outerMostScfFor.getLowerBound()))
if (*constLB < 1)
return notMergeable;
return mergeableToLB;
} else {
for (auto i : controlLoopIndices) {
if (!areEquivalentControlLoops(a_loop_nest[index++], b_loop_nest[i]))
if (!areEquivalentControlLoops(a_loop_nest[i], b_loop_nest[index++]))
return notMergeable;
}
// Merge by unpeeling into UB.
auto outerMostScfFor =
dyn_cast<scf::ForOp>(b_loop_nest[outermostScfFor]->getParentOp());
dyn_cast<scf::ForOp>(a_loop_nest[outermostScfFor]->getParentOp());
assert(outerMostScfFor);
return mergeableToUB;
}
Expand Down Expand Up @@ -3706,10 +3793,7 @@ class AIRFuseChannels
(!areTheSameSSAValueLists(bSizes, b_puts[i].getSizes())) ||
(!areTheSameSSAValueLists(bStrides, b_puts[i].getStrides())))
return notMergeable; // Inconsistent memory use for all puts
if ((!areTheSameMemref(aMemref, bMemref)) ||
(!areTheSameSSAValueLists(aOffsets, bOffsets)) ||
(!areTheSameSSAValueLists(aSizes, bSizes)) ||
(!areTheSameSSAValueLists(aStrides, bStrides)))
if ((!areTheSameMemref(aMemref, bMemref)))
return notMergeable;
aMemref = a_gets[0].getMemref();
aOffsets = a_gets[0].getOffsets();
Expand Down Expand Up @@ -3886,17 +3970,20 @@ class AIRFuseChannels
IRMapping remap;
remapAllParentLoopArgs(remap, a, b);
OpBuilder builder(a);
builder.setInsertionPoint(a->getBlock()->getTerminator());
auto new_b = cloneOpAndOperands(builder, remap, b);
eraseParentLoopIfEmpty(*b);
auto async_a = dyn_cast<air::AsyncOpInterface>(a.getOperation());
if (async_a.getAsyncToken())
async_a.addAsyncDependency(
dyn_cast<air::AsyncOpInterface>(new_b).getAsyncToken());
auto async_b = dyn_cast<air::AsyncOpInterface>(new_b);
if (async_b.getAsyncToken())
async_b.addAsyncDependency(
dyn_cast<air::AsyncOpInterface>(a.getOperation()).getAsyncToken());
}
void mergeChannelOpsTemporally(air::ChannelInterface a,
air::ChannelInterface b,
std::string mergeByLBOrUB) {
scf::ForOp parentForOp = a->getParentOfType<scf::ForOp>();
if (!parentForOp)
return;
while (parentForOp && parentForOp->getParentOfType<scf::ForOp>()) {
parentForOp = parentForOp->getParentOfType<scf::ForOp>();
}
Expand Down Expand Up @@ -3926,12 +4013,10 @@ class AIRFuseChannels
std::vector<air::ChannelGetOp> b_gets =
getChannelGetOpThroughSymbol(chan_b);
// Interleave puts and gets
for (unsigned i = 0; i < a_puts.size(); i++) {
for (unsigned i = 0; i < a_puts.size(); i++)
mergeChannelOps(a_puts[i], b_puts[i]);
}
for (unsigned i = 0; i < a_gets.size(); i++) {
for (unsigned i = 0; i < a_gets.size(); i++)
mergeChannelOps(a_gets[i], b_gets[i]);
}
}
void mergeChannelOpsTemporally(air::ChannelOp chan_a, air::ChannelOp chan_b,
std::string mergeByLBOrUB) {
Expand Down Expand Up @@ -4680,16 +4765,23 @@ struct AIRSegmentLoopFusionPattern : public OpRewritePattern<air::SegmentOp> {
auto hasNElements = [](Block *block, unsigned N) {
unsigned counter = 0;
for (auto &o : block->getOperations()) {
if (o.mightHaveTrait<OpTrait::IsTerminator>())
continue;
if (isa<air::WaitAllOp>(o))
continue;
counter++;
if (isa<air::ChannelInterface>(o))
counter++;
else if (isa<LoopLikeOpInterface>(o))
counter++;
else if (isa<mlir::linalg::LinalgOp>(o))
counter++;
}
return counter == N;
};
for (auto forOp : op.getOps<scf::ForOp>()) {
if (hasNElements(forOp.getBody(), 1) &&
// Conditions for candicate scf.for op for fusion: (1) has at most 1
// unique channels operating in the block, (2) is either perfectly nested,
// or contains only channel ops, (3) is static for loop.
if (getNumUniqueChannelsInBlock(forOp.getBody()) <= 1 &&
hasNElements(
forOp.getBody(),
std::max(getNumChannelPutsGetsInBlock(forOp.getBody()), 1)) &&
air::getStaticScfForTripCountAsInt(forOp))
perfectlyNestedForBands.push_back(forOp);
}
Expand Down Expand Up @@ -4727,13 +4819,18 @@ struct AIRSegmentLoopFusionPattern : public OpRewritePattern<air::SegmentOp> {

// Folding memref.alloc / dealloc ops into fused loop.
SmallVector<scf::ForOp> fusableForOps;
for (auto execOpPair : alloc_dealloc_execs) {
air::ExecuteOp alloc_exec = execOpPair.first;
for (auto token_user : alloc_exec.getAsyncToken().getUsers())
if (llvm::any_of(equalIterationForOps, [&](scf::ForOp fusableForOp) {
return fusableForOp == token_user;
}))
fusableForOps.push_back(dyn_cast<scf::ForOp>(token_user));
for (auto forOp : equalIterationForOps) {
for (auto ia : forOp.getInitArgs()) {
auto iaDefOp = ia.getDefiningOp();
if (!iaDefOp)
continue;
if (llvm::any_of(
alloc_dealloc_execs,
[&](std::pair<air::ExecuteOp, air::ExecuteOp> exec_pair) {
return exec_pair.first == iaDefOp;
}))
fusableForOps.push_back(forOp);
}
}
if (fusableForOps.empty())
return failure();
Expand Down Expand Up @@ -4869,6 +4966,25 @@ struct AIRSegmentLoopFusionPattern : public OpRewritePattern<air::SegmentOp> {
}

private:
// Get the number of air.channel.puts/gets in block.
int getNumChannelPutsGetsInBlock(Block *block) const {
int count = 0;
for (auto &o : block->getOperations())
if (auto chanOp = dyn_cast<air::ChannelInterface>(o))
count++;
return count;
}

// Get the total number of unique air.channels that all air.channel.puts/gets
// in block operate on.
int getNumUniqueChannelsInBlock(Block *block) const {
llvm::SmallSet<std::string, 1> chanNamesInBlock;
for (auto &o : block->getOperations())
if (auto chanOp = dyn_cast<air::ChannelInterface>(o))
chanNamesInBlock.insert(chanOp.getChanName().str());
return chanNamesInBlock.size();
}

// Scf.for loop tiling. This simple tiling implementation generates a new
// inner scf.for loop which starts from the original loop's lower bound. It
// may change the meaning of the original scf.for loop, therefore it requires
Expand Down Expand Up @@ -4977,7 +5093,9 @@ class AIRSegmentLoopFusion
void runPreProcPatterns(func::FuncOp funcOp) {
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet patterns(&getContext());
patterns.insert<CanonicalizeAffineApplyOnLoopInductionVar>(ctx);
patterns
.insert<CanonicalizeAffineApplyOnLoopInductionVar, UnrollScfParallel>(
ctx);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

Expand Down
5 changes: 4 additions & 1 deletion mlir/lib/Transform/AIRDmaToChannel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,10 @@ static air::ChannelOp
createChannelOpWithBCast(OpBuilder builder, ModuleOp module, std::string cname,
Location loc, SmallVector<int64_t, 2> bcast_sizes) {
auto insertionCheckpoint = builder.saveInsertionPoint();
builder.setInsertionPointToStart(module.getBody());
Operation *o = &module.getBody()->front();
while (dyn_cast_or_null<air::ChannelOp>(o))
o = o->getNextNode();
builder.setInsertionPoint(o);

auto channel_op = builder.create<air::ChannelOp>(
loc, cname, builder.getI64ArrayAttr(bcast_sizes));
Expand Down
7 changes: 5 additions & 2 deletions mlir/lib/Transform/AIRMiscPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1481,8 +1481,11 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() {
auto loc = chanUserOp->getLoc();
auto ctx = chanUserOp->getContext();
OpBuilder builder(chanUserOp);
builder.setInsertionPointToStart(
chanUserOp->getParentOfType<ModuleOp>().getBody());
Operation *o =
&chanUserOp->getParentOfType<ModuleOp>().getBody()->front();
while (dyn_cast_or_null<air::ChannelOp>(o))
o = o->getNextNode();
builder.setInsertionPoint(o);
SmallVector<Type, 4> tys = {
air::AsyncTokenType::get(chanUserOp->getContext())};
auto cname =
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Conversion/ConvertToAIR/dma_to_channel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

#map = affine_map<()[s0] -> (s0 * 32)>
module attributes {torch.debug_module_name = "mmult"} {
// CHECK: air.channel @channel_3 [2, 2]
// CHECK: air.channel @channel_2 [2, 2]
// CHECK: air.channel @channel_1 [2, 2]
// CHECK: air.channel @channel_0 [2, 2]
// CHECK: air.channel @channel_1 [2, 2]
// CHECK: air.channel @channel_2 [2, 2]
// CHECK: air.channel @channel_3 [2, 2]
// CHECK-LABEL: func.func @mmult
func.func @mmult(%arg0: memref<64x64xi32>, %arg1: memref<64x64xi32>) -> memref<64x64xi32> {
%c2 = arith.constant 2 : index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
#map = affine_map<()[s0] -> (s0 * 64)>
#map1 = affine_map<()[s0] -> (s0 * 32)>
module {
// CHECK: air.channel @channel_1 [2, 2]
// CHECK: air.channel @channel_0 [1, 1]
// CHECK: air.channel @channel_1 [2, 2]
// CHECK-LABEL: func.func @mmult
func.func @mmult(%arg0: memref<512x512xbf16>) {
%c8 = arith.constant 8 : index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

#map = affine_map<()[s0] -> (s0 * 32)>
module {
// CHECK: air.channel @channel_1 [2, 2]
// CHECK: air.channel @channel_0 [2, 2]
// CHECK: air.channel @channel_1 [2, 2]
func.func @matmul_on_buffers(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>, %arg2: memref<64x64xf32>) {
%c2 = arith.constant 2 : index
// CHECK: %[[EVENT0:.*]] = scf.parallel (%[[VALUE0:.*]], %[[VALUE1:.*]]) ={{.*}}init
Expand Down
Loading

0 comments on commit a913d16

Please sign in to comment.