Skip to content

Commit

Permalink
AIRSplitL2Memref: Expand the pass to work with air.herds peeled from …
Browse files Browse the repository at this point in the history
…for loops (#587)

* Extend traceDependentInductionVar method to channels

* Fixup an issue where L2 memref tiling fails with high-rank offsets and memref

* Fixup issues caused by lacking dimension conversion between offset dim and memref dim

* Make affine.apply canonicalizer aware of parent execute op; reimplement hasNElements which rules out no-ops when detecting perfect loop nests

* Attempt to match the memref rank on both sides of air.channel, before tiling for L2 memref splitting

* Tests
  • Loading branch information
erwei-xilinx authored May 31, 2024
1 parent 1f0385e commit 60d20cf
Show file tree
Hide file tree
Showing 6 changed files with 435 additions and 139 deletions.
8 changes: 6 additions & 2 deletions mlir/include/air/Util/Dependency.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ namespace air {
//===----------------------------------------------------------------------===//

bool areEqualIndices(mlir::Value index_0, mlir::Value index_1);
void traceDependentInductionVar(air::DmaMemcpyNdOp async_op,
void traceDependentInductionVar(SmallVector<Value, 1> candidate_scalar_operands,
SmallVector<Value, 1> &loop_dep_history,
std::vector<Operation *> &op_history);
void traceDependentInductionVar(air::MemcpyInterface memcpyif_op,
SmallVector<Value, 1> &loop_dep_history,
std::vector<Operation *> &op_history);
void traceDependentInductionVar(air::AsyncOpInterface async_op,
Expand Down Expand Up @@ -365,7 +368,8 @@ class dependencyTracer {

private:
// Trace the defining op of sink op, RAW
template <typename T> void traceDefiningOpAsDep(Value operand, T op) {
template <typename T>
void traceDefiningOpAsDep(Value operand, T op) {
// Check memref deps
if (auto defop = operand.getDefiningOp<air::ExecuteOp>()) {
// addNewAsyncDepToGraph<T>(defop.getResult(0), op);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/ConvertToAIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ class LinalgCopyToAIRDmaConversion : public OpRewritePattern<linalg::CopyOp> {
}
};

unsigned getScfParDimIdFromBCastDma(air::DmaMemcpyNdOp memcpyOp) {
unsigned getScfParDimIdFromBCastDma(air::MemcpyInterface memcpyOp) {
// Get all ops on the dependency connection between dma and herd launch
SmallVector<Value, 1> loop_dep_history;
std::vector<Operation *> op_history;
Expand Down Expand Up @@ -769,7 +769,7 @@ void replaceAIRDmaWithAIRChannelPairs(
op->getAttrOfType<mlir::IntegerSetAttr>("broadcast_pattern").getValue();
air::getSizesFromIntegerSet(ctx, int_set, lbs_int, ubs_int);
SmallVector<int64_t, 2> channel_sizes = {1, 1};
channel_sizes[getScfParDimIdFromBCastDma(dyn_cast<air::DmaMemcpyNdOp>(
channel_sizes[getScfParDimIdFromBCastDma(dyn_cast<air::MemcpyInterface>(
op.getOperation()))] = ubs_int[0] - lbs_int[0] + 1;
auto channel_op =
createChannelOpWithBCast(builder, module, cname, loc, channel_sizes);
Expand Down
65 changes: 38 additions & 27 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,9 @@ struct CanonicalizeAffineApplyOnLoopInductionVar
return failure();
if (apply.getResult().use_empty())
return failure();
if (auto exec_apply = dyn_cast<air::ExecuteOp>(apply->getParentOp()))
if (exec_apply->getResult(1).use_empty())
return failure();
auto *containingOp = ivArg.getOwner()->getParentOp();

// Apply affine map to loop step and bound
Expand Down Expand Up @@ -1667,10 +1670,15 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor

// Check if the loop is the outermost loop in a perfect loop nest
auto hasNElements = [](Block *block, unsigned N) {
auto op_ptr = block->begin();
for (unsigned i = 0; i < N; i++)
op_ptr = std::next(op_ptr);
return op_ptr != block->end() && &*op_ptr == &block->back();
unsigned counter = 0;
for (auto &o : block->getOperations()) {
if (o.mightHaveTrait<OpTrait::IsTerminator>())
continue;
if (isa<air::WaitAllOp>(o))
continue;
counter++;
}
return counter == N;
};
if (auto parent_for = dyn_cast<scf::ForOp>(for_op->getParentOp()))
if (hasNElements(parent_for.getBody(), 1))
Expand Down Expand Up @@ -1776,10 +1784,15 @@ struct AIRSpecializeChannelWrapAndStrideInAffineFor

// Check if the loop is the outermost loop in a perfect loop nest
auto hasNElements = [](Block *block, unsigned N) {
auto op_ptr = block->begin();
for (unsigned i = 0; i < N; i++)
op_ptr = std::next(op_ptr);
return op_ptr != block->end() && &*op_ptr == &block->back();
unsigned counter = 0;
for (auto &o : block->getOperations()) {
if (o.mightHaveTrait<OpTrait::IsTerminator>())
continue;
if (isa<air::WaitAllOp>(o))
continue;
counter++;
}
return counter == N;
};
if (auto parent_for = dyn_cast<affine::AffineForOp>(for_op->getParentOp()))
if (hasNElements(parent_for.getBody(), 1))
Expand Down Expand Up @@ -2115,24 +2128,22 @@ struct BroadcastDetection {
public:
// Trace dma ops' dependency to loop induction variables
void getDmaOpLoopDependency(func::FuncOp f) {
f.walk([&](Operation *op) {
if (auto dma_op = mlir::dyn_cast<xilinx::air::DmaMemcpyNdOp>(op)) {
int src_memspace =
llvm::cast<MemRefType>(dma_op.getSrcMemref().getType())
.getMemorySpaceAsInt();
int dst_memspace =
llvm::cast<MemRefType>(dma_op.getDstMemref().getType())
.getMemorySpaceAsInt();
bool isL1Memcpy = (src_memspace == (int)air::MemorySpace::L1) ||
(dst_memspace == (int)air::MemorySpace::L1);
if (dma_op->getParentOfType<xilinx::air::HerdOp>() && isL1Memcpy) {
// Start recursively tracing for loop induction variables
dma_op_history.push_back(dma_op);
SmallVector<Value, 1> loop_dep_history;
std::vector<Operation *> op_history;
traceDependentInductionVar(dma_op, loop_dep_history, op_history);
dma_op_loop_dep_history.push_back(loop_dep_history);
}
f.walk([&](MemcpyInterface memcpyif_op) {
int src_memspace =
llvm::cast<MemRefType>(memcpyif_op.getSrcMemref().getType())
.getMemorySpaceAsInt();
int dst_memspace =
llvm::cast<MemRefType>(memcpyif_op.getDstMemref().getType())
.getMemorySpaceAsInt();
bool isL1Memcpy = (src_memspace == (int)air::MemorySpace::L1) ||
(dst_memspace == (int)air::MemorySpace::L1);
if (memcpyif_op->getParentOfType<xilinx::air::HerdOp>() && isL1Memcpy) {
// Start recursively tracing for loop induction variables
dma_op_history.push_back(memcpyif_op);
SmallVector<Value, 1> loop_dep_history;
std::vector<Operation *> op_history;
traceDependentInductionVar(memcpyif_op, loop_dep_history, op_history);
dma_op_loop_dep_history.push_back(loop_dep_history);
}
});
}
Expand Down Expand Up @@ -2222,7 +2233,7 @@ struct BroadcastDetection {

private:
// DMA dependency to loop induction variables
std::vector<air::DmaMemcpyNdOp> dma_op_history;
std::vector<MemcpyInterface> dma_op_history;
SmallVector<SmallVector<Value, 1>, 1> dma_op_loop_dep_history;
};

Expand Down
98 changes: 51 additions & 47 deletions mlir/lib/Transform/AIRMiscPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1087,15 +1087,6 @@ int findGCD(SmallVector<int> vec) {
return result;
}

// Check if an air.channel is single-consumer-single-producer.
bool hasSinglePutAndGet(air::ChannelOp chan) {
auto puts =
getChannelPutOpThroughSymbol(chan, chan->getParentOfType<ModuleOp>());
auto gets =
getChannelGetOpThroughSymbol(chan, chan->getParentOfType<ModuleOp>());
return puts.size() == 1 && gets.size() == 1;
}

// Tile air.channel put/get wrt a memref.
Value tileChannelOpByFactor(air::ChannelInterface originalChanOp, int factor,
int originalMemrefSize, int dim,
Expand Down Expand Up @@ -1389,23 +1380,41 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs(
targetMemrefs.push_back(allocOp);
allocOp->setAttr("split", BoolAttr::get(ctx, true));
allocOp->setAttr("split_type", StringAttr::get(ctx, "scf.parallel"));
if (lbs_spatial.size() == 1) {
// If scf.parallel has less dims than the memref, i.e. partial
// unrolling, then label the dim.
int unrollDim = 0;
for (auto index : chanOp.getIndices()) {
if (auto indexOwner =
scf::getParallelForInductionVarOwner(index)) {
if (indexOwner == parentParOp) {
allocOp->setAttr(
"split_dim",
IntegerAttr::get(IntegerType::get(ctx, 32), unrollDim));
break;
}
}
unrollDim++;

// Label the dimension on memref to be targetted for splitting, by
// analyzing the scf.parallel access pattern (as dependence from
// offsets to induction variables).
int splitDim = -1;
for (unsigned i = 0; i < chanOp.getOffsets().size(); i++) {
SmallVector<Value, 1> candidateOperands(1, chanOp.getOffsets()[i]);
SmallVector<Value, 1> loop_dep_history;
std::vector<Operation *> op_history;
air::traceDependentInductionVar(candidateOperands, loop_dep_history,
op_history);
for (auto v : loop_dep_history) {
if (scf::getParallelForInductionVarOwner(v) != parentParOp)
continue;
auto memrefDim = air::getMemrefDimFromOffsetDim(
i, chanOp.getOffsets(), chanOp.getStrides(),
air::getTensorShape(memref.getType()));
if (!memrefDim)
continue;
splitDim = *memrefDim;
break;
}
if (splitDim >= 0)
break;
}

if (allocOp->hasAttr("split_dim"))
assert(allocOp->getAttrOfType<IntegerAttr>("split_dim").getInt() ==
splitDim &&
"L2 memref tiled inconsistently across multiple data access "
"patterns. Cannot infer L2 memref tiling strategy.");
else if (splitDim >= 0)
allocOp->setAttr(
"split_dim",
IntegerAttr::get(IntegerType::get(ctx, 32), splitDim));
}
// Tiling along the first (x) dimension of scf.parallel only, as one NPU
// memtile is located at the bottom of each column.
Expand Down Expand Up @@ -1498,16 +1507,10 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() {
if (!isa<air::ChannelInterface>(user))
continue;
auto chanUserOp = dyn_cast<air::ChannelInterface>(user);
auto chanUserChannelDeclr =
air::getChannelDeclarationThroughSymbol(chanUserOp);
if (!hasSinglePutAndGet(chanUserChannelDeclr)) {
assert(false && "NYI");
} else if (auto par = user->getParentOfType<scf::ParallelOp>()) {
if (auto par = user->getParentOfType<scf::ParallelOp>()) {
// Case 1: Parallel access to the memref represented with scf.parallel
// op. Data access specialization method: unroll the scf.parallel
// loop.
SmallVector<int, 2> lbs_spatial, ubs_spatial;
air::getSizesFromSpatialLoop(par, lbs_spatial, ubs_spatial);
OpBuilder builder(par);
IRMapping remap;
(void)air::unrollAIRChannelPutGetInScfParallel(builder, par, user,
Expand Down Expand Up @@ -1543,19 +1546,12 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() {
int dim = 0;
if (allocOp->hasAttr("split_dim"))
dim = allocOp->getAttrOfType<IntegerAttr>("split_dim").getInt();
for (unsigned i = 0; i < memrefShape.size(); i++) {
if (chanUserOp.getOffsets().empty())
break;
int offsetDim =
chanUserOp.getOffsets().size() - memrefShape.size() + i;
if (getConstantIntValue(chanUserOp.getOffsets()[offsetDim])) {
dim = i;
break;
}
}
auto newWaitAll =
tileChannelOpByFactor(chanUserOp, targetColTilingFactor,
memrefShape[dim], dim, new_chan, loc, ctx);
auto offsetDimOpt = air::getOffsetDimFromMemrefDim(
dim, chanUserOp.getStrides(), memrefShape);
int offsetDim = offsetDimOpt ? *offsetDimOpt : dim;
auto newWaitAll = tileChannelOpByFactor(
chanUserOp, targetColTilingFactor, memrefShape[dim], offsetDim,
new_chan, loc, ctx);

// Update async dependency.
auto old_token =
Expand All @@ -1566,9 +1562,17 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() {
// Update the other channel op of the chanUserChannelDeclr.
auto theOtherChanOp =
air::getTheOtherChannelOpThroughSymbol(chanUserOp);
Value newWaitAll1 =
tileChannelOpByFactor(theOtherChanOp[0], targetColTilingFactor,
memrefShape[dim], dim, new_chan, loc, ctx);
// Note: if the memref on the other side of the air channel has
// different rank, then we check if ranks can be matched after leading
// singleton dimensions are removed. If the ranks still do not match,
// then the behaviour is unstable.
int numLeadingSingletonDims = 0;
for (auto memrefDim : memrefShape)
if (memrefDim == 1)
numLeadingSingletonDims++;
Value newWaitAll1 = tileChannelOpByFactor(
theOtherChanOp[0], targetColTilingFactor, memrefShape[dim],
dim - numLeadingSingletonDims, new_chan, loc, ctx);

// Update dependency.
auto oldToken =
Expand Down
92 changes: 31 additions & 61 deletions mlir/lib/Util/Dependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,20 @@ bool areEqualIndices(mlir::Value index_0, mlir::Value index_1) {
}
}

// Recursively check for dependency to loop induction vars arising from dma src
void traceDependentInductionVar(air::DmaMemcpyNdOp dmaNd_op,
void traceDependentInductionVar(SmallVector<Value, 1> candidate_scalar_operands,
SmallVector<Value, 1> &loop_dep_history,
std::vector<Operation *> &op_history) {
// Check for immediate dependency to loop induction vars
SmallVector<Value, 1> candidate_scalar_operands;
for (unsigned i = 0; i < dmaNd_op.getSrcOffsets().size(); i++) {
candidate_scalar_operands.push_back(dmaNd_op.getSrcOffsets()[i]);
candidate_scalar_operands.push_back(dmaNd_op.getSrcSizes()[i]);
candidate_scalar_operands.push_back(dmaNd_op.getSrcStrides()[i]);
}
for (auto operand : candidate_scalar_operands) {
// If parent loop op is an scf.for
if (auto for_op = mlir::scf::getForInductionVarOwner(operand)) {
loop_dep_history.push_back(for_op.getInductionVar());
}
// TODO: Assuming that src.parallel won't exist under herd launch
// If parent loop op is an scf.parallel
if (auto par_op = mlir::scf::getParallelForInductionVarOwner(operand)) {
for (auto ind_var : par_op.getInductionVars())
if (ind_var == operand)
loop_dep_history.push_back(ind_var);
}

// If parent loop op is an air.launch_herd
if (auto hl_op = getHerdArgOwner(operand)) {
Expand Down Expand Up @@ -104,6 +100,30 @@ void traceDependentInductionVar(air::DmaMemcpyNdOp dmaNd_op,
}
}

// Recursively check for dependency to loop induction vars arising from dma
void traceDependentInductionVar(air::MemcpyInterface memcpyif_op,
SmallVector<Value, 1> &loop_dep_history,
std::vector<Operation *> &op_history) {
// Check for immediate dependency to loop induction vars
SmallVector<Value, 1> candidate_scalar_operands;
if (memcpyif_op.getSrcMemref()) {
for (unsigned i = 0; i < memcpyif_op.getSrcOffsets().size(); i++) {
candidate_scalar_operands.push_back(memcpyif_op.getSrcOffsets()[i]);
candidate_scalar_operands.push_back(memcpyif_op.getSrcSizes()[i]);
candidate_scalar_operands.push_back(memcpyif_op.getSrcStrides()[i]);
}
}
if (memcpyif_op.getDstMemref()) {
for (unsigned i = 0; i < memcpyif_op.getDstOffsets().size(); i++) {
candidate_scalar_operands.push_back(memcpyif_op.getDstOffsets()[i]);
candidate_scalar_operands.push_back(memcpyif_op.getDstSizes()[i]);
candidate_scalar_operands.push_back(memcpyif_op.getDstStrides()[i]);
}
}
traceDependentInductionVar(candidate_scalar_operands, loop_dep_history,
op_history);
}

// Recursively check for dependency to any loop induction vars
void traceDependentInductionVar(air::AsyncOpInterface async_op,
SmallVector<Value, 1> &loop_dep_history,
Expand All @@ -123,57 +143,7 @@ void traceDependentInductionVar(air::AsyncOpInterface async_op,
} else {
op = async_op.getOperation();
}

// Check for immediate dependency to loop induction vars
for (auto operand : op->getOperands()) {
// If parent loop op is an scf.for
if (auto for_op = mlir::scf::getForInductionVarOwner(operand)) {
loop_dep_history.push_back(for_op.getInductionVar());
}
// If parent loop op is an scf.parallel
if (auto parallel_op =
mlir::scf::getParallelForInductionVarOwner(operand)) {
for (auto induction_var : parallel_op.getInductionVars()) {
if (operand == induction_var) {
loop_dep_history.push_back(induction_var);
}
}
}
// If parent loop op is an air.launch_herd
if (auto hl_op = getHerdArgOwner(operand)) {
for (auto id : hl_op.getIds()) {
if (operand == id) {
loop_dep_history.push_back(id);
}
}
}
}

// Recursively trace dependency to loop induction vars
for (auto operand : op->getOperands()) {
if (operand && llvm::isa<IndexType>(
operand.getType())) { // Only tracing scalar operands
if (operand.getDefiningOp() &&
mlir::dyn_cast<air::AsyncOpInterface>(operand.getDefiningOp())) {
auto ancestor_async_op =
dyn_cast<air::AsyncOpInterface>(operand.getDefiningOp());
op_history.push_back(ancestor_async_op.getOperation());
traceDependentInductionVar(ancestor_async_op, loop_dep_history,
op_history);
} else {
// Trace dependency through a for loop
if (auto for_op = getForRegionIterArgsOwner(operand)) {
for (auto iter_arg : for_op.getInitArgs()) {
if (operand == iter_arg) {
loop_dep_history.push_back(iter_arg);
}
}
}
// Trace dependency through a parallel loop
// TODO: decide if parallel should exist in herd launch
}
}
}
traceDependentInductionVar(op->getOperands(), loop_dep_history, op_history);
}

// Recursively check for dependency to any air.herd induction variables.
Expand Down
Loading

0 comments on commit 60d20cf

Please sign in to comment.