Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Seq] Add cast operation to immutable type #7638

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/circt/Dialect/Seq/SeqOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ createConstantInitialValue(OpBuilder builder, Operation *constantLike);
// initial op.
Value unwrapImmutableValue(mlir::TypedValue<seq::ImmutableType> immutableVal);

// Helper function to merge initial ops within the block into a single initial
// op. Return failure if we cannot topologically sort the initial ops.
// Return null if there is no initial op in the block. Return the initial op
// otherwise.
FailureOr<seq::InitialOp> mergeInitialOps(Block *block);

} // namespace seq
} // namespace circt

Expand Down
29 changes: 27 additions & 2 deletions include/circt/Dialect/Seq/SeqOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def InitialOp : SeqOp<"initial", [SingleBlock,
See the Seq dialect rationale for a longer description.
}];

let arguments = (ins);
let arguments = (ins Variadic<ImmutableType>: $inputs);
let results = (outs Variadic<ImmutableType>); // seq.immutable values
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
Expand All @@ -721,7 +721,7 @@ def InitialOp : SeqOp<"initial", [SingleBlock,
];

let assemblyFormat = [{
$body attr-dict `:` type(results)
`(` $inputs `)` $body attr-dict `:` functional-type($inputs, results)
}];

let extraClassDeclaration = [{
Expand All @@ -740,3 +740,28 @@ def YieldOp : SeqOp<"yield",

let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}

def FromImmutableOp : SeqOp<"from_immutable", [Pure]> {
let summary = "Cast from an immutable type to a wire type";

let arguments = (ins ImmutableType:$input);
let results = (outs AnyType:$output);

let assemblyFormat = "$input attr-dict `:` functional-type(operands, results)";
}

def GetInitialValueOp : SeqOp<"get_initial_value", [Pure]> {
let summary = "Get an initial value of the input";
let description = [{
This operation freezes a HW value while the initialization phase and
returns the frozen value as `seq.immutable` type. The input value must be valid
at the initialization and immutable through the initialization phase and otherwise
the result value is undefined. In other words time-variant values such as registers
cannot be used as an input.
}];

let arguments = (ins AnyType:$input);
let results = (outs ImmutableType:$output);

let assemblyFormat = "$input attr-dict `:` functional-type(operands, results)";
}
4 changes: 2 additions & 2 deletions integration_test/Bindings/Python/dialects/seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def top(module):
poweron_value = hw.ConstantOp.create(i32, 42).result
# CHECK: %[[INPUT_VAL:.+]] = hw.constant 45
reg_input = hw.ConstantOp.create(i32, 45).result
# CHECK-NEXT: %[[POWERON_VAL:.+]] = seq.initial {
# CHECK-NEXT: %[[POWERON_VAL:.+]] = seq.initial() {
# CHECK-NEXT: %[[C42:.+]] = hw.constant 42 : i32
# CHECK-NEXT: seq.yield %[[C42]] : i32
# CHECK-NEXT: } : !seq.immutable<i32>
# CHECK-NEXT: } : () -> !seq.immutable<i32>
# CHECK: %[[DATA_VAL:.+]] = seq.compreg %[[INPUT_VAL]], %clk reset %rst, %[[RESET_VAL]] initial %[[POWERON_VAL]]
reg = seq.CompRegOp(i32,
reg_input,
Expand Down
4 changes: 2 additions & 2 deletions integration_test/arcilator/JIT/reg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ hw.module @counter(in %clk: i1, out o1: i8, out o2: i8) {

%r0 = seq.compreg %added1, %seq_clk initial %0#0 : i8
%r1 = seq.compreg %added2, %seq_clk initial %0#1 : i8
%0:2 = seq.initial {
%0:2 = seq.initial () {
%1 = func.call @random() : () -> i32
%2 = comb.extract %1 from 0 : (i32) -> i8
%3 = hw.constant 5 : i8
seq.yield %2, %3: i8, i8
} : !seq.immutable<i8>, !seq.immutable<i8>
} : () -> (!seq.immutable<i8>, !seq.immutable<i8>)

%one = hw.constant 1 : i8
%added1 = comb.add %r0, %one : i8
Expand Down
2 changes: 1 addition & 1 deletion lib/Bindings/Python/dialects/seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self,
if power_on_value.owner is None:
assert False, "Initial value must not be port"
elif isinstance(power_on_value.owner.opview, hw.ConstantOp):
init = InitialOp([seq.ImmutableType.get(power_on_value.type)])
init = InitialOp([seq.ImmutableType.get(power_on_value.type)], [])
init.body.blocks.append()
with InsertionPoint(init.body.blocks[0]):
cloned_constant = power_on_value.owner.clone()
Expand Down
9 changes: 4 additions & 5 deletions lib/Conversion/ConvertToArcs/ConvertToArcs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@ static LogicalResult convertInitialValue(seq::CompRegOp reg,
if (!reg.getInitialValue())
return values.push_back({}), success();

// unrealized_conversion_cast to normal type
// Use from_immutable cast to convert the seq.immutable type to the reg's
// type.
OpBuilder builder(reg);
auto init = builder
.create<mlir::UnrealizedConversionCastOp>(
reg.getLoc(), reg.getType(), reg.getInitialValue())
.getResult(0);
auto init = builder.create<seq::FromImmutableOp>(reg.getLoc(), reg.getType(),
reg.getInitialValue());

values.push_back(init);
return success();
Expand Down
107 changes: 86 additions & 21 deletions lib/Conversion/SeqToSV/SeqToSV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "circt/Dialect/SV/SVOps.h"
#include "circt/Dialect/Seq/SeqOps.h"
#include "circt/Support/Naming.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
Expand Down Expand Up @@ -59,11 +60,10 @@ struct SeqToSVPass : public impl::LowerSeqToSVBase<SeqToSVPass> {
namespace {
struct ModuleLoweringState {
ModuleLoweringState(HWModuleOp module)
: initalOpLowering(module), module(module) {}
: immutableValueLowering(module), module(module) {}

struct InitialOpLowering {
InitialOpLowering(hw::HWModuleOp module)
: builder(module.getModuleBody()), module(module) {}
struct ImmutableValueLowering {
ImmutableValueLowering(hw::HWModuleOp module) : module(module) {}

// Lower initial ops.
LogicalResult lower();
Expand All @@ -82,9 +82,8 @@ struct ModuleLoweringState {
// defined in SV initial op.
MapVector<mlir::TypedValue<seq::ImmutableType>, Value> mapping;

OpBuilder builder;
hw::HWModuleOp module;
} initalOpLowering;
} immutableValueLowering;

struct FragmentInfo {
bool needsRegFragment = false;
Expand All @@ -94,21 +93,39 @@ struct ModuleLoweringState {
HWModuleOp module;
};

LogicalResult ModuleLoweringState::InitialOpLowering::lower() {
auto loweringFailed = module
.walk([&](seq::InitialOp initialOp) {
if (failed(lower(initialOp)))
return mlir::WalkResult::interrupt();
return mlir::WalkResult::advance();
})
.wasInterrupted();
return LogicalResult::failure(loweringFailed);
LogicalResult ModuleLoweringState::ImmutableValueLowering::lower() {
auto result = mergeInitialOps(module.getBodyBlock());
if (failed(result))
return failure();

auto initialOp = *result;
if (!initialOp)
return success();

return lower(initialOp);
}

LogicalResult
ModuleLoweringState::InitialOpLowering::lower(seq::InitialOp initialOp) {
ModuleLoweringState::ImmutableValueLowering::lower(seq::InitialOp initialOp) {
OpBuilder builder = OpBuilder::atBlockBegin(module.getBodyBlock());
if (!svInitialOp)
svInitialOp = builder.create<sv::InitialOp>(initialOp->getLoc());
// Replace immutable operands passed to initial op with already lowered
// values.
for (auto [blockArgument, operand] :
llvm::zip(initialOp.getBodyBlock()->getArguments(),
initialOp->getOpOperands())) {

auto immut = operand.get().getDefiningOp<seq::GetInitialValueOp>();
if (!immut)
return initialOp.emitError()
<< "invalid operand to initial op: " << operand.get();
blockArgument.replaceAllUsesWith(immut.getInput());
operand.drop();
if (immut.use_empty())
immut.erase();
}

auto loc = initialOp.getLoc();
llvm::SmallVector<Value> results;

Expand All @@ -127,10 +144,10 @@ ModuleLoweringState::InitialOpLowering::lower(seq::InitialOp initialOp) {
}

svInitialOp.getBodyBlock()->getOperations().splice(
svInitialOp.begin(), initialOp.getBodyBlock()->getOperations());
svInitialOp.end(), initialOp.getBodyBlock()->getOperations());

assert(initialOp->use_empty());
initialOp->erase();
initialOp.erase();
yieldOp->erase();
return success();
}
Expand Down Expand Up @@ -200,7 +217,7 @@ class CompRegLower : public OpConversionPattern<OpTy> {
auto module = reg->template getParentOfType<hw::HWModuleOp>();
const auto &initial =
moduleLoweringStates.find(module.getModuleNameAttr())
->second.initalOpLowering;
->second.immutableValueLowering;

Value initialValue = initial.lookupImmutableValue(init);

Expand Down Expand Up @@ -247,6 +264,49 @@ void CompRegLower<CompRegClockEnabledOp>::createAssign(
});
}

/// Lower FromImmutable to `sv.reg` and `sv.initial`.
class FromImmutableLowering : public OpConversionPattern<FromImmutableOp> {
public:
FromImmutableLowering(
TypeConverter &typeConverter, MLIRContext *context,
const MapVector<StringAttr, ModuleLoweringState> &moduleLoweringStates)
: OpConversionPattern<FromImmutableOp>(typeConverter, context),
moduleLoweringStates(moduleLoweringStates) {}

using OpAdaptor = typename OpConversionPattern<FromImmutableOp>::OpAdaptor;

LogicalResult
matchAndRewrite(FromImmutableOp fromImmutableOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = fromImmutableOp.getLoc();

auto regTy = ConversionPattern::getTypeConverter()->convertType(
fromImmutableOp.getType());
auto svReg = rewriter.create<sv::RegOp>(loc, regTy);

auto regVal = rewriter.create<sv::ReadInOutOp>(loc, svReg);

// Lower initial values.
auto module = fromImmutableOp->template getParentOfType<hw::HWModuleOp>();
const auto &initial = moduleLoweringStates.find(module.getModuleNameAttr())
->second.immutableValueLowering;

Value initialValue =
initial.lookupImmutableValue(fromImmutableOp.getInput());

OpBuilder::InsertionGuard guard(rewriter);
auto in = initial.getSVInitial();
rewriter.setInsertionPointToEnd(in.getBodyBlock());
rewriter.create<sv::BPAssignOp>(fromImmutableOp->getLoc(), svReg,
initialValue);

rewriter.replaceOp(fromImmutableOp, regVal);
return success();
}

private:
const MapVector<StringAttr, ModuleLoweringState> &moduleLoweringStates;
};
// Lower seq.clock_gate to a fairly standard clock gate implementation.
//
class ClockGateLowering : public OpConversionPattern<ClockGateOp> {
Expand Down Expand Up @@ -537,7 +597,7 @@ void SeqToSVPass::runOnOperation() {
moduleLoweringStates.try_emplace(module.getModuleNameAttr(),
ModuleLoweringState(module));

mlir::parallelForEach(
auto result = mlir::failableParallelForEach(
&getContext(), moduleLoweringStates, [&](auto &moduleAndState) {
auto &state = moduleAndState.second;
auto module = state.module;
Expand All @@ -561,9 +621,12 @@ void SeqToSVPass::runOnOperation() {
}
needsMemRandomization = true;
}
(void)state.initalOpLowering.lower();
return state.immutableValueLowering.lower();
});

if (failed(result))
return signalPassFailure();

auto randomInitFragmentName =
FlatSymbolRefAttr::get(context, "RANDOM_INIT_FRAGMENT");
auto randomInitRegFragmentName =
Expand Down Expand Up @@ -605,6 +668,8 @@ void SeqToSVPass::runOnOperation() {
moduleLoweringStates);
patterns.add<CompRegLower<CompRegClockEnabledOp>>(
typeConverter, context, lowerToAlwaysFF, moduleLoweringStates);
patterns.add<FromImmutableLowering>(typeConverter, context,
moduleLoweringStates);
patterns.add<ClockCastLowering<seq::FromClockOp>>(typeConverter, context);
patterns.add<ClockCastLowering<seq::ToClockOp>>(typeConverter, context);
patterns.add<ClockGateLowering>(typeConverter, context);
Expand Down
38 changes: 25 additions & 13 deletions lib/Dialect/Arc/Transforms/LowerState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ Value ClockLowering::materializeValue(Value value) {
return {};
if (auto mapped = materializedValues.lookupOrNull(value))
return mapped;
if (auto fromImmutable = value.getDefiningOp<seq::FromImmutableOp>())
// Immutable value is pre-materialized so directly lookup the input.
return materializedValues.lookup(fromImmutable.getInput());

if (!shouldMaterialize(value))
return value;

Expand Down Expand Up @@ -427,19 +431,27 @@ LogicalResult ModuleLowering::lowerPrimaryOutputs() {
}

LogicalResult ModuleLowering::lowerInitials() {
// Move all operations except for seq.yield to arc.initial op.
for (auto op : moduleOp.getOps<seq::InitialOp>()) {
auto terminator = cast<seq::YieldOp>(op.getBodyBlock()->getTerminator());
getInitial().builder.getBlock()->getOperations().splice(
getInitial().builder.getBlock()->begin(),
op.getBodyBlock()->getOperations());

// Map seq.initial results to operands of the seq.yield op.
for (auto [result, operand] :
llvm::zip(op.getResults(), terminator.getOperands()))
getInitial().materializedValues.map(result, operand);
terminator.erase();
}
// Merge all seq.initial ops into a single seq.initial op.
auto result = circt::seq::mergeInitialOps(moduleOp.getBodyBlock());
if (failed(result))
return moduleOp.emitError() << "initial ops cannot be topologically sorted";

auto initialOp = *result;
if (!initialOp) // There is no seq.initial op.
return success();

// Move the operations of the merged initial op into the builder's block.
auto terminator =
cast<seq::YieldOp>(initialOp.getBodyBlock()->getTerminator());
getInitial().builder.getBlock()->getOperations().splice(
getInitial().builder.getBlock()->begin(),
initialOp.getBodyBlock()->getOperations());

// Map seq.initial results to their corresponding operands.
for (auto [result, operand] :
llvm::zip(initialOp.getResults(), terminator.getOperands()))
getInitial().materializedValues.map(result, operand);
terminator.erase();

return success();
}
Expand Down
Loading
Loading