Skip to content

Commit

Permalink
support if op when its condition check is not combinational
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahanxie353 committed Oct 9, 2024
1 parent 307160c commit 4d83a5c
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 14 deletions.
151 changes: 137 additions & 14 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ using Scheduleable =

class IfLoweringStateInterface {
public:
void setCondReg(scf::IfOp op, calyx::RegisterOp regOp) {
Operation *operation = op.getOperation();
assert(condReg.count(operation) == 0 &&
"A condition register was already set for this scf::IfOp!\n");
condReg[operation] = regOp;
}

calyx::RegisterOp getCondReg(scf::IfOp op) {
auto it = condReg.find(op.getOperation());
if (it != condReg.end())
return it->second;
return nullptr;
}

void setThenGroup(scf::IfOp op, calyx::GroupOp group) {
Operation *operation = op.getOperation();
assert(thenGroup.count(operation) == 0 &&
Expand Down Expand Up @@ -172,6 +186,7 @@ class IfLoweringStateInterface {
}

private:
DenseMap<Operation *, calyx::RegisterOp> condReg;
DenseMap<Operation *, calyx::GroupOp> thenGroup;
DenseMap<Operation *, calyx::GroupOp> elseGroup;
DenseMap<Operation *, DenseMap<unsigned, calyx::RegisterOp>> resultRegs;
Expand Down Expand Up @@ -240,13 +255,36 @@ class ForLoopLoweringStateInterface
}
};

class PipeOpLoweringStateInterface {
public:
void setPipeResReg(Operation *op, calyx::RegisterOp reg) {
assert(isa<calyx::MultPipeLibOp>(op) || isa<calyx::DivUPipeLibOp>(op) ||
isa<calyx::DivSPipeLibOp>(op) || isa<calyx::RemUPipeLibOp>(op) ||
isa<calyx::RemSPipeLibOp>(op));
assert(resultRegs.count(op) == 0 &&
"A register was already set for this pipe operation!\n");
resultRegs[op] = reg;
}
// Get the register for a specific pipe operation
calyx::RegisterOp getPipeResReg(Operation *op) {
auto it = resultRegs.find(op);
assert(it != resultRegs.end() &&
"No register was set for this pipe operation!\n");
return it->second;
}

private:
DenseMap<Operation *, calyx::RegisterOp> resultRegs;
};

/// Handles the current state of lowering of a Calyx component. It is mainly
/// used as a key/value store for recording information during partial lowering,
/// which is required at later lowering passes.
class ComponentLoweringState : public calyx::ComponentLoweringStateInterface,
public WhileLoopLoweringStateInterface,
public ForLoopLoweringStateInterface,
public IfLoweringStateInterface,
public PipeOpLoweringStateInterface,
public calyx::SchedulerInterface<Scheduleable> {
public:
ComponentLoweringState(calyx::ComponentOp component)
Expand Down Expand Up @@ -339,7 +377,12 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
/// source operation TSrcOp.
template <typename TGroupOp, typename TCalyxLibOp, typename TSrcOp>
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op,
TypeRange srcTypes, TypeRange dstTypes) const {
TypeRange srcTypes, TypeRange dstTypes,
calyx::RegisterOp srcReg = nullptr,
calyx::RegisterOp dstReg = nullptr) const {
assert((srcReg && dstReg) || (!srcReg && !dstReg));
bool isSequential = srcReg && dstReg;

SmallVector<Type> types;
llvm::append_range(types, srcTypes);
llvm::append_range(types, dstTypes);
Expand All @@ -365,26 +408,54 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {

/// Create assignments to the inputs of the library op.
auto group = createGroupForOp<TGroupOp>(rewriter, op);

if (isSequential) {
auto groupOp = cast<calyx::GroupOp>(group);
getState<ComponentLoweringState>().addBlockScheduleable(op->getBlock(),
groupOp);
}

rewriter.setInsertionPointToEnd(group.getBodyBlock());
for (auto dstOp : enumerate(opInputPorts))
rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(),
op->getOperand(dstOp.index()));

for (auto dstOp : enumerate(opInputPorts)) {
if (isSequential)
rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(),
srcReg.getOut());
else
rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(),
op->getOperand(dstOp.index()));
}

/// Replace the result values of the source operator with the new operator.
for (auto res : enumerate(opOutputPorts)) {
getState<ComponentLoweringState>().registerEvaluatingGroup(res.value(),
group);
op->getResult(res.index()).replaceAllUsesWith(res.value());
if (isSequential)
op->getResult(res.index()).replaceAllUsesWith(dstReg.getOut());
else
op->getResult(res.index()).replaceAllUsesWith(res.value());
}

if (isSequential) {
auto groupOp = cast<calyx::GroupOp>(group);
buildAssignmentsForRegisterWrite(
rewriter, groupOp,
getState<ComponentLoweringState>().getComponentOp(), dstReg,
calyxOp.getOut());
}

return success();
}

/// buildLibraryOp which provides in- and output types based on the operands
/// and results of the op argument.
template <typename TGroupOp, typename TCalyxLibOp, typename TSrcOp>
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op) const {
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op,
calyx::RegisterOp srcReg = nullptr,
calyx::RegisterOp dstReg = nullptr) const {
return buildLibraryOp<TGroupOp, TCalyxLibOp, TSrcOp>(
rewriter, op, op.getOperandTypes(), op->getResultTypes());
rewriter, op, op.getOperandTypes(), op->getResultTypes(), srcReg,
dstReg);
}

/// Creates a group named by the basic block which the input op resides in.
Expand All @@ -411,6 +482,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
auto reg = createRegister(
op.getLoc(), rewriter, getComponent(), width.getIntOrFloatBitWidth(),
getState<ComponentLoweringState>().getUniqueName(opName));

// Operation pipelines are not combinational, so a GroupOp is required.
auto group = createGroupForOp<calyx::GroupOp>(rewriter, op);
OpBuilder builder(group->getRegion(0));
Expand Down Expand Up @@ -441,6 +513,8 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
getState<ComponentLoweringState>().registerEvaluatingGroup(
opPipe.getRight(), group);

getState<ComponentLoweringState>().setPipeResReg(out.getDefiningOp(), reg);

return success();
}

Expand Down Expand Up @@ -939,9 +1013,43 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
CmpIOp op) const {
auto isPipeLibOp = [](Value val) -> bool {
if (Operation *defOp = val.getDefiningOp()) {
return isa<calyx::MultPipeLibOp, calyx::DivUPipeLibOp,
calyx::DivSPipeLibOp, calyx::RemUPipeLibOp,
calyx::RemSPipeLibOp>(defOp);
}
return false;
};

switch (op.getPredicate()) {
case CmpIPredicate::eq:
case CmpIPredicate::eq: {
StringRef opName = op.getOperationName().split(".").second;
Type width = op.getResult().getType();
auto condReg = createRegister(
op.getLoc(), rewriter, getComponent(), width.getIntOrFloatBitWidth(),
getState<ComponentLoweringState>().getUniqueName(opName));

for (auto *user : op->getUsers()) {
if (auto ifOp = dyn_cast<scf::IfOp>(user))
getState<ComponentLoweringState>().setCondReg(ifOp, condReg);
}

bool isSequential = isPipeLibOp(op.getLhs()) || isPipeLibOp(op.getRhs());
if (isSequential) {
calyx::RegisterOp pipeResReg;
if (isPipeLibOp(op.getLhs()))
pipeResReg = getState<ComponentLoweringState>().getPipeResReg(
op.getLhs().getDefiningOp());
else
pipeResReg = getState<ComponentLoweringState>().getPipeResReg(
op.getRhs().getDefiningOp());

return buildLibraryOp<calyx::GroupOp, calyx::EqLibOp>(
rewriter, op, pipeResReg, condReg);
}
return buildLibraryOp<calyx::CombGroupOp, calyx::EqLibOp>(rewriter, op);
}
case CmpIPredicate::ne:
return buildLibraryOp<calyx::CombGroupOp, calyx::NeqLibOp>(rewriter, op);
case CmpIPredicate::uge:
Expand Down Expand Up @@ -1535,11 +1643,16 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {
Location loc = ifOp->getLoc();

auto cond = ifOp.getCondition();
auto condGroup = getState<ComponentLoweringState>()
.getEvaluatingGroup<calyx::CombGroupOp>(cond);

auto symbolAttr = FlatSymbolRefAttr::get(
StringAttr::get(getContext(), condGroup.getSymName()));
FlatSymbolRefAttr symbolAttr = nullptr;
auto condReg = getState<ComponentLoweringState>().getCondReg(ifOp);
if (!condReg) {
auto condGroup = getState<ComponentLoweringState>()
.getEvaluatingGroup<calyx::CombGroupOp>(cond);

symbolAttr = FlatSymbolRefAttr::get(
StringAttr::get(getContext(), condGroup.getSymName()));
}

bool initElse = !ifOp.getElseRegion().empty();
auto ifCtrlOp = rewriter.create<calyx::IfOp>(
Expand All @@ -1551,8 +1664,13 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {
rewriter.create<calyx::SeqOp>(ifOp.getThenRegion().getLoc());
auto *thenSeqOpBlock = thenSeqOp.getBodyBlock();

rewriter.setInsertionPointToEnd(thenSeqOpBlock);
auto *thenBlock = &ifOp.getThenRegion().front();
LogicalResult res = buildCFGControl(path, rewriter, thenSeqOpBlock,
/*preBlock=*/block, thenBlock);
if (res.failed())
return res;

rewriter.setInsertionPointToEnd(thenSeqOpBlock);
calyx::GroupOp thenGroup =
getState<ComponentLoweringState>().getThenGroup(ifOp);
rewriter.create<calyx::EnableOp>(thenGroup.getLoc(),
Expand All @@ -1565,8 +1683,13 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {
rewriter.create<calyx::SeqOp>(ifOp.getElseRegion().getLoc());
auto *elseSeqOpBlock = elseSeqOp.getBodyBlock();

rewriter.setInsertionPointToEnd(elseSeqOpBlock);
auto *elseBlock = &ifOp.getElseRegion().front();
res = buildCFGControl(path, rewriter, elseSeqOpBlock,
/*preBlock=*/block, elseBlock);
if (res.failed())
return res;

rewriter.setInsertionPointToEnd(elseSeqOpBlock);
calyx::GroupOp elseGroup =
getState<ComponentLoweringState>().getElseGroup(ifOp);
rewriter.create<calyx::EnableOp>(elseGroup.getLoc(),
Expand Down
119 changes: 119 additions & 0 deletions test/Conversion/SCFToCalyx/convert_controlflow.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -641,3 +641,122 @@ module {
return %1 : i32
}
}

// -----

// Test if ops with sequential condition check.

module {
// CHECK-LABEL: calyx.component @main(
// CHECK-SAME: %[[VAL_0:in0]]: i32,
// CHECK-SAME: %[[VAL_1:.*]]: i1 {clk},
// CHECK-SAME: %[[VAL_2:.*]]: i1 {reset},
// CHECK-SAME: %[[VAL_3:.*]]: i1 {go}) -> (
// CHECK-SAME: %[[VAL_4:out0]]: i32,
// CHECK-SAME: %[[VAL_5:.*]]: i1 {done}) {
// CHECK: %[[VAL_6:.*]] = hw.constant true
// CHECK: %[[VAL_7:.*]] = hw.constant false
// CHECK: %[[VAL_8:.*]] = hw.constant 1 : i32
// CHECK: %[[VAL_9:.*]] = hw.constant 2 : i32
// CHECK: %[[VAL_10:.*]], %[[VAL_11:.*]] = calyx.std_slice @std_slice_1 : i32, i7
// CHECK: %[[VAL_12:.*]], %[[VAL_13:.*]] = calyx.std_slice @std_slice_0 : i32, i7
// CHECK: %[[VAL_14:.*]], %[[VAL_15:.*]], %[[VAL_16:.*]], %[[VAL_17:.*]], %[[VAL_18:.*]], %[[VAL_19:.*]] = calyx.register @load_1_reg : i32, i1, i1, i1, i32, i1
// CHECK: %[[VAL_20:.*]], %[[VAL_21:.*]], %[[VAL_22:.*]], %[[VAL_23:.*]], %[[VAL_24:.*]], %[[VAL_25:.*]] = calyx.register @load_0_reg : i32, i1, i1, i1, i32, i1
// CHECK: %[[VAL_26:.*]], %[[VAL_27:.*]], %[[VAL_28:.*]] = calyx.std_eq @std_eq_0 : i32, i32, i1
// CHECK: %[[VAL_29:.*]], %[[VAL_30:.*]], %[[VAL_31:.*]], %[[VAL_32:.*]], %[[VAL_33:.*]], %[[VAL_34:.*]] = calyx.register @cmpi_0_reg : i1, i1, i1, i1, i1, i1
// CHECK: %[[VAL_35:.*]], %[[VAL_36:.*]], %[[VAL_37:.*]], %[[VAL_38:.*]], %[[VAL_39:.*]], %[[VAL_40:.*]] = calyx.register @remui_0_reg : i32, i1, i1, i1, i32, i1
// CHECK: %[[VAL_41:.*]], %[[VAL_42:.*]], %[[VAL_43:.*]], %[[VAL_44:.*]], %[[VAL_45:.*]], %[[VAL_46:.*]], %[[VAL_47:.*]] = calyx.std_remu_pipe @std_remu_pipe_0 : i1, i1, i1, i32, i32, i32, i1
// CHECK: %[[VAL_48:.*]], %[[VAL_49:.*]], %[[VAL_50:.*]], %[[VAL_51:.*]], %[[VAL_52:.*]], %[[VAL_53:.*]], %[[VAL_54:.*]], %[[VAL_55:.*]] = calyx.seq_mem @mem_0 <[120] x 32> [7] {external = true} : i7, i1, i1, i1, i1, i32, i32, i1
// CHECK: %[[VAL_56:.*]], %[[VAL_57:.*]], %[[VAL_58:.*]], %[[VAL_59:.*]], %[[VAL_60:.*]], %[[VAL_61:.*]] = calyx.register @if_res_0_reg : i32, i1, i1, i1, i32, i1
// CHECK: %[[VAL_62:.*]], %[[VAL_63:.*]], %[[VAL_64:.*]], %[[VAL_65:.*]], %[[VAL_66:.*]], %[[VAL_67:.*]] = calyx.register @ret_arg0_reg : i32, i1, i1, i1, i32, i1
// CHECK: calyx.wires {
// CHECK: calyx.assign %[[VAL_4]] = %[[VAL_66]] : i32
// CHECK: calyx.group @then_br_0 {
// CHECK: calyx.assign %[[VAL_56]] = %[[VAL_24]] : i32
// CHECK: calyx.assign %[[VAL_57]] = %[[VAL_6]] : i1
// CHECK: calyx.group_done %[[VAL_61]] : i1
// CHECK: }
// CHECK: calyx.group @else_br_0 {
// CHECK: calyx.assign %[[VAL_56]] = %[[VAL_18]] : i32
// CHECK: calyx.assign %[[VAL_57]] = %[[VAL_6]] : i1
// CHECK: calyx.group_done %[[VAL_61]] : i1
// CHECK: }
// CHECK: calyx.group @bb0_0 {
// CHECK: calyx.assign %[[VAL_44]] = %[[VAL_0]] : i32
// CHECK: calyx.assign %[[VAL_45]] = %[[VAL_9]] : i32
// CHECK: calyx.assign %[[VAL_35]] = %[[VAL_46]] : i32
// CHECK: calyx.assign %[[VAL_36]] = %[[VAL_47]] : i1
// CHECK: %[[VAL_68:.*]] = comb.xor %[[VAL_47]], %[[VAL_6]] : i1
// CHECK: calyx.assign %[[VAL_43]] = %[[VAL_68]] ? %[[VAL_6]] : i1
// CHECK: calyx.group_done %[[VAL_40]] : i1
// CHECK: }
// CHECK: calyx.group @bb0_1 {
// CHECK: calyx.assign %[[VAL_26]] = %[[VAL_39]] : i32
// CHECK: calyx.assign %[[VAL_27]] = %[[VAL_39]] : i32
// CHECK: calyx.assign %[[VAL_29]] = %[[VAL_28]] : i1
// CHECK: calyx.assign %[[VAL_30]] = %[[VAL_6]] : i1
// CHECK: calyx.group_done %[[VAL_34]] : i1
// CHECK: }
// CHECK: calyx.group @bb0_2 {
// CHECK: calyx.assign %[[VAL_10]] = %[[VAL_9]] : i32
// CHECK: calyx.assign %[[VAL_48]] = %[[VAL_11]] : i7
// CHECK: calyx.assign %[[VAL_51]] = %[[VAL_6]] : i1
// CHECK: calyx.assign %[[VAL_52]] = %[[VAL_7]] : i1
// CHECK: calyx.assign %[[VAL_20]] = %[[VAL_54]] : i32
// CHECK: calyx.assign %[[VAL_21]] = %[[VAL_55]] : i1
// CHECK: calyx.group_done %[[VAL_25]] : i1
// CHECK: }
// CHECK: calyx.group @bb0_3 {
// CHECK: calyx.assign %[[VAL_12]] = %[[VAL_8]] : i32
// CHECK: calyx.assign %[[VAL_48]] = %[[VAL_13]] : i7
// CHECK: calyx.assign %[[VAL_51]] = %[[VAL_6]] : i1
// CHECK: calyx.assign %[[VAL_52]] = %[[VAL_7]] : i1
// CHECK: calyx.assign %[[VAL_14]] = %[[VAL_54]] : i32
// CHECK: calyx.assign %[[VAL_15]] = %[[VAL_55]] : i1
// CHECK: calyx.group_done %[[VAL_19]] : i1
// CHECK: }
// CHECK: calyx.group @ret_assign_0 {
// CHECK: calyx.assign %[[VAL_62]] = %[[VAL_60]] : i32
// CHECK: calyx.assign %[[VAL_63]] = %[[VAL_6]] : i1
// CHECK: calyx.group_done %[[VAL_67]] : i1
// CHECK: }
// CHECK: }
// CHECK: calyx.control {
// CHECK: calyx.seq {
// CHECK: calyx.enable @bb0_0
// CHECK: calyx.enable @bb0_1
// CHECK: calyx.if %[[VAL_33]] {
// CHECK: calyx.seq {
// CHECK: calyx.enable @bb0_2
// CHECK: calyx.enable @then_br_0
// CHECK: }
// CHECK: } else {
// CHECK: calyx.seq {
// CHECK: calyx.enable @bb0_3
// CHECK: calyx.enable @else_br_0
// CHECK: }
// CHECK: }
// CHECK: calyx.enable @ret_assign_0
// CHECK: }
// CHECK: }
// CHECK: } {toplevel}
func.func @main(%arg0 : i32) -> i32 {
%1 = memref.alloc() : memref<120xi32>
%idx_one = arith.constant 1 : index
%two = arith.constant 2: i32
%rem = arith.remui %arg0, %two : i32
%cond = arith.cmpi eq, %arg0, %rem : i32

%res = scf.if %cond -> i32 {
%idx = arith.addi %idx_one, %idx_one : index
%then_res = memref.load %1[%idx] : memref<120xi32>
scf.yield %then_res : i32
} else {
%idx = arith.muli %idx_one, %idx_one : index
%else_res = memref.load %1[%idx] : memref<120xi32>
scf.yield %else_res : i32
}

return %res : i32
}
}

0 comments on commit 4d83a5c

Please sign in to comment.