Skip to content

Commit

Permalink
AIRBcastDetection: Remove faulty logic leading to overlooked broadcas…
Browse files Browse the repository at this point in the history
…ting opportunities (#755)

* Trace iv use through non-async ssa values only

* Remove faulty logic in dependency detection which led to overlooked bcast cases

* The previous test overlooked a broadcasting opportunity; test now updated
  • Loading branch information
erwei-xilinx authored Oct 29, 2024
1 parent 2ff4b89 commit 5131fd2
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 34 deletions.
15 changes: 0 additions & 15 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2459,21 +2459,6 @@ struct BroadcastDetection {
}
}
}
// If not variant wrt herd, then check for fixed row-wise or col-wise
// offset.
int src_memspace = llvm::cast<MemRefType>(dma_op.getSrcMemref().getType())
.getMemorySpaceAsInt();
auto externalOffsets = src_memspace == (int)air::MemorySpace::L1
? dma_op.getDstOffsets()
: dma_op.getSrcOffsets();
if (!hl_op && externalOffsets.size() ==
dma_op->getParentOfType<air::HerdOp>().getNumDims()) {
hl_op = dma_op->getParentOfType<air::HerdOp>();
if (getConstantIntValue(externalOffsets[0]))
isVariantWrtHerdRows = true;
if (getConstantIntValue(externalOffsets[1]))
isVariantWrtHerdCols = true;
}

if (!hl_op) {
// If dma op is completely independent of the parent herd induction
Expand Down
37 changes: 19 additions & 18 deletions mlir/lib/Util/Dependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ void traceDependentInductionVar(SmallVector<Value, 1> candidate_scalar_operands,
SmallVector<Value, 1> &loop_dep_history,
std::vector<Operation *> &op_history) {
for (auto operand : candidate_scalar_operands) {
if (!llvm::isa<IndexType>(operand.getType()))
continue; // Only tracing scalar operands
// If parent loop op is an scf.for
if (auto for_op = mlir::scf::getForInductionVarOwner(operand)) {
loop_dep_history.push_back(for_op.getInductionVar());
Expand All @@ -76,27 +78,26 @@ void traceDependentInductionVar(SmallVector<Value, 1> candidate_scalar_operands,

// Recursively trace dependency to loop induction vars
for (auto operand : candidate_scalar_operands) {
if (operand && llvm::isa<IndexType>(
operand.getType())) { // Only tracing scalar operands
if (operand.getDefiningOp() &&
mlir::dyn_cast<air::AsyncOpInterface>(operand.getDefiningOp())) {
auto ancestor_async_op =
dyn_cast<air::AsyncOpInterface>(operand.getDefiningOp());
op_history.push_back(ancestor_async_op.getOperation());
traceDependentInductionVar(ancestor_async_op, loop_dep_history,
op_history);
} else {
// Trace dependency through a for loop
if (auto for_op = getForRegionIterArgsOwner(operand)) {
for (auto iter_arg : for_op.getInitArgs()) {
if (operand == iter_arg) {
loop_dep_history.push_back(iter_arg);
}
if (!llvm::isa<IndexType>(operand.getType()))
continue; // Only tracing scalar operands
if (operand.getDefiningOp() &&
mlir::dyn_cast<air::AsyncOpInterface>(operand.getDefiningOp())) {
auto ancestor_async_op =
dyn_cast<air::AsyncOpInterface>(operand.getDefiningOp());
op_history.push_back(ancestor_async_op.getOperation());
traceDependentInductionVar(ancestor_async_op, loop_dep_history,
op_history);
} else {
// Trace dependency through a for loop
if (auto for_op = getForRegionIterArgsOwner(operand)) {
for (auto iter_arg : for_op.getInitArgs()) {
if (operand == iter_arg) {
loop_dep_history.push_back(iter_arg);
}
}
// Trace dependency through a parallel loop
// TODO: decide if parallel should exist in herd launch
}
// Trace dependency through a parallel loop
// TODO: decide if parallel should exist in herd launch
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func.func @matmul(%arg0: memref<512x512xbf16>, %arg1: memref<512x512xbf16>, %arg

// -----

// CHECK: [[$SET0:#set[0-9]*]] = affine_set<(d0, d1)[s0] : (d0 >= 0, -d0 + 3 >= 0, d1 - s0 == 0, s0 >= 0, -s0 + 3 >= 0)>
// CHECK: [[$SET0:#set[0-9]*]] = affine_set<(d0, d1)[s0] : (d0 >= 0, -d0 + 3 >= 0, d1 >= 0, -d1 + 3 >= 0, s0 >= 0, -s0 >= 0)>
// CHECK-LABEL: func.func @func0
// CHECK: %[[EVENT0:.*]] = air.dma_memcpy_nd {{.*}} {id = 1 : i32} : (memref<256x64xbf16, 1>, memref<1024x256xbf16>)
// CHECK: %[[EVENT1:.*]] = air.dma_memcpy_nd {{.*}}broadcast_pattern = [[$SET0]]{{.*}}
Expand Down

0 comments on commit 5131fd2

Please sign in to comment.