Skip to content

Commit

Permalink
Fix NVIDIA#781 Loop unrolling with break statement and mid-circuit me…
Browse files Browse the repository at this point in the history
…asurement.

Due to an interaction between the revised memtoreg pass and loop
unrolling, the test with a mid-circuit measurement and a break statement
still wasn't being unrolled. These changes fix that issue.

Primarily, these changes allow structured operations that accept region
arguments to straddle the fence and allow some values to be promoted as
dominating uses (when only used) while other values to be threaded as
region arguments exactly as before (when written). This change
simplifies the register-semantics IR. There may be a bit of performance
lossage in memtoreg as a result however as promoted values may later be
discovered to be written and the changes will need to be reapplied.
  • Loading branch information
schweitzpgi committed Oct 17, 2023
1 parent b431631 commit c5e3189
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 75 deletions.
4 changes: 3 additions & 1 deletion include/cudaq/Optimizer/CodeGen/Pipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ void addPipelineToQIR(mlir::PassManager &pm,
pm.addNestedPass<mlir::func::FuncOp>(cudaq::opt::createLowerToCFGPass());
pm.addNestedPass<mlir::func::FuncOp>(cudaq::opt::createQuakeAddDeallocs());
pm.addNestedPass<mlir::func::FuncOp>(cudaq::opt::createLoopNormalize());
pm.addNestedPass<mlir::func::FuncOp>(cudaq::opt::createLoopUnroll());
cudaq::opt::LoopUnrollOptions luo;
luo.allowBreak = convertTo.equals("qir-adaptive");
pm.addNestedPass<mlir::func::FuncOp>(cudaq::opt::createLoopUnroll(luo));
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addNestedPass<mlir::func::FuncOp>(
Expand Down
106 changes: 70 additions & 36 deletions lib/Optimizer/Transforms/MemToReg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,9 @@ namespace {
/// There are 3 basic cases.
///
/// -# High-level operations that take region arguments. In this case all
/// use-def information is passed as arguments between the blocks if it is
/// live.
/// def information is passed as arguments between the blocks if it is
/// live. Use information, if only used, is passed as promoted loads,
/// otherwise it involves a def and is passed as an argument.
/// -# High-level operations that disallow region arguments. In this case
/// uses may have loads promoted to immediately before the operation.
/// -# Function operations. In this case, the body is a plain old CFG and
Expand Down Expand Up @@ -325,17 +326,23 @@ class RegionDataFlow {
/// of \p block but not have a dominating definition in \p block. In these
/// cases, the value will be passed as an argument to all blocks in the
/// operation.
std::pair<SSAReg, bool> addRegionBinding(Block *block, MemRef mr) {
std::pair<SSAReg, bool> addEscapingBinding(Block *block, MemRef mr) {
assert(block && rMap.count(block) && mr && !isFunctionBlock(block));
SSAReg result;
bool newlyAdded = false;
if (!escapes.count(mr)) {
auto off = escapes.size();
escapes[mr] = off;
newlyAdded = true;
}
bool changed = maybeAddEscapingBlockArguments(block);
result = block->getArgument(originalBlockArgs[block] + escapes[mr]);
rMap[block][mr] = result;
return {result, changed};
const auto blockArgNum = originalBlockArgs[block] + escapes[mr];
auto ba = block->getArgument(blockArgNum);
rMap[block][mr] = ba;
if (newlyAdded && hasPromotedMemRef(mr)) {
promoChange |= convertPromotedToEscapingDef(block, mr, blockArgNum);
changed |= promoChange;
}
return {ba, changed};
}

/// Is \p mr a known escaping binding?
Expand Down Expand Up @@ -485,14 +492,19 @@ class RegionDataFlow {
/// new dominating dereference. Used when \p op does not allow region
/// arguments.
Value createPromotedValue(Value memuse, Operation *op) {
Operation *parent = op->getParentOp();
if (hasPromotedMemRef(memuse))
return getPromotedMemRef(memuse);
Operation *parent = op->getParentOp();
OpBuilder builder(parent);
Value newUse = reloadMemoryReference(builder, memuse);
return addPromotedMemRef(memuse, newUse);
}

SSAReg getPromotedMemRef(MemRef mr) const {
assert(hasPromotedMemRef(mr));
return promotedMem.find(mr)->second;
}

/// Track the memory reference \p mr as being live-out of the parent
/// operation. (\p parent is passed for the assertion check only.)
void addLiveOutOfParent(Operation *parent, MemRef mr) {
Expand All @@ -504,6 +516,18 @@ class RegionDataFlow {
return SmallVector<MemRef>(liveOutSet.begin(), liveOutSet.end());
}

void cleanupIfPromoChanged(SmallPtrSetImpl<Block *> &visited, Block *block) {
assert(block);
if (promoChange) {
// A promoted load was converted to an escaping definition. We have to
// revisit all the blocks to thread the new block arguments and
// terminators.
visited.clear();
visited.insert(block);
promoChange = false;
}
}

private:
// Delete all ctors that should never be used.
RegionDataFlow() = delete;
Expand All @@ -512,11 +536,28 @@ class RegionDataFlow {

bool hasLiveOutOfParent() const { return !liveOutSet.empty(); }
unsigned getNumEscapes() const { return escapes.size(); }

bool hasPromotedMemRef(MemRef mr) const { return promotedMem.count(mr); }

SSAReg getPromotedMemRef(MemRef mr) const {
assert(hasPromotedMemRef(mr));
return promotedMem.find(mr)->second;
bool convertPromotedToEscapingDef(Block *block, MemRef mr,
unsigned blockArgNum) {
auto ssaReg = promotedMem[mr];
SmallVector<Operation *> users(ssaReg.getUsers().begin(),
ssaReg.getUsers().end());
const bool result = !users.empty();
for (auto *user : users) {
Block *b = user->getBlock();
if (b->getParentOp() != block->getParentOp()) {
// Find the block in parent to add the escaping binding to.
while (b->getParentOp() != block->getParentOp())
b = b->getParentOp()->getBlock();
}
// Add an escaping binding to block `b` for the user to use.
if (b != block)
addEscapingBinding(b, mr);
user->replaceUsesOfWith(ssaReg, b->getArgument(blockArgNum));
}
return result;
}

SSAReg addPromotedMemRef(MemRef mr, SSAReg sr) {
Expand Down Expand Up @@ -573,6 +614,7 @@ class RegionDataFlow {
/// For the body of a function, we maintain a distinct map for each block of
/// the definitions that are live-in to each block.
DenseMap<Block *, DenseMap<MemRef, SSAReg>> liveInMap;
bool promoChange = false;
};
} // namespace

Expand Down Expand Up @@ -877,9 +919,9 @@ class MemToRegPass : public cudaq::opt::impl::MemToRegBase<MemToRegPass> {

// If op is a use of a memory ref, forward the last def if there is one.
// If no def is known, then if this is a function entry raise an error,
// or if this op does not have region arguments add a dominating def
// immediately before parent, or (the default) add a block argument for
// the def.
// or if this op does not have region arguments or this use is not also
// being defined add a dominating def immediately before parent, or
// (the default) add a block argument for the def.
auto handleUse = [&]<typename T>(T useop, Value memuse) {
if (!memuse)
return;
Expand Down Expand Up @@ -909,10 +951,11 @@ class MemToRegPass : public cudaq::opt::impl::MemToRegBase<MemToRegPass> {
if (block->isEntryBlock()) {
// Create a promoted value that dominates parent.
auto newUseopVal = dataFlow.createPromotedValue(memuse, op);
if (parent->hasTrait<OpTrait::NoRegionArguments>()) {
if (!dataFlow.hasEscape(memuse)) {
// In this case, parent does not accept region arguments so the
// reference values must already be defined to dominate parent.
useop.replaceAllUsesWith(newUseopVal);
dataFlow.addBinding(block, memuse, newUseopVal);
cleanUps.insert(useop);
return;
}
Expand All @@ -927,11 +970,12 @@ class MemToRegPass : public cudaq::opt::impl::MemToRegBase<MemToRegPass> {
auto *entry = &reg.front();
bool changes = dataFlow.addBlock(entry);
auto [blockArg, changed] =
dataFlow.addRegionBinding(entry, memuse);
dataFlow.addEscapingBinding(entry, memuse);
if (useop->getParentRegion() == &reg)
useop.replaceAllUsesWith(blockArg);
if (entry == block)
dataFlow.addBinding(block, memuse, blockArg);
dataFlow.cleanupIfPromoChanged(blocksVisited, block);
if (changes || changed)
appendPredecessorsToWorklist(worklist, entry);
}
Expand All @@ -957,28 +1001,14 @@ class MemToRegPass : public cudaq::opt::impl::MemToRegBase<MemToRegPass> {
return;
}

// Does this op not allow region arguments?
if (parent->hasTrait<OpTrait::NoRegionArguments>()) {
if (!dataFlow.hasEscape(memuse)) {
// Create a promoted value that dominates parent. In this case, the
// ref value must already be defined somewhere that dominates Op
// `parent`, so we can just reload it.
auto newUseopVal = dataFlow.createPromotedValue(memuse, op);
useop.replaceAllUsesWith(newUseopVal);
dataFlow.addBinding(block, memuse, newUseopVal);
cleanUps.insert(useop);
return;
}

// We want to add the def to the arguments coming from our
// predecessor blocks.
if (!dataFlow.hasEscape(memuse)) {
// This one isn't already on the list of block arguments, so add
// and record it as a new BlockArgument.
auto [newArg, changes] = dataFlow.addRegionBinding(block, memuse);
dataFlow.addBinding(block, memuse, newArg);
useop.replaceAllUsesWith(newArg);
cleanUps.insert(op);
if (changes)
appendPredecessorsToWorklist(worklist, block);
}
};
if (auto unwrap = dyn_cast<quake::UnwrapOp>(op)) {
Expand Down Expand Up @@ -1014,12 +1044,13 @@ class MemToRegPass : public cudaq::opt::impl::MemToRegBase<MemToRegPass> {
continue;
Block *entry = &reg.front();
bool changes = dataFlow.addBlock(entry);
auto [na, changed] = dataFlow.addRegionBinding(entry, memdef);
if (changes || changed)
auto pr = dataFlow.addEscapingBinding(entry, memdef);
if (changes || pr.second)
appendPredecessorsToWorklist(worklist, entry);
}
auto [na, changes] = dataFlow.addRegionBinding(block, memdef);
if (changes)
dataFlow.cleanupIfPromoChanged(blocksVisited, block);
auto pr = dataFlow.addEscapingBinding(block, memdef);
if (pr.second)
appendPredecessorsToWorklist(worklist, block);
}
}
Expand Down Expand Up @@ -1079,6 +1110,9 @@ class MemToRegPass : public cudaq::opt::impl::MemToRegBase<MemToRegPass> {
SmallVector<Value> operands;
for (auto opndVal : parent->getOperands())
operands.push_back(opndVal);
if (!parent->hasTrait<OpTrait::NoRegionArguments>())
for (auto d : allDefs)
operands.push_back(dataFlow.getPromotedMemRef(d));
Operation *np =
Operation::create(parent->getLoc(), parent->getName(), resultTypes,
operands, parent->getAttrs(),
Expand Down
3 changes: 1 addition & 2 deletions runtime/cudaq/builder/kernel_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -782,18 +782,17 @@ jitCode(ImplicitLocOpBuilder &builder, ExecutionEngine *jit,

PassManager pm(context);
OpPassManager &optPM = pm.nest<func::FuncOp>();
optPM.addPass(cudaq::opt::createUnwindLoweringPass());
cudaq::opt::addAggressiveEarlyInlining(pm);
pm.addPass(createCanonicalizerPass());
pm.addPass(cudaq::opt::createApplyOpSpecializationPass());
optPM.addPass(cudaq::opt::createClassicalMemToReg());
pm.addPass(createCanonicalizerPass());
pm.addPass(cudaq::opt::createExpandMeasurementsPass());
pm.addPass(cudaq::opt::createLoopNormalize());
pm.addPass(cudaq::opt::createLoopUnroll());
pm.addPass(createCanonicalizerPass());
optPM.addPass(cudaq::opt::createQuakeAddDeallocs());
optPM.addPass(cudaq::opt::createQuakeAddMetadata());
optPM.addPass(cudaq::opt::createUnwindLoweringPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());

Expand Down
5 changes: 5 additions & 0 deletions runtime/cudaq/platform/default/rest/RemoteRESTQPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@ class RemoteRESTQPU : public cudaq::QPU {
postCodeGenPasses = match[1].str();
}
}
std::string allowEarlyExitSetting =
(codegenTranslation == "qir-adaptive") ? "1" : "0";
passPipelineConfig =
std::string("func.func(cc-loop-unroll{allow-early-exit=") +
allowEarlyExitSetting + "})," + passPipelineConfig;

// Set the qpu name
qpuName = mutableBackend;
Expand Down
14 changes: 7 additions & 7 deletions test/AST-Quake/adjoint-3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ struct run_circuit {
// ADJOINT: %[[VAL_33:.*]] = arith.subi %[[VAL_32]], %[[VAL_25]] : i32
// ADJOINT: %[[VAL_34:.*]] = arith.addi %[[VAL_7]], %[[VAL_33]] : i32
// ADJOINT: %[[VAL_35:.*]] = arith.constant 0 : i32
// ADJOINT: %[[VAL_36:.*]]:4 = cc.loop while ((%[[VAL_37:.*]] = %[[VAL_34]], %[[VAL_38:.*]] = %[[VAL_9]], %[[VAL_39:.*]] = %[[VAL_1]], %[[VAL_40:.*]] = %[[VAL_32]]) -> (i32, i32, f64, i32)) {
// ADJOINT: %[[VAL_36:.*]]:2 = cc.loop while ((%[[VAL_37:.*]] = %[[VAL_34]], %[[VAL_40:.*]] = %[[VAL_32]]) -> (i32, i32)) {
// ADJOINT: %[[VAL_41:.*]] = arith.cmpi slt, %[[VAL_37]], %[[VAL_9]] : i32
// ADJOINT: %[[VAL_42:.*]] = arith.cmpi sgt, %[[VAL_40]], %[[VAL_35]] : i32
// ADJOINT: cc.condition %[[VAL_42]](%[[VAL_37]], %[[VAL_9]], %[[VAL_1]], %[[VAL_40]] : i32, i32, f64, i32)
// ADJOINT: cc.condition %[[VAL_42]](%[[VAL_37]], %[[VAL_40]] : i32, i32)
// ADJOINT: } do {
// ADJOINT: ^bb0(%[[VAL_43:.*]]: i32, %[[VAL_44:.*]]: i32, %[[VAL_45:.*]]: f64, %[[VAL_46:.*]]: i32):
// ADJOINT: ^bb0(%[[VAL_43:.*]]: i32, %[[VAL_46:.*]]: i32):
// ADJOINT: %[[VAL_47:.*]] = arith.subi %[[VAL_9]], %[[VAL_43]] : i32
// ADJOINT: %[[VAL_48:.*]] = arith.subi %[[VAL_47]], %[[VAL_7]] : i32
// ADJOINT: %[[VAL_50:.*]] = math.fpowi %[[VAL_2]], %[[VAL_48]] : f64, i32
Expand All @@ -139,15 +139,15 @@ struct run_circuit {
// ADJOINT: %[[VAL_56:.*]] = arith.extsi %[[VAL_55]] : i32 to i64
// ADJOINT: %[[VAL_57:.*]] = quake.extract_ref %[[VAL_0]]{{\[}}%[[VAL_56]]] : (!quake.veq<?>, i64) -> !quake.ref
// ADJOINT: %[[VAL_58:.*]] = arith.negf %[[VAL_51]] : f64
// ADJOINT: quake.ry (%[[VAL_58]]) {{\[}}%[[VAL_54]]] %[[VAL_57]] : (f64, !quake.ref, !quake.ref) -> ()
// ADJOINT: cc.continue %[[VAL_43]], %[[VAL_44]], %[[VAL_45]], %[[VAL_46]] : i32, i32, f64, i32
// ADJOINT: quake.ry (%[[VAL_58]]) [%[[VAL_54]]] %[[VAL_57]] : (f64, !quake.ref, !quake.ref) -> ()
// ADJOINT: cc.continue %[[VAL_43]], %[[VAL_46]] : i32, i32
// ADJOINT: } step {
// ADJOINT: ^bb0(%[[VAL_59:.*]]: i32, %[[VAL_60:.*]]: i32, %[[VAL_61:.*]]: f64, %[[VAL_62:.*]]: i32):
// ADJOINT: ^bb0(%[[VAL_59:.*]]: i32, %[[VAL_62:.*]]: i32):
// ADJOINT: %[[VAL_63:.*]] = arith.addi %[[VAL_59]], %[[VAL_7]] : i32
// ADJOINT: %[[VAL_64:.*]] = arith.subi %[[VAL_59]], %[[VAL_7]] : i32
// ADJOINT: %[[VAL_65:.*]] = arith.constant 1 : i32
// ADJOINT: %[[VAL_66:.*]] = arith.subi %[[VAL_62]], %[[VAL_65]] : i32
// ADJOINT: cc.continue %[[VAL_64]], %[[VAL_9]], %[[VAL_1]], %[[VAL_66]] : i32, i32, f64, i32
// ADJOINT: cc.continue %[[VAL_64]], %[[VAL_66]] : i32, i32
// ADJOINT: }
// ADJOINT: %[[VAL_67:.*]] = arith.negf %[[VAL_19]] : f64
// ADJOINT: quake.ry (%[[VAL_67]]) %[[VAL_22]] : (f64, !quake.ref) -> ()
Expand Down
7 changes: 2 additions & 5 deletions test/NVQPP/qir_test_cond_for_break.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

// clang-format off
// RUN: nvq++ --target quantinuum --emulate %s -o %basename_t.x && ./%basename_t.x | FileCheck %s
// XFAIL: *
// ^^^^^ Produces error: 'cc.loop' op not a simple counted loop
// clang-format on
// RUN: nvq++ --target quantinuum --emulate %s -o %basename_t.x && \
// RUN: ./%basename_t.x | FileCheck %s

#include <cudaq.h>
#include <iostream>
Expand Down
Loading

0 comments on commit c5e3189

Please sign in to comment.