Skip to content

Commit

Permalink
Support cloning operand-less dispatches
Browse files Browse the repository at this point in the history
Any dispatch that does not consume a resource can be placed on any
device. If so we can clone this to every device to avoid transferring or
access across device memory.
  • Loading branch information
rsuderman committed Feb 18, 2025
1 parent 33a770e commit 5d59201
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 17 deletions.
20 changes: 5 additions & 15 deletions compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,21 +728,11 @@ ChangeStatus OpAffinityPVS::updateOperation(Operation *op,
op->getOperandTypes(), +[](Type type) {
return isa<IREE::Stream::AffinityTypeInterface>(type);
});
if (consumesAny) {
for (auto operand : op->getOperands()) {
if (isa<IREE::Stream::AffinityTypeInterface>(operand.getType())) {
auto valuePVS = solver.getElementFor<ValueProducerAffinityPVS>(
*this, Position::forValue(operand), DFX::Resolution::REQUIRED);
newState ^= valuePVS;
}
}
} else {
for (auto result : op->getResults()) {
if (isa<IREE::Stream::AffinityTypeInterface>(result.getType())) {
auto valuePVS = solver.getElementFor<ValueConsumerAffinityPVS>(
*this, Position::forValue(result), DFX::Resolution::REQUIRED);
newState ^= valuePVS;
}
for (auto operand : op->getOperands()) {
if (isa<IREE::Stream::AffinityTypeInterface>(operand.getType())) {
auto valuePVS = solver.getElementFor<ValueProducerAffinityPVS>(
*this, Position::forValue(operand), DFX::Resolution::REQUIRED);
newState ^= valuePVS;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,12 @@ struct ConvertDispatchOp
llvm::make_filter_range(op->getDialectAttrs(), [](NamedAttribute attr) {
return attr.getName() != "stream.affinity";
}));

// If cloned we should ignore the selected affinity attr.
if (newOp.preferCloneToConsumers()) {
newOp.setAffinityAttr(nullptr);
}

SmallVector<SmallVector<Value>> replacementsVec = llvm::map_to_vector(
llvm::zip_equal(newOp->getResults(), resultSizes), [](auto it) {
return SmallVector<Value>{std::get<0>(it), std::get<1>(it)};
Expand Down
18 changes: 18 additions & 0 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2183,6 +2183,15 @@ verifyDispatchSymbolUses(Operation *op, ArrayAttr entryPointsAttr,
return success();
}

bool TensorDispatchOp::preferCloneToConsumers() {
for (auto operand : getMixedOperands()) {
if (isa<Stream::ResourceType>(operand.getType())) {
return false;
}
}
return true;
}

LogicalResult
TensorDispatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return verifyDispatchSymbolUses(getOperation(), getEntryPointsAttr(),
Expand Down Expand Up @@ -2644,6 +2653,15 @@ static void printDispatchOperands(OpAsmPrinter &p, Operation *op,
p << ")";
}

bool AsyncDispatchOp::preferCloneToConsumers() {
for (auto operand : getResourceOperands()) {
if (isa<Stream::ResourceType>(operand.getType())) {
return false;
}
}
return true;
}

LogicalResult AsyncDispatchOp::verify() {
AsyncDispatchOp op = *this;
if (failed(verifyOpValueSizes(op, op.getResourceOperands(),
Expand Down
8 changes: 6 additions & 2 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1744,7 +1744,9 @@ def Stream_TensorDispatchOp : Stream_Op<"tensor.dispatch", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
Stream_AffinityOp,
Stream_TensorPhaseOp,
Stream_StreamableOp,
DeclareOpInterfaceMethods<Stream_StreamableOp, [
"preferCloneToConsumers",
]>,
Util_SizeAwareOp,
DeclareOpInterfaceMethods<Util_TiedOpInterface, [
"getTiedOperandsIndexAndLength",
Expand Down Expand Up @@ -2493,7 +2495,9 @@ def Stream_AsyncDispatchOp : Stream_Op<"async.dispatch", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
Stream_AffinityOp,
Stream_AsyncPhaseOp,
Stream_StreamableOp,
DeclareOpInterfaceMethods<Stream_StreamableOp, [
"preferCloneToConsumers",
]>,
DeclareOpInterfaceMethods<Stream_AsyncAccessOp, [
"getAsyncAccessRanges",
]>,
Expand Down

0 comments on commit 5d59201

Please sign in to comment.