Skip to content

Commit

Permalink
[Arc] Initial implementation of partition pass
Browse files Browse the repository at this point in the history
  • Loading branch information
SpriteOvO committed Oct 24, 2024
1 parent 804cdbe commit ef99d3d
Show file tree
Hide file tree
Showing 10 changed files with 621 additions and 37 deletions.
15 changes: 15 additions & 0 deletions include/circt/Dialect/Arc/ArcOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,21 @@ def InitialOp : ClockTreeLikeOp<"initial"> {
}];
}

//===----------------------------------------------------------------------===//
// Partition
//===----------------------------------------------------------------------===//

def SyncOp : ArcOp<"sync", [
RecursiveMemoryEffects, NoTerminator, NoRegionArguments, SingleBlock,
// ParentOneOf<["ClockTreeOp", "ModelOp"]>
]> {
let summary = "A region reserved after the computation of different partitions";
let assemblyFormat = [{
attr-dict-with-keyword $body
}];
let regions = (region SizedRegion<1>:$body);
}

//===----------------------------------------------------------------------===//
// Storage Allocation
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 2 additions & 1 deletion include/circt/Dialect/Arc/ArcPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ std::unique_ptr<mlir::Pass> createLegalizeStateUpdatePass();
std::unique_ptr<mlir::Pass> createLowerArcsToFuncsPass();
std::unique_ptr<mlir::Pass> createLowerClocksToFuncsPass();
std::unique_ptr<mlir::Pass> createLowerLUTPass();
std::unique_ptr<mlir::Pass> createLowerStatePass();
std::unique_ptr<mlir::Pass> createLowerStatePass(const LowerStateOptions &options = {});
std::unique_ptr<mlir::Pass> createLowerVectorizationsPass(
LowerVectorizationsModeEnum mode = LowerVectorizationsModeEnum::Full);
std::unique_ptr<mlir::Pass> createMakeTablesPass();
std::unique_ptr<mlir::Pass> createMuxToControlFlowPass();
std::unique_ptr<mlir::Pass> createPartitionPass(const PartitionOptions &opts = {});
std::unique_ptr<mlir::Pass> createPrintCostModelPass();
std::unique_ptr<mlir::Pass> createSimplifyVariadicOpsPass();
std::unique_ptr<mlir::Pass> createSplitLoopsPass();
Expand Down
16 changes: 16 additions & 0 deletions include/circt/Dialect/Arc/ArcPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ def LowerLUT : Pass<"arc-lower-lut", "arc::DefineOp"> {
def LowerState : Pass<"arc-lower-state", "mlir::ModuleOp"> {
let summary = "Split state into read and write ops grouped by clock tree";
let constructor = "circt::arc::createLowerStatePass()";
let options = [
Option<"withTemporary", "with-temporary", "bool", "false",
"Create temporary storage for blocks and subsequently temporaries during state update legalization">
];
let dependentDialects = [
"arc::ArcDialect", "mlir::scf::SCFDialect", "mlir::func::FuncDialect",
"mlir::LLVM::LLVMDialect", "comb::CombDialect", "seq::SeqDialect"
Expand Down Expand Up @@ -303,6 +307,18 @@ def SimplifyVariadicOps : Pass<"arc-simplify-variadic-ops", "mlir::ModuleOp"> {
];
}

def Partition: Pass<"arc-partition", "arc::ModelOp"> {
let summary = "Partition the entire circuit into multiple chunks";
let options = [
Option<"chunks", "chunks", "unsigned", "1",
"Number of resulting chunks">
];
}

def PartitionClone: Pass<"arc-partition-clone"> {
let summary = "Commit the planned partition by cloning models";
}

def SplitFuncs : Pass<"arc-split-funcs", "mlir::ModuleOp"> {
let summary = "Split large funcs into multiple smaller funcs";
let dependentDialects = ["mlir::func::FuncDialect"];
Expand Down
8 changes: 5 additions & 3 deletions lib/Dialect/Arc/ArcOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,12 @@ void RootOutputOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
//===----------------------------------------------------------------------===//

LogicalResult ModelOp::verify() {
if (getBodyBlock().getArguments().size() != 1)
if (size_t argcnt = getBodyBlock().getArguments().size(); argcnt != 1 && argcnt != 2)
return emitOpError("must have exactly one argument");
if (auto type = getBodyBlock().getArgument(0).getType();
!isa<StorageType>(type))
if (llvm::any_of(getBodyBlock().getArguments(),
[](BlockArgument arg) -> bool {
return !isa<StorageType>(arg.getType());
}))
return emitOpError("argument must be of storage type");
for (const hw::ModulePort &port : getIo().getPorts())
if (port.dir == hw::ModulePort::Direction::InOut)
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Arc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_circt_dialect_library(CIRCTArcTransforms
LowerVectorizations.cpp
MakeTables.cpp
MuxToControlFlow.cpp
Partition.cpp
PrintCostModel.cpp
SimplifyVariadicOps.cpp
SplitFuncs.cpp
Expand Down
17 changes: 15 additions & 2 deletions lib/Dialect/Arc/Transforms/LegalizeStateUpdate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,23 +361,36 @@ LogicalResult Legalizer::visitBlock(Block *block) {
// HACK: This is ugly, but we need a storage reference to allocate a state
// into. Ideally we'd materialize this later on, but the current impl of
// the alloc op requires a storage immediately. So try to find one.
auto storage = TypeSwitch<Operation *, Value>(state.getDefiningOp())
auto origStorage = TypeSwitch<Operation *, Value>(state.getDefiningOp())
.Case<AllocStateOp, RootInputOp, RootOutputOp>(
[&](auto allocOp) { return allocOp.getStorage(); })
.Default([](auto) { return Value{}; });
if (!storage) {
if (!origStorage) {
mlir::emitError(
state.getLoc(),
"cannot find storage pointer to allocate temporary into");
return failure();
}

// Check if the storage has a temporary storage
assert(isa<BlockArgument>(origStorage) && "Storage should be a block argument");
auto origStorageArg = cast<BlockArgument>(origStorage);
auto origStorageBlk = origStorageArg.getParentBlock();
auto storage = origStorageBlk->getArgument(origStorageBlk->getNumArguments() - 1);
assert(isa<TypedValue<StorageType>>(storage) && "Block argument in arc.model should be a storage");

// Allocate a temporary state, read the current value of the state we are
// legalizing, and write it to the temporary.
++numLegalizedWrites;
ImplicitLocOpBuilder builder(state.getLoc(), op);
auto tmpState =
builder.create<AllocStateOp>(state.getType(), storage, nullptr);

auto stateRole = cast<StringAttr>(state.getDefiningOp()->getAttr("partition-role"));
assert(stateRole == "state" || stateRole == "old-clock");
auto shadowRole = stateRole == "state" ? "shadow-state" : "shadow-old-clock";
tmpState->setAttr("partition-role", builder.getStringAttr(shadowRole));

auto stateValue = builder.create<StateReadOp>(state);
builder.create<StateWriteOp>(tmpState, stateValue, Value{});
locallyLegalizedStates.push_back(state);
Expand Down
48 changes: 29 additions & 19 deletions lib/Dialect/Arc/Transforms/LowerClocksToFuncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ struct LowerClocksToFuncsPass

void runOnOperation() override;
LogicalResult lowerModel(ModelOp modelOp);
LogicalResult lowerClock(Operation *clockOp, Value modelStorageArg,
LogicalResult lowerClock(Operation *clockOp,
SmallVector<Value> modelStorageArgs,
OpBuilder &funcBuilder);
LogicalResult isolateClock(Operation *clockOp, Value modelStorageArg,
Value clockStorageArg);

LogicalResult isolateClock(Operation *clockOp, IRMapping storageArgMapping);
SymbolTable *symbolTable;

Statistic numOpsCopied{this, "ops-copied", "Ops copied into clock trees"};
Expand Down Expand Up @@ -105,27 +104,37 @@ LogicalResult LowerClocksToFuncsPass::lowerModel(ModelOp modelOp) {
// Perform the actual extraction.
OpBuilder funcBuilder(modelOp);
for (auto *op : clocks)
if (failed(lowerClock(op, modelOp.getBody().getArgument(0), funcBuilder)))
if (failed(
lowerClock(op,
llvm::map_to_vector(modelOp.getBody().getArguments(),
[](auto v) -> Value { return v; }),
funcBuilder)))
return failure();

return success();
}

LogicalResult LowerClocksToFuncsPass::lowerClock(Operation *clockOp,
Value modelStorageArg,
OpBuilder &funcBuilder) {
LogicalResult
LowerClocksToFuncsPass::lowerClock(Operation *clockOp,
SmallVector<Value> modelStorageArgs,
OpBuilder &funcBuilder) {
LLVM_DEBUG(llvm::dbgs() << "- Lowering clock " << clockOp->getName() << "\n");
assert((isa<ClockTreeOp, PassThroughOp, InitialOp>(clockOp)));

// Add a `StorageType` block argument to the clock's body block which we are
// going to use to pass the storage pointer to the clock once it has been
// pulled out into a separate function.
Region &clockRegion = clockOp->getRegion(0);
Value clockStorageArg = clockRegion.addArgument(modelStorageArg.getType(),
modelStorageArg.getLoc());

IRMapping mapping;

for (const auto &storage : modelStorageArgs) {
Value mapped = clockRegion.addArgument(storage.getType(), storage.getLoc());
mapping.map(storage, mapped);
}

// Ensure the clock tree does not use any values defined outside of it.
if (failed(isolateClock(clockOp, modelStorageArg, clockStorageArg)))
if (failed(isolateClock(clockOp, mapping)))
return failure();

// Add a return op to the end of the body.
Expand All @@ -144,9 +153,11 @@ LogicalResult LowerClocksToFuncsPass::lowerClock(Operation *clockOp,
else
funcName.append("_clock");

SmallVector<Type> fnArgTypes = llvm::map_to_vector(
modelStorageArgs, [](Value v) { return v.getType(); });

auto funcOp = funcBuilder.create<func::FuncOp>(
clockOp->getLoc(), funcName,
builder.getFunctionType({modelStorageArg.getType()}, {}));
clockOp->getLoc(), funcName, builder.getFunctionType({fnArgTypes}, {}));
symbolTable->insert(funcOp); // uniquifies the name
LLVM_DEBUG(llvm::dbgs() << " - Created function `" << funcOp.getSymName()
<< "`\n");
Expand All @@ -159,11 +170,11 @@ LogicalResult LowerClocksToFuncsPass::lowerClock(Operation *clockOp,
treeOp.getClock(), false);
auto builder = ifOp.getThenBodyBuilder();
builder.template create<func::CallOp>(clockOp->getLoc(), funcOp,
ValueRange{modelStorageArg});
ValueRange{modelStorageArgs});
})
.Case<PassThroughOp>([&](auto) {
builder.template create<func::CallOp>(clockOp->getLoc(), funcOp,
ValueRange{modelStorageArg});
ValueRange{modelStorageArgs});
})
.Case<InitialOp>([&](auto) {
if (modelOp.getInitialFn().has_value())
Expand Down Expand Up @@ -195,16 +206,15 @@ LogicalResult LowerClocksToFuncsPass::lowerClock(Operation *clockOp,
/// body. Anything besides constants should no longer exist after a proper run
/// of the pipeline.
LogicalResult LowerClocksToFuncsPass::isolateClock(Operation *clockOp,
Value modelStorageArg,
Value clockStorageArg) {
IRMapping storageMapping) {
auto *clockRegion = &clockOp->getRegion(0);
auto builder = OpBuilder::atBlockBegin(&clockRegion->front());
DenseMap<Value, Value> copiedValues;
auto result = clockRegion->walk([&](Operation *op) {
for (auto &operand : op->getOpOperands()) {
// Block arguments are okay, since there's nothing we can move.
if (operand.get() == modelStorageArg) {
operand.set(clockStorageArg);
if (storageMapping.contains(operand.get())) {
operand.set(storageMapping.lookup(operand.get()));
continue;
}
if (isa<BlockArgument>(operand.get())) {
Expand Down
42 changes: 32 additions & 10 deletions lib/Dialect/Arc/Transforms/LowerState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,14 @@ struct ModuleLowering {
DenseMap<Value, GatedClockLowering> gatedClockLowerings;
std::unique_ptr<ClockLowering> initialLowering;
Value storageArg;
Value temporaryStorageArg;
OpBuilder clockBuilder;
OpBuilder stateBuilder;
bool withTemporary;

ModuleLowering(HWModuleOp moduleOp, Statistics &stats)
ModuleLowering(HWModuleOp moduleOp, Statistics &stats, bool withTemporary)
: moduleOp(moduleOp), stats(stats), context(moduleOp.getContext()),
clockBuilder(moduleOp), stateBuilder(moduleOp) {}
clockBuilder(moduleOp), stateBuilder(moduleOp), withTemporary(withTemporary) {}

GatedClockLowering getOrCreateClockLowering(Value clock);
ClockLowering &getOrCreatePassThrough();
Expand Down Expand Up @@ -313,12 +315,23 @@ GatedClockLowering ModuleLowering::getOrCreateClockLowering(Value clock) {
clockBuilder.createOrFold<seq::FromClockOp>(clock.getLoc(), clock);

// Detect a rising edge on the clock, as `(old != new) & new`.
auto clockStorage = withTemporary ? temporaryStorageArg : storageArg;
auto oldClockStorage = stateBuilder.create<AllocStateOp>(
clock.getLoc(), StateType::get(stateBuilder.getI1Type()), storageArg);
clock.getLoc(), StateType::get(stateBuilder.getI1Type()), clockStorage);
oldClockStorage->setAttr("partition-role", stateBuilder.getStringAttr("old-clock"));
auto oldClock =
clockBuilder.create<StateReadOp>(clock.getLoc(), oldClockStorage);

SyncOp sync;
if(withTemporary) {
sync = clockBuilder.create<SyncOp>(clock.getLoc());
sync.getBodyRegion().emplaceBlock();
clockBuilder.setInsertionPointToEnd(&sync.getBody().front());
}
clockBuilder.create<StateWriteOp>(clock.getLoc(), oldClockStorage, newClock,
Value{});
if(withTemporary) clockBuilder.setInsertionPointAfter(sync);

Value trigger = clockBuilder.create<comb::ICmpOp>(
clock.getLoc(), comb::ICmpPredicate::ne, oldClock, newClock);
trigger =
Expand Down Expand Up @@ -360,16 +373,19 @@ Value ModuleLowering::replaceValueWithStateRead(Value value, Value state) {

/// Add the global state as an argument to the module's body block.
void ModuleLowering::addStorageArg() {
assert(!storageArg);
assert(!storageArg && !temporaryStorageArg);
storageArg = moduleOp.getBodyBlock()->addArgument(
StorageType::get(context, {}), moduleOp.getLoc());
if (withTemporary)
temporaryStorageArg = moduleOp.getBodyBlock()->addArgument(
StorageType::get(context, {}), moduleOp.getLoc());
}

/// Lower the primary inputs of the module to dedicated ops that allocate the
/// inputs in the model's storage.
LogicalResult ModuleLowering::lowerPrimaryInputs() {
for (auto blockArg : moduleOp.getBodyBlock()->getArguments()) {
if (blockArg == storageArg)
if (blockArg == storageArg || blockArg == temporaryStorageArg)
continue;
auto name = moduleOp.getArgName(blockArg.getArgNumber());
auto argTy = blockArg.getType();
Expand Down Expand Up @@ -470,6 +486,7 @@ LogicalResult ModuleLowering::lowerStateLike(
auto stateType = StateType::get(intType);
auto state = stateBuilder.create<AllocStateOp>(stateOp->getLoc(), stateType,
storageArg);
state->setAttr("partition-role", stateBuilder.getStringAttr("state"));
if (auto names = stateOp->getAttrOfType<ArrayAttr>("names"))
state->setAttr("name", names[stateIdx]);
allocatedStates.push_back(state);
Expand Down Expand Up @@ -646,6 +663,7 @@ LogicalResult ModuleLowering::lowerState(TapOp tapOp) {
auto materializedValue = passThrough.materializeValue(tapValue);
auto state = stateBuilder.create<AllocStateOp>(
tapOp.getLoc(), StateType::get(intType), storageArg, true);
state->setAttr("partition-role", stateBuilder.getStringAttr("tap-output"));
state->setAttr("name", tapOp.getNameAttr());
passThrough.builder.create<StateWriteOp>(tapOp.getLoc(), state,
materializedValue, Value{});
Expand Down Expand Up @@ -688,6 +706,7 @@ LogicalResult ModuleLowering::lowerExtModule(InstanceOp instOp) {
auto &passThrough = getOrCreatePassThrough();
auto state = stateBuilder.create<AllocStateOp>(
instOp.getLoc(), StateType::get(intType), storageArg);
state->setAttr("partition-role", stateBuilder.getStringAttr("ext-module-input"));
state->setAttr("name", stateBuilder.getStringAttr(baseName));
passThrough.builder.create<StateWriteOp>(
instOp.getLoc(), state, passThrough.materializeValue(operand), Value{});
Expand All @@ -709,6 +728,7 @@ LogicalResult ModuleLowering::lowerExtModule(InstanceOp instOp) {
baseName += cast<StringAttr>(name).getValue();
auto state = stateBuilder.create<AllocStateOp>(
result.getLoc(), StateType::get(intType), storageArg);
state->setAttr("partition-role", stateBuilder.getStringAttr("ext-module-output"));
state->setAttr("name", stateBuilder.getStringAttr(baseName));
replaceValueWithStateRead(result, state);
}
Expand Down Expand Up @@ -786,13 +806,15 @@ LogicalResult ModuleLowering::cleanup() {

namespace {
struct LowerStatePass : public arc::impl::LowerStateBase<LowerStatePass> {
LowerStatePass() = default;
LowerStatePass(const LowerStateOptions &options = {}) : opts(options) {}
LowerStatePass(const LowerStatePass &pass) : LowerStatePass() {}

void runOnOperation() override;
LogicalResult runOnModule(HWModuleOp moduleOp, SymbolTable &symtbl);

Statistics stats{this};

LowerStateOptions opts;
};
} // namespace

Expand Down Expand Up @@ -864,7 +886,7 @@ LogicalResult LowerStatePass::runOnModule(HWModuleOp moduleOp,
SymbolTable &symtbl) {
LLVM_DEBUG(llvm::dbgs() << "Lowering state in `" << moduleOp.getModuleName()
<< "`\n");
ModuleLowering lowering(moduleOp, stats);
ModuleLowering lowering(moduleOp, stats, opts.withTemporary);

// Add sentinel ops to separate state allocations from clock trees.
lowering.stateBuilder.setInsertionPointToStart(moduleOp.getBodyBlock());
Expand Down Expand Up @@ -906,7 +928,7 @@ LogicalResult LowerStatePass::runOnModule(HWModuleOp moduleOp,

// Replace the `HWModuleOp` with a `ModelOp`.
moduleOp.getBodyBlock()->eraseArguments(
[&](auto arg) { return arg != lowering.storageArg; });
[&](auto arg) { return arg != lowering.storageArg && arg != lowering.temporaryStorageArg; });
ImplicitLocOpBuilder builder(moduleOp.getLoc(), moduleOp);
auto modelOp = builder.create<ModelOp>(
moduleOp.getLoc(), moduleOp.getModuleNameAttr(),
Expand All @@ -918,6 +940,6 @@ LogicalResult LowerStatePass::runOnModule(HWModuleOp moduleOp,
return success();
}

std::unique_ptr<Pass> arc::createLowerStatePass() {
return std::make_unique<LowerStatePass>();
std::unique_ptr<Pass> arc::createLowerStatePass(const LowerStateOptions &options) {
return std::make_unique<LowerStatePass>(options);
}
Loading

0 comments on commit ef99d3d

Please sign in to comment.