Skip to content

Commit

Permalink
use a local build method
Browse files Browse the repository at this point in the history
  • Loading branch information
desmonddak committed Mar 1, 2025
1 parent 7a780c2 commit 4f101d8
Showing 1 changed file with 66 additions and 71 deletions.
137 changes: 66 additions & 71 deletions lib/src/reduction_tree.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,43 +9,11 @@

import 'dart:math';

import 'package:meta/meta.dart';
import 'package:rohd/rohd.dart';
import 'package:rohd_hcl/rohd_hcl.dart';

/// Recursive Node for Reduction Tree
class ReductionTree extends Module {
/// Operation to be performed at each node. Note that [operation] can widen
/// the output. The logic function must support the operation for [2 to radix]
/// inputs.
@protected
final Logic Function(List<Logic> inputs, {String name}) operation;

/// Specified width of input to each reduction node (e.g., binary: radix=2)
@protected
late final int radix;

/// When [signExtend] is true, use sign-extension on values,
/// otherwise use zero-extension.
@protected
late final bool signExtend;

/// Specified depth of nodes at which to flop (requires [clk]).
@protected
late final int? depthToFlop;

/// Optional [clk] input to create pipeline.
@protected
late final Logic? clk;

/// Optional [reset] input to reset pipeline.
@protected
late final Logic? reset;

/// Optional [enable] input to enable pipeline.
@protected
late final Logic? enable;

/// The final output of the tree computation.
Logic get out => output('out');

Expand All @@ -56,25 +24,48 @@ class ReductionTree extends Module {
/// The flop depth of the tree from the output to the leaves.
int get latency => _computed.flopDepth;

/// Operation to be performed at each node. Note that [_operation] can widen
/// the output. The logic function must support the operation for 2 and up to
/// [_radix] inputs.
final Logic Function(List<Logic> inputs, {String name}) _operation;

/// Specified width of input to each reduction node (e.g., binary: radix=2)
late final int _radix;

/// When [_signExtend] is true, use sign-extension on values,
/// otherwise use zero-extension.
late final bool _signExtend;

/// Specified depth of nodes at which to flop (requires [_clk]).
late final int? _depthToFlop;

/// Optional [_clk] input to create pipeline.
late final Logic? _clk;

/// Optional [_reset] input to reset pipeline.
late final Logic? _reset;

/// Optional [_enable] input to enable pipeline.
late final Logic? _enable;

/// The input sequence
@protected
late final List<Logic> sequence;
late final List<Logic> _sequence;

/// Capture the record of compute: the final value, its depth (from last
/// flop or input), and its flopDepth if pipelined.
late final ({Logic value, int depth, int flopDepth}) _computed;

/// Local conditional flop using module reset/enable
Logic localFlop(Logic d, {bool doFlop = false}) =>
condFlop(doFlop ? clk : null, reset: reset, en: enable, d);
Logic _localFlop(Logic d, {bool doFlop = false}) =>
condFlop(doFlop ? _clk : null, reset: _reset, en: _enable, d);

/// Generate a tree based on dividing the input [sequence] of a node into
/// segments, recursively constructing [radix] child nodes to operate
/// on each segment.
/// - [sequence] is the input sequence to be reduced using the tree of
/// operations.
/// - Logic Function(List<Logic> inputs, {String name}) [operation]
/// is the operation to be performed at each node. Note that [operation]
/// - Logic Function(List<Logic> inputs, {String name}) [_operation]
/// is the operation to be performed at each node. Note that [_operation]
/// can widen the output. The logic function must support the operation for
/// (2 to [radix]) inputs.
/// - [radix] is the width of reduction at each node in the tree (e.g.,
Expand All @@ -85,70 +76,74 @@ class ReductionTree extends Module {
/// Optional parameters to be used for creating a pipelined computation tree:
/// - [clk], [reset], [enable] are optionally provided to allow for flopping.
/// - [depthToFlop] specifies how many nodes deep separate flops.
ReductionTree(List<Logic> sequence, this.operation,
{this.radix = 2,
this.signExtend = false,
this.depthToFlop,
ReductionTree(List<Logic> sequence, this._operation,
{int radix = 2,
bool signExtend = false,
int? depthToFlop,
Logic? clk,
Logic? enable,
Logic? reset,
super.name = 'reduction_tree'})
: super(
: _depthToFlop = depthToFlop,
_signExtend = signExtend,
_radix = radix,
super(
definitionName: 'ReductionTreeNode_R${radix}_L${sequence.length}') {
this.sequence = [
_sequence = [
for (var i = 0; i < sequence.length; i++)
addInput('seq$i', sequence[i], width: sequence[i].width)
];
this.clk = (clk != null) ? addInput('clk', clk) : null;
this.enable = (enable != null) ? addInput('enable', enable) : null;
this.reset = (reset != null) ? addInput('reset', reset) : null;
_clk = (clk != null) ? addInput('clk', clk) : null;
_enable = (enable != null) ? addInput('enable', enable) : null;
_reset = (reset != null) ? addInput('reset', reset) : null;

if (this.sequence.length <= radix) {
final value = operation(this.sequence);
_buildLogic();
}

/// Build out the recursive tree
void _buildLogic() {
if (_sequence.length <= _radix) {
final value = _operation(_sequence);
addOutput('out', width: value.width) <= value;
_computed = (value: output('out'), depth: 0, flopDepth: 0);
} else {
final results = <({Logic value, int depth, int flopDepth})>[];
final segment = this.sequence.length ~/ radix;

// final divisor = (log(sequence.length - 1) / log(radix)).floor();
// final segment = pow(radix, divisor).toInt();
final segment = _sequence.length ~/ _radix;

var pos = 0;
for (var i = 0; i < radix; i++) {
for (var i = 0; i < _radix; i++) {
final tree = ReductionTree(
this
.sequence
_sequence
.getRange(
pos, (i < radix - 1) ? pos + segment : this.sequence.length)
pos, (i < _radix - 1) ? pos + segment : _sequence.length)
.toList(),
operation,
radix: radix,
signExtend: signExtend,
depthToFlop: depthToFlop,
clk: this.clk,
enable: this.enable,
reset: this.reset);
_operation,
radix: _radix,
signExtend: _signExtend,
depthToFlop: _depthToFlop,
clk: _clk,
enable: _enable,
reset: _reset);
results.add(tree._computed);
pos += segment;
}
final flopDepth = results.map((c) => c.flopDepth).reduce(max);
final treeDepth = results.map((c) => c.depth).reduce(max);

final alignedResults = results
.map((c) => localFlop(c.value, doFlop: c.flopDepth < flopDepth));
.map((c) => _localFlop(c.value, doFlop: c.flopDepth < flopDepth));

final depthFlop = (depthToFlop != null) &&
(treeDepth > 0) & (treeDepth % depthToFlop! == 0);
final depthFlop = (_depthToFlop != null) &&
(treeDepth > 0) & (treeDepth % _depthToFlop! == 0);
final resultsFlop =
alignedResults.map((r) => localFlop(r, doFlop: depthFlop));
alignedResults.map((r) => _localFlop(r, doFlop: depthFlop));

final alignWidth = results.map((c) => c.value.width).reduce(max);
final resultsExtend = resultsFlop.map((r) =>
signExtend ? r.signExtend(alignWidth) : r.zeroExtend(alignWidth));
_signExtend ? r.signExtend(alignWidth) : r.zeroExtend(alignWidth));

final value = operation(resultsExtend.toList(),
name: 'reduce_d${(treeDepth + 1) + flopDepth * (depthToFlop ?? 0)}');
final value = _operation(resultsExtend.toList(),
name: 'reduce_d${(treeDepth + 1) + flopDepth * (_depthToFlop ?? 0)}');

addOutput('out', width: value.width) <= value;
_computed = (
Expand Down

0 comments on commit 4f101d8

Please sign in to comment.