Skip to content

Commit

Permalink
changes with add/dec ref and lastuse ops
Browse files Browse the repository at this point in the history
  • Loading branch information
drprajap committed Feb 24, 2025
1 parent 11aa54f commit a05ce5f
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 73 deletions.
8 changes: 4 additions & 4 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1879,11 +1879,11 @@ def Stream_ResourceAddRefOp : Stream_Op<"async.resource.add_ref"> {

let arguments = (ins
Stream_AnyStreamResource:$resource,
ConfinedAttr<I64Attr, [IntPositive]>:$count
I64:$count
);

let assemblyFormat = [{
$resource attr-dict `:` type($resource)
$resource attr-dict `:` type($resource) `,` $count
}];

}
Expand All @@ -1897,11 +1897,11 @@ def Stream_ResourceDecRefOp : Stream_Op<"async.resource.dec_ref"> {

let arguments = (ins
Stream_AnyStreamResource:$resource,
ConfinedAttr<I64Attr, [IntPositive]>:$count
I64:$count
);

let assemblyFormat = [{
$resource attr-dict `:` type($resource)
$resource attr-dict `:` type($resource) `,` $count
}];

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ class LivenessAnalysis {

llvm::DenseMap<Value, int> refCountMap;
llvm::DenseMap<Value, Operation *> lastUseMap;
void insertDealloca(Value, Operation *);
llvm::DenseMap<Value, Value> resourceTimepoints;

// on alloca, find all users and add reference counts
// at each execute region block resource is used
Expand All @@ -367,86 +367,149 @@ void insertDealloca(Value, Operation *);
static void performRefCountingInRegion(Region &region,
LivenessAnalysis &analysis) {

region.walk([&](Operation *op) {
TypeSwitch<Operation *>(op)
.Case<IREE::Stream::ResourceAllocaOp>([&](auto allocaOp) {
auto resource =
cast<IREE::Stream::ResourceAllocaOp>(op)->getResult(0);
// Adds refCounting
if (!resource.getUses().empty()) {
refCountMap.insert(std::make_pair(resource, 0));
for (Operation *user : resource.getUsers()) {
refCountMap[resource]++;
if (analysis.isLastUser(resource, user))
lastUseMap[resource] = user;
}
}
})
.Case<IREE::Stream::ResourceDeallocaOp>([&](auto deallocaOp) {
auto resource =
cast<IREE::Stream::ResourceDeallocaOp>(op)->getResult(0);
// region.walk([&](Operation *op) {
for (auto &block : region) {
block.walk([&](Operation *op) {
return TypeSwitch<Operation *, WalkResult>(op)
.Case<IREE::Stream::ResourceAllocaOp>([&](auto allocaOp) {
AsmState asmState(op->getParentOp());

auto resource =
cast<IREE::Stream::ResourceAllocaOp>(op)->getResult(0);

LLVM_DEBUG({
llvm::dbgs() << "\n ----- resource: ";
resource.printAsOperand(llvm::dbgs(), asmState);
});
// Adds refCounting
if (!resource.getUses().empty()) {
SmallVector<Value> joinTimepoints;

refCountMap.insert(std::make_pair(resource, 0));
for (Operation *user : resource.getUsers()) {
if (isa<IREE::Stream::ResourceDeallocaOp,
IREE::Stream::ResourceAddRefOp,
IREE::Stream::ResourceDecRefOp>(user)) {
return WalkResult::advance();
}

refCountMap[resource]--;
LLVM_DEBUG({
llvm::dbgs() << "\n last user is dealloca, just decrement refcount "
"and do nothing: "
<< resource;
});
})
.Case<IREE::Stream::CmdExecuteOp>([&](auto executeOp) {
for (auto operand : executeOp->getOperands()) {
if (refCountMap.count(operand)) {
// Keep analysis check here before the IR mutates by add/dec ref
// Ops
if (analysis.isLastUser(resource, user)) {
OpBuilder builder(user);
auto loc = user->getLoc();
builder.setInsertionPointAfter(user);
auto lastUseOp =
builder.create<IREE::Stream::ResourceLastUseOp>(loc,
resource);
lastUseMap.insert(
std::make_pair(resource, lastUseOp.getOperation()));
}
if (auto execop = dyn_cast<IREE::Stream::CmdExecuteOp>(*user)) {
joinTimepoints.push_back(execop.getResultTimepoint());
LLVM_DEBUG({
llvm::dbgs() << "\n jointimepoint: ";
execop.getResultTimepoint().printAsOperand(llvm::dbgs(),
asmState);
});
}

// Decrement reference after last use
if (analysis.isLastUser(operand, executeOp)) {
lastUseMap.insert(std::make_pair(operand, executeOp));
OpBuilder builder(user);
auto loc = user->getLoc();
refCountMap[resource]++;

if (refCountMap[operand] > 0) {
refCountMap[operand]--;
}
builder.setInsertionPoint(user);

if (refCountMap[operand] == 0) {
insertDealloca(operand, executeOp);
Value countVal = builder.create<arith::ConstantIntOp>(
loc, refCountMap[resource], 64);
Value one = builder.create<arith::ConstantIntOp>(loc, 1, 64);

builder.create<IREE::Stream::ResourceAddRefOp>(loc, resource,
countVal);

builder.setInsertionPointAfter(user);
auto Val =
builder.createOrFold<arith::SubIOp>(loc, countVal, one);

builder.create<IREE::Stream::ResourceDecRefOp>(loc, resource,
Val);
}

{
if (lastUseMap.count(resource)) {
auto lastUseOp = cast<IREE::Stream::ResourceLastUseOp>(
lastUseMap[resource]);
OpBuilder builder(lastUseOp);
builder.setInsertionPoint(lastUseOp);

Value newAwaitTimepoint = IREE::Stream::TimepointJoinOp::join(
joinTimepoints, builder);

resourceTimepoints.insert(
std::make_pair(resource, newAwaitTimepoint));
}
} else {
refCountMap[operand]--;
}
}
}
});
});
return WalkResult::advance();
})
.Default([&](auto *op) { return WalkResult::advance(); });
});
}
}

void insertDealloca(Value resource, Operation *lastuseOp) {
auto refCount = refCountMap[resource];
auto definingOp = resource.getDefiningOp();
void insertDealloca(Region &region) {
OpBuilder builder(region);

region.walk([&](Operation *op) {
TypeSwitch<Operation *>(op)

auto allocaOp = cast<IREE::Stream::ResourceAllocaOp>(definingOp);
.Case<IREE::Stream::ResourceDecRefOp>([&](auto decRefOp) {
AsmState asmState(op->getParentOp());
auto resource = decRefOp.getResource();
LLVM_DEBUG({
llvm::dbgs() << "\n ----- resource: ";
resource.printAsOperand(llvm::dbgs(), asmState);
});

OpBuilder builder(lastuseOp);
builder.setInsertionPointAfter(lastuseOp);
auto loc = lastuseOp->getLoc();
auto timepoint = allocaOp.getResultTimepoint();
if (auto op = dyn_cast<IREE::Stream::CmdExecuteOp>(lastuseOp)) {
timepoint = op.getResultTimepoint();
}
auto loc = decRefOp.getLoc();
builder.setInsertionPoint(decRefOp);

auto allocaOp =
cast<IREE::Stream::ResourceAllocaOp>(resource.getDefiningOp());
Value newTimepoint = resourceTimepoints[resource];
if (lastUseMap.count(resource)) {
auto lastUseOp =
cast<IREE::Stream::ResourceLastUseOp>(lastUseMap[resource]);
loc = lastUseOp->getLoc();
builder.setInsertionPoint(lastUseOp);
}

Type i32Type = builder.getIntegerType(32);
Value countVal = builder.create<arith::ConstantOp>(
loc, builder.getIntegerAttr(i32Type, refCount));
Value zero = builder.create<arith::ConstantOp>(
loc, builder.getIntegerAttr(i32Type, 0));

auto cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
countVal, zero);
builder.create<mlir::scf::IfOp>(
loc, cond, [&](OpBuilder &builder, Location loc) {
builder.create<IREE::Stream::ResourceDeallocaOp>(
loc, resource, allocaOp.getResultSize(0), timepoint,
allocaOp.getAffinityAttr());

builder.create<mlir::scf::YieldOp>(loc);
});
Value zero = builder.create<arith::ConstantIntOp>(loc, 0, 64);

auto cond = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, decRefOp.getCount(), zero);

builder.create<mlir::scf::IfOp>(
loc, cond, [&](OpBuilder &builder, Location loc) {
builder.create<IREE::Stream::ResourceDeallocaOp>(
loc, resource, allocaOp.getResultSize(0), newTimepoint,
allocaOp.getAffinityAttr());

builder.create<mlir::scf::YieldOp>(loc);
});
});
});

// TODO: erase only when needed ?
region.walk([&](Operation *op) {
TypeSwitch<Operation *>(op)
.Case<IREE::Stream::ResourceAddRefOp>(
[&](auto addRefOp) { addRefOp.erase(); })
.Case<IREE::Stream::ResourceDecRefOp>(
[&](auto decRefOp) { decRefOp.erase(); })
.Case<IREE::Stream::ResourceLastUseOp>(
[&](auto lastUseOp) { lastUseOp.erase(); });
});
}

// This operates using a whole-program analysis to track reference counts of a
Expand Down Expand Up @@ -480,6 +543,8 @@ struct ResourceRefCountingPass
// adding reference count for each user, and
// insert dealloca when refcount reaches zero
performRefCountingInRegion(*region, analysis);

insertDealloca(*region);
}
}
};
Expand Down

0 comments on commit a05ce5f

Please sign in to comment.