From b44b85db672c6a25f437d1f0468ff54ccecffdb1 Mon Sep 17 00:00:00 2001 From: Desmond Kirkpatrick Date: Wed, 15 Jan 2025 12:10:06 -0800 Subject: [PATCH 1/3] Extensive fp cleanup, support Inf/NaN, addition of simple fp multiplier (#152) * extensive fp cleanup, support of Inf/NaN, addition of simple fp multiplier * use NativeAdder by default * cleanup clock/reset/enable init in FP base components --- doc/README.md | 3 + doc/components/adder.md | 27 +- doc/components/floating_point.md | 39 +- doc/components/shifter.md | 19 + lib/rohd_hcl.dart | 1 + lib/src/arithmetic/adder.dart | 21 +- lib/src/arithmetic/compound_adder.dart | 11 +- lib/src/arithmetic/float_to_fixed.dart | 4 +- .../floating_point/floating_point.dart | 5 +- .../floating_point/floating_point_adder.dart | 86 +++ .../floating_point_adder_round.dart | 240 +++---- .../floating_point_adder_simple.dart | 162 ++--- .../floating_point_multiplier.dart | 83 +++ .../floating_point_multiplier_simple.dart | 112 ++++ lib/src/arithmetic/multiplier.dart | 26 +- lib/src/arithmetic/ones_complement_adder.dart | 31 +- .../parallel_prefix_operations.dart | 52 +- .../signals/floating_point_logic.dart | 69 +- .../floating_point_8_value.dart | 49 +- .../floating_point_value.dart | 219 +++++-- .../components/component_registry.dart | 4 +- .../components/components.dart | 4 +- .../components/config_compound_adder.dart | 36 +- .../config_compression_tree_multiplier.dart | 32 +- .../config_floating_point_adder_round.dart | 25 +- .../config_floating_point_adder_simple.dart | 82 +++ ...nfig_floating_point_multiplier_simple.dart | 82 +++ .../config_parallel_prefix_adder.dart | 8 +- lib/src/signed_shifter.dart | 27 + lib/src/utils.dart | 33 +- test/arithmetic/adder_test.dart | 27 +- test/arithmetic/compound_adder_test.dart | 2 +- test/arithmetic/fixed_to_float_test.dart | 6 +- test/arithmetic/float_to_fixed_test.dart | 8 +- .../floating_point_adder_round_test.dart | 552 +++++++++++----- .../floating_point_adder_simple_test.dart | 614 ++++++++++-------- .../floating_point_adder_test.dart | 86 +++ .../floating_point_multiplier_test.dart | 267 ++++++++ .../floating_point_value_test.dart | 140 +++- test/arithmetic/multiplier_test.dart | 4 +- .../parallel_prefix_operations_test.dart | 20 +- test/signed_shifter_test.dart | 33 + 42 files changed, 2530 insertions(+), 821 deletions(-) create mode 100644 doc/components/shifter.md create mode 100644 lib/src/arithmetic/floating_point/floating_point_adder.dart create mode 100644 lib/src/arithmetic/floating_point/floating_point_multiplier.dart create mode 100644 lib/src/arithmetic/floating_point/floating_point_multiplier_simple.dart create mode 100644 lib/src/component_config/components/config_floating_point_adder_simple.dart create mode 100644 lib/src/component_config/components/config_floating_point_multiplier_simple.dart create mode 100644 lib/src/signed_shifter.dart create mode 100644 test/arithmetic/floating_point/floating_point_adder_test.dart create mode 100644 test/arithmetic/floating_point/floating_point_multiplier_test.dart create mode 100644 test/signed_shifter_test.dart diff --git a/doc/README.md b/doc/README.md index d1bd835e2..8dcf21fd7 100644 --- a/doc/README.md +++ b/doc/README.md @@ -42,6 +42,7 @@ Some in-development items will have opened issues, as well. Feel free to create - [Parallel Prefix Adder](./components/parallel_prefix_operations.md) - [Sign Magnitude Adder](./components/adder.md#sign-magnitude-adder) - [Compound Adder](./components/adder.md#compound-adder) + - [Native Adder](./components/adder.md#native-adder) - Subtractors - [Ones' Complement Adder Subtractor](./components/adder.md#ones-complement-adder-subtractor) - Multipliers @@ -63,11 +64,13 @@ Some in-development items will have opened issues, as well. Feel free to create - 8-bit E4/M3 and E5/M2 - [Simple Floating-Point Adder](./components/floating_point.md#floatingpointadder) - [Rounding Floating-Point Adder](./components/floating_point.md#floatingpointadder) + - [Simple Floating-Point Multiplier](./components/floating_point.md#floatingpointmultiplier) - [Fixed point](./components/fixed_point.md) - [FloatToFixed](./components/fixed_point.md#floattofixed) - [FixedToFloat](./components/fixed_point.md#fixedtofloat) - Binary-Coded Decimal (BCD) - [Rotate](./components/rotate.md) +- [SignedShifter](./components/shifter.md) - Counters - [Summation](./components/summation.md#sum) - [Binary counter](./components/summation.md#counter) diff --git a/doc/components/adder.md b/doc/components/adder.md index 0c162052f..18428ac93 100644 --- a/doc/components/adder.md +++ b/doc/components/adder.md @@ -4,9 +4,10 @@ ROHD-HCL provides a set of adder modules to get the sum from a pair of Logic. So - [Ripple Carry Adder](#ripple-carry-adder) - [Parallel Prefix Adder](#parallel-prefix-adder) -- [One's Complement Adder Subtractor](#ones-complement-adder-subtractor) +- [Ones' Complement Adder Subtractor](#ones-complement-adder-subtractor) - [Sign Magnitude Adder](#sign-magnitude-adder) - [Compound Adder](#compound-adder) +- [Native Adder](#native-adder) ## Ripple Carry Adder @@ -50,7 +51,7 @@ Here is an example of instantiating a [ParallelPrefixAdder](https://intel.github ## Ones' Complement Adder Subtractor -A ones-complement adder (and subtractor) is useful in efficient arithmetic operations as the +A ones'-complement adder (and subtractor) is useful in efficient arithmetic operations as the end-around carry can be bypassed and used later. The [OnesComplementAdder](https://intel.github.io/rohd-hcl/rohd_hcl/OnesComplementAdder-class.html) can take a subtraction command as either a `Logic` `subtractIn` or a boolean `subtract` (the Logic overrides the boolean). If Logic `carry` is provided, the end-around carry is output on `carry` and the value will be one less than expected when `carry` is high. An `adderGen` adder function can be provided that generates your favorite internal adder (such as a parallel prefix adder). @@ -76,7 +77,7 @@ Here is an example of instantiating a [OnesComplementAdder](https://intel.githu ## Sign Magnitude Adder -A sign magnitude adder is useful in situations where the sign of the addends is separated from their magnitude (e.g., not 2s complement), such as in floating point multipliers. The [SignMagnitudeAdder](https://intel.github.io/rohd-hcl/rohd_hcl/SignMagnitudeAdder-class.html) inherits from `Adder` but adds the `Logic` inputs for the two operands. +A sign magnitude adder is useful in situations where the sign of the addends is separated from their magnitude (e.g., not twos' complement), such as in floating point multipliers. The [SignMagnitudeAdder](https://intel.github.io/rohd-hcl/rohd_hcl/SignMagnitudeAdder-class.html) inherits from `Adder` but adds the `Logic` inputs for the two operands. If you can supply the largest magnitude number first, then you can disable a comparator generation inside by declaring the `largestMagnitudeFirst` option as true. @@ -137,3 +138,23 @@ final sum1 = rippleCarryAdder.sum1; final rippleCarryAdder4BitBlock = CarrySelectCompoundAdder(a, b, widthGen: CarrySelectCompoundAdder.splitSelectAdderAlgorithm4Bit); ``` + +## Native Adder + +As logic synthesis can replace a '+' in RTL with a wide variety of adder architectures on its own, we have a `NativeAdder` wrapper class that allows you to use the native '+' with any component that exposes an `Adder` functor as a parameter: + +```dart +// API definition: FloatingPointAdderRound(super.a, super.b, +// {Logic? subtract, +// super.clk, +// super.reset, +// super.enable, +// Adder Function(Logic, Logic, {Logic? carryIn}) adderGen = +// ParallelPrefixAdder.new, +// ParallelPrefix Function(List, Logic Function(Logic, Logic)) +// ppTree = KoggeStone.new, +// super.name = 'floating_point_adder_round'}) + +// Instantiate with a NativeAdder as the internal adder +final adder = FloatingPointAdderRound(a, b, adderGen: NativeAdder.new); +``` diff --git a/doc/components/floating_point.md b/doc/components/floating_point.md index fe5b6f322..3a97e5f75 100644 --- a/doc/components/floating_point.md +++ b/doc/components/floating_point.md @@ -1,6 +1,6 @@ # Floating-Point Components -Floating-point operations require meticulous precision, and have standards like [IEEE-754]() which govern them. To support floating-point components, we have created a parallel to [Logic](https://intel.github.io/rohd/rohd/Logic-class.html)/[LogicValue](https://intel.github.io/rohd/rohd/LogicValue-class.html) which are part of [ROHD](). Here, [FloatingPoint](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint-class.html) is the [Logic](https://pub.dev/documentation/rohd/latest/rohd/Logic-class.html) wire in a component that carries [FloatingPointValue](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPointValue-class.html) literal values, a subclass of [LogicValue](https://intel.github.io/rohd/rohd/LogicValue-class.html). An important distinction is that these classes are parameterized to create arbitrary size floating-point values. +Floating-point operations require meticulous precision, and have standards like [IEEE-754]() which govern them. To support floating-point components, we have created a parallel to [Logic](https://intel.github.io/rohd/rohd/Logic-class.html)/[LogicValue](https://intel.github.io/rohd/rohd/LogicValue-class.html) which are part of [ROHD](). Here, [FloatingPoint](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint-class.html) is the [Logic](https://intel.github.io/rohd/rohd/Logic-class.html) wire in a component that carries [FloatingPointValue](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPointValue-class.html) literal values, a subclass of [LogicValue](https://intel.github.io/rohd/rohd/LogicValue-class.html). An important distinction is that these classes are parameterized to create arbitrary size floating-point values. ## FloatingPointValue @@ -16,10 +16,31 @@ $$minExponent <= exponent <= maxExponent$$ And a mantissa in the range of $[1,2)$. Subnormal numbers are represented with a zero exponent and leading zeros in the mantissa capture the negative exponent value. -The various IEEE constants representing corner cases of the field of floating-point values for a given size of [FloatingPointValue](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPointValue-class.html): infinities, zeros, limits for normal (e.g. mantissa in the range of $[1,2)$ and sub-normal numbers (zero exponent, and mantissa <1). +Conversions from the native `double` are supported, both in rounded and unrounded forms. This is quite useful in testing narrower width floating point components leveraging the `double` native operations for validation. Appropriate string representations, comparison operations, and operators are available. The usefulness of [FloatingPointValue](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPointValue-class.html) is in the testing of [FloatingPoint](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint-class.html) components, where we can leverage the abstraction of a floating-point value type to drive and compare floating-point values operated upon by floating-point components. +### Floating Point Constants + +The various IEEE constants representing corner cases of the field of floating-point values for a given size of [FloatingPointValue](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPointValue-class.html): infinities, zeros, limits for normal (e.g. mantissa in the range of $[1,2)$ and sub-normal numbers (zero exponent, and mantissa <1). + +For any basic arbitrary width `FloatingPointValue` ROHD-HCL supports the following constants in that format. + +- `negativeInfinity`: smallest possible number +- `negativeZero`: The number zero, negative form +- `positiveZero`: The number zero, positive form +- `smallestPositiveSubnormal`: Smallest possible number, most exponent negative, LSB set in mantissa +- `largestPositiveSubnormal`: Largest possible subnormal, most negative exponent, mantissa all 1s +- `smallestPositiveNormal`: Smallest possible positive number, most negative exponent, mantissa is 0 +- `largestLessThanOne`: Largest number smaller than one +- `one`: The number one +- `smallestLargerThanOne`: Smallest number greater than one +- `largestNormal`: Largest positive number, most positive exponent, full mantissa +- `infinity`: Largest possible number: all 1s in the exponent, all 0s in the mantissa +- `nan`: Not a Number, demarked by all 1s in exponent and any 1 in mantissa (we use the LSB) + +### Special subtypes + As 64-bit double-precision and 32-bit single-precision floating-point types are most common, we have [FloatingPoint32Value](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint32Value-class.html) and [FloatingPoint64Value](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint64Value-class.html) subclasses with direct converters from Dart native [Double](https://api.dart.dev/stable/3.6.0/dart-core/double-class.html). Other special widths of floating-point values supported are: @@ -34,14 +55,24 @@ Finally, we have a [random value constructor](https://intel.github.io/rohd-hcl/r ## FloatingPoint -The [FloatingPoint](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint-class.html) type is a [LogicStructure](https://pub.dev/documentation/rohd/latest/rohd/LogicStructure-class.html) which comprises the [Logic](https://pub.dev/documentation/rohd/latest/rohd/Logic-class.html) bits for the sign, exponent, and mantissa used in hardware floating-point. This type is provided to simplify and abstract the declaration and manipulation of floating-point bits in hardware. This type is parameterized like [FloatingPointValue](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPointValue-class.html), for exponent and mantissa width. +The [FloatingPoint](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint-class.html) type is a [LogicStructure](https://intel.github.io/rohd/rohd/LogicStructure-class.html) which comprises the [Logic](https://intel.github.io/rohd/rohd/Logic-class.html) bits for the sign, exponent, and mantissa used in hardware floating-point. This type is provided to simplify and abstract the declaration and manipulation of floating-point bits in hardware. This type is parameterized like [FloatingPointValue](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPointValue-class.html), for exponent and mantissa width. Again, like [FloatingPointValue](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPointValue-class.html), [FloatingPoint64](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint64-class.html) and [FloatingPoint32](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint32-class.html) subclasses are provided as these are the most common floating-point number types. ## FloatingPointAdder -A very basic [FloatingPointAdderSimple](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPointAdderSimple-class.html) component is available which does not perform any rounding. It takes two [FloatingPoint](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint-class.html) [LogicStructure](https://pub.dev/documentation/rohd/latest/rohd/LogicStructure-class.html)s and adds them, returning a normalized [FloatingPoint](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint-class.html) on the output. An option on input is the type of ['ParallelPrefix'](https://intel.github.io/rohd-hcl/rohd_hcl/ParallelPrefix-class.html) used in the critical internal addition of the mantissas. +A very basic [FloatingPointAdderSimple](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPointAdderSimple-class.html) component is available which does not perform any rounding. It takes two [FloatingPoint](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint-class.html) [LogicStructure](https://intel.github.io/rohd/rohd/LogicStructure-class.html)s and adds them, returning a normalized [FloatingPoint](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint-class.html) on the output. An option on input is the type of ['ParallelPrefix'](https://intel.github.io/rohd-hcl/rohd_hcl/ParallelPrefix-class.html) used in the critical internal addition of the mantissas. Currently, the [FloatingPointAdderSimple](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPointAdderSimple-class.html) is close in accuracy (as it has no rounding) and is not optimized for circuit performance, but only provides the key functionalities of alignment, addition, and normalization. Still, this component is a starting point for more realistic floating-point components that leverage the logical [FloatingPoint](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint-class.html) and literal [FloatingPointValue](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPointValue-class.html) type abstractions. A second [FloatingPointAdderRound](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPointAdderRound-class.html) component is available which does perform rounding. It is based on "Delay-Optimized Implementation of IEEE Floating-Point Addition", by Peter-Michael Seidel and Guy Even, using an R-path and an N-path to process far-apart exponents and use rounding and an N-path for exponents within 2 and subtraction, which is exact. If you pass in an optional clock, a pipestage will be added to help optimize frequency; an optional reset and enable are can control the pipestage. + +## FloatingPointMultiplier + +A very basic [FloatingPointMultiplierSimple] component is available which does not perform any rounding. It takes two [FloatingPoint](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint-class.html) [LogicStructure](https://intel.github.io/rohd/rohd/LogicStructure-class.html)s and multiplies them, returning a normalized [FloatingPoint](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint-class.html) on the output 'product'. + +It has options to control its performance: + +- 'radix': used to specify the radix of the Booth encoder (default radix=4: options are [2,4,8,16])'. +- adderGen': used to specify the kind of [Adder] used for key functions like the mantissa addition. Defaults to [NativeAdder], but you can select a [ParallelPrefixAdder] of your choice. +- 'ppTree': used to specify the type of ['ParallelPrefix'](https://intel.github.io/rohd-hcl/rohd_hcl/ParallelPrefix-class.html) used in the pther critical functions like leading-one detect. diff --git a/doc/components/shifter.md b/doc/components/shifter.md new file mode 100644 index 000000000..e662a26e0 --- /dev/null +++ b/doc/components/shifter.md @@ -0,0 +1,19 @@ +# Shifter + +ROHD-HCL provides a component to perform shifting of a Logic based on an input Logic treated as signed. + +## SignedShifter + +The `SignedShifter` takes as input a Logic $shift$ and interprets $shift > 0$ as left-shift by the magnitude of $shift$ and right-shift by the magnitude of $shift$ if $shift < 0$. + +```dart + final bits = Const(16, width: 8); + print(bits.value.toRadixString()); + // Produces: 16'b1_0000 + final shift = Logic(width: 3); + final shifter = SignedShifter(bits, shift); + + shift.put(-1); + print(shifter.shifted.value.toRadixString()); + // Produces: 16'b1000 + ``` diff --git a/lib/rohd_hcl.dart b/lib/rohd_hcl.dart index 406acfde2..b12b6ead1 100644 --- a/lib/rohd_hcl.dart +++ b/lib/rohd_hcl.dart @@ -20,6 +20,7 @@ export 'src/models/models.dart'; export 'src/rotate.dart'; export 'src/serialization/serialization.dart'; export 'src/shift_register.dart'; +export 'src/signed_shifter.dart'; export 'src/sort.dart'; export 'src/summation/summation.dart'; export 'src/toggle_gate.dart'; diff --git a/lib/src/arithmetic/adder.dart b/lib/src/arithmetic/adder.dart index 01d94c08a..ca80e8ddc 100644 --- a/lib/src/arithmetic/adder.dart +++ b/lib/src/arithmetic/adder.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // adder.dart @@ -67,3 +67,22 @@ class FullAdder extends Adder { sum <= [carryIn! & (a ^ b) | a & b, (a ^ b) ^ carryIn!].swizzle(); } } + +/// A class which wraps the native '+' operator so that it can be passed +/// into other modules as a parameter for using the native operation. +class NativeAdder extends Adder { + /// The width of input [a] and [b] must be the same. + NativeAdder(super.a, super.b, {super.carryIn, super.name = 'native_adder'}) { + if (a.width != b.width) { + throw RohdHclException('inputs of a and b should have same width.'); + } + if (carryIn == null) { + sum <= a.zeroExtend(a.width + 1) + b.zeroExtend(b.width + 1); + } else { + sum <= + a.zeroExtend(a.width + 1) + + b.zeroExtend(b.width + 1) + + carryIn!.zeroExtend(a.width + 1); + } + } +} diff --git a/lib/src/arithmetic/compound_adder.dart b/lib/src/arithmetic/compound_adder.dart index 0776cf73e..19a0367a1 100644 --- a/lib/src/arithmetic/compound_adder.dart +++ b/lib/src/arithmetic/compound_adder.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // compound_adder.dart @@ -19,10 +19,14 @@ abstract class CompoundAdder extends Adder { /// Takes in input [a] and input [b] and return the [sum] of the addition /// result and [sum1] sum + 1. /// The width of input [a] and [b] must be the same. - CompoundAdder(super.a, super.b, {super.name = 'compound_adders'}) { + CompoundAdder(super.a, super.b, + {Logic? carryIn, super.name = 'compound_adders'}) { if (a.width != b.width) { throw RohdHclException('inputs of a and b should have same width.'); } + if (carryIn != null) { + throw RohdHclException("we don't support carryIn"); + } addOutput('sum1', width: a.width + 1); } } @@ -31,7 +35,7 @@ abstract class CompoundAdder extends Adder { class TrivialCompoundAdder extends CompoundAdder { /// Constructs a [CompoundAdder]. TrivialCompoundAdder(super.a, super.b, - {super.name = 'trivial_compound_adder'}) { + {super.carryIn, super.name = 'trivial_compound_adder'}) { sum <= a.zeroExtend(a.width + 1) + b.zeroExtend(b.width + 1); sum1 <= sum + 1; } @@ -68,6 +72,7 @@ class CarrySelectCompoundAdder extends CompoundAdder { CarrySelectCompoundAdder(super.a, super.b, {Adder Function(Logic a, Logic b, {Logic? carryIn, String name}) adderGen = ParallelPrefixAdder.new, + super.carryIn, super.name = 'cs_compound_adder', List Function(int) widthGen = splitSelectAdderAlgorithmSingleBlock}) { diff --git a/lib/src/arithmetic/float_to_fixed.dart b/lib/src/arithmetic/float_to_fixed.dart index 4f3721ac0..799b3417f 100644 --- a/lib/src/arithmetic/float_to_fixed.dart +++ b/lib/src/arithmetic/float_to_fixed.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // float_to_fixed.dart @@ -42,7 +42,7 @@ class FloatToFixed extends Module { n = bias + float.mantissa.width - 1; final outputWidth = m + n + 1; - final jBit = Logic(name: 'jBit')..gets(float.isNormal()); + final jBit = Logic(name: 'jBit')..gets(float.isNormal); final shift = Logic(name: 'shift', width: float.exponent.width) ..gets( mux(jBit, float.exponent - 1, Const(0, width: float.exponent.width))); diff --git a/lib/src/arithmetic/floating_point/floating_point.dart b/lib/src/arithmetic/floating_point/floating_point.dart index 3b850d9b6..299087ba6 100644 --- a/lib/src/arithmetic/floating_point/floating_point.dart +++ b/lib/src/arithmetic/floating_point/floating_point.dart @@ -1,5 +1,8 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause +export 'floating_point_adder.dart'; export 'floating_point_adder_round.dart'; export 'floating_point_adder_simple.dart'; +export 'floating_point_multiplier.dart'; +export 'floating_point_multiplier_simple.dart'; diff --git a/lib/src/arithmetic/floating_point/floating_point_adder.dart b/lib/src/arithmetic/floating_point/floating_point_adder.dart new file mode 100644 index 000000000..be070d533 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_adder.dart @@ -0,0 +1,86 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_adder.dart +// An abstract base class defining the API for floating-point adders. +// +// 2025 January 3 +// Author: Desmond A Kirkpatrick + ( + toSwap.$1.clone()..gets(mux(swap, toSwap.$2, toSwap.$1)), + toSwap.$2.clone()..gets(mux(swap, toSwap.$1, toSwap.$2)) + ); + + /// Pipelining helper that uses the context for signals clk/enable/reset + Logic localFlop(Logic input) => + condFlop(clk, input, en: enable, reset: reset); +} diff --git a/lib/src/arithmetic/floating_point/floating_point_adder_round.dart b/lib/src/arithmetic/floating_point/floating_point_adder_round.dart index 2d4c090a6..00de61b18 100644 --- a/lib/src/arithmetic/floating_point/floating_point_adder_round.dart +++ b/lib/src/arithmetic/floating_point/floating_point_adder_round.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // floating_point_adder_round.dart @@ -8,113 +8,52 @@ // Author: Desmond A Kirkpatrick - (clk == null) - ? d - : flop( - clk, - d, - en: en, - reset: reset, - resetValue: resetValue, - ); - /// An adder module for variable FloatingPoint type with rounding. // This is a Seidel/Even adder, dual-path implementation. -class FloatingPointAdderRound extends Module { - /// Must be greater than 0. - final int exponentWidth; - - /// Must be greater than 0. - final int mantissaWidth; - - /// The [clk]: if a valid clock signal is passed in, a pipestage is added to - /// the adder to help optimize frequency. - Logic? clk; - - /// Optional [reset], used only if a [clk] is not null to reset the pipeline - /// flops. - Logic? reset; - - /// Optional [enable], used only if a [clk] is not null to enable the pipeline - /// flops. - Logic? enable; - - /// Output [FloatingPoint] representing the sum of two input [FloatingPoint]s - late final FloatingPoint sum = - FloatingPoint(exponentWidth: exponentWidth, mantissaWidth: mantissaWidth) - ..gets(output('sum')); - - /// The result of [FloatingPoint] addition - @protected - late final FloatingPoint _sum = - FloatingPoint(exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - - /// Swapping two FloatingPoint structures based on a conditional - static (FloatingPoint, FloatingPoint) _swap( - Logic swap, (FloatingPoint, FloatingPoint) toSwap) => - ( - toSwap.$1.clone()..gets(mux(swap, toSwap.$2, toSwap.$1)), - toSwap.$2.clone()..gets(mux(swap, toSwap.$1, toSwap.$2)) - ); - +class FloatingPointAdderRound extends FloatingPointAdder { /// Add two floating point numbers [a] and [b], returning result in [sum]. /// [subtract] is an optional Logic input to do subtraction /// [adderGen] is an adder generator to be used in the primary adder /// functions. /// [ppTree] is an ParallelPrefix generator for use in increment /decrement /// functions. - FloatingPointAdderRound(FloatingPoint a, FloatingPoint b, + FloatingPointAdderRound(super.a, super.b, {Logic? subtract, - this.clk, - this.reset, - this.enable, - Adder Function(Logic, Logic, {Logic? carryIn}) adderGen = - ParallelPrefixAdder.new, - ParallelPrefix Function(List, Logic Function(Logic, Logic)) + super.clk, + super.reset, + super.enable, + Adder Function(Logic a, Logic b, {Logic? carryIn}) adderGen = + NativeAdder.new, + ParallelPrefix Function( + List inps, Logic Function(Logic term1, Logic term2) op) ppTree = KoggeStone.new, - super.name = 'floating_point_adder_round'}) - : exponentWidth = a.exponent.width, - mantissaWidth = a.mantissa.width { - if (b.exponent.width != exponentWidth || - b.mantissa.width != mantissaWidth) { - throw RohdHclException('FloatingPoint widths must match'); - } - if (clk != null) { - clk = addInput('clk', clk!); - } - if (reset != null) { - reset = addInput('reset', reset!); - } - if (enable != null) { - enable = addInput('enable', enable!); - } - a = a.clone()..gets(addInput('a', a, width: a.width)); - b = b.clone()..gets(addInput('b', b, width: b.width)); - addOutput('sum', width: _sum.width) <= _sum; + super.name = 'floating_point_adder_round'}) { + final outputSum = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + output('sum') <= outputSum; + + // Ensure that the larger number is wired as 'a' + final ae = this.a.exponent; + final be = this.b.exponent; + final am = this.a.mantissa; + final bm = this.b.mantissa; + final doSwap = ae.lt(be) | + (ae.eq(be) & am.lt(bm)) | + ((ae.eq(be) & am.eq(bm)) & this.a.sign); + + final FloatingPoint a; + final FloatingPoint b; + (a, b) = swap(doSwap, (super.a, super.b)); + + // Seidel: S.EFF = effectiveSubtraction + final effectiveSubtraction = a.sign ^ b.sign ^ (subtract ?? Const(0)); + final isNaN = a.isNaN | + b.isNaN | + (a.isInfinity & b.isInfinity & effectiveSubtraction); + final isInf = a.isInfinity | b.isInfinity; final exponentSubtractor = OnesComplementAdder(a.exponent, b.exponent, subtract: true, adderGen: adderGen, name: 'exponent_sub'); @@ -123,20 +62,17 @@ class FloatingPointAdderRound extends Module { final delta = exponentSubtractor.sum; // Seidel: (sl, el, fl) = larger; (ss, es, fs) = smaller - final (larger, smaller) = _swap(signDelta, (a, b)); + final (larger, smaller) = swap(signDelta, (a, b)); final fl = mux( - larger.isNormal(), - [larger.isNormal(), larger.mantissa].swizzle(), + larger.isNormal, + [larger.isNormal, larger.mantissa].swizzle(), [larger.mantissa, Const(0)].swizzle()); final fs = mux( - smaller.isNormal(), - [smaller.isNormal(), smaller.mantissa].swizzle(), + smaller.isNormal, + [smaller.isNormal, smaller.mantissa].swizzle(), [smaller.mantissa, Const(0)].swizzle()); - // Seidel: S.EFF = effectiveSubtraction - final effectiveSubtraction = a.sign ^ b.sign ^ (subtract ?? Const(0)); - // Seidel: flp larger preshift, normally in [2,4) final sigWidth = fl.width + 1; final largeShift = mux(effectiveSubtraction, fl.zeroExtend(sigWidth) << 1, @@ -162,21 +98,16 @@ class FloatingPointAdderRound extends Module { smallerAlignRPath.width - largeOperand.width); /// R Pipestage here: - final aIsNormalLatched = - condFlop(clk, a.isNormal(), en: enable, reset: reset); - final bIsNormalLatched = - condFlop(clk, b.isNormal(), en: enable, reset: reset); - final effectiveSubtractionLatched = - condFlop(clk, effectiveSubtraction, en: enable, reset: reset); - final largeOperandLatched = - condFlop(clk, largeOperand, en: enable, reset: reset); - final smallerOperandRPathLatched = - condFlop(clk, smallerOperandRPath, en: enable, reset: reset); - final smallerAlignRPathLatched = - condFlop(clk, smallerAlignRPath, en: enable, reset: reset); - final largerExpLatched = - condFlop(clk, larger.exponent, en: enable, reset: reset); - final deltaLatched = condFlop(clk, delta, en: enable, reset: reset); + final aIsNormalLatched = localFlop(a.isNormal); + final bIsNormalLatched = localFlop(b.isNormal); + final effectiveSubtractionLatched = localFlop(effectiveSubtraction); + final largeOperandLatched = localFlop(largeOperand); + final smallerOperandRPathLatched = localFlop(smallerOperandRPath); + final smallerAlignRPathLatched = localFlop(smallerAlignRPath); + final largerExpLatched = localFlop(larger.exponent); + final deltaLatched = localFlop(delta); + final isInfLatched = localFlop(isInf); + final isNaNLatched = localFlop(isNaN); final carryRPath = Logic(); final significandAdderRPath = OnesComplementAdder( @@ -278,26 +209,30 @@ class FloatingPointAdderRound extends Module { final significandNPath = significandSubtractorNPath.sum.slice(smallOperandNPath.width - 1, 0); - final leadOneNPath = mux( - significandNPath.or(), - ParallelPrefixPriorityEncoder(significandNPath.reversed, - ppGen: ppTree, name: 'npath_leadingOne') - .out - .zeroExtend(exponentWidth), - Const(15, width: exponentWidth)); + final validLeadOneNPath = Logic(); + final leadOneNPathPre = ParallelPrefixPriorityEncoder( + significandNPath.reversed, + ppGen: ppTree, + valid: validLeadOneNPath, + name: 'npath_leadingOne') + .out; + // Limit leadOne to exponent range and match widths + final leadOneNPath = (leadOneNPathPre.width > exponentWidth) + ? mux( + leadOneNPathPre + .gte(a.inf().exponent.zeroExtend(leadOneNPathPre.width)), + a.inf().exponent, + leadOneNPathPre.getRange(0, exponentWidth)) + : leadOneNPathPre.zeroExtend(exponentWidth); // N pipestage here: - final significandNPathLatched = - condFlop(clk, significandNPath, en: enable, reset: reset); - final significandSubtractorNPathSignLatched = condFlop( - clk, significandSubtractorNPath.sign, - en: enable, reset: reset); - final leadOneNPathLatched = - condFlop(clk, leadOneNPath, en: enable, reset: reset); - final largerSignLatched = - condFlop(clk, larger.sign, en: enable, reset: reset); - final smallerSignLatched = - condFlop(clk, smaller.sign, en: enable, reset: reset); + final significandNPathLatched = localFlop(significandNPath); + final significandSubtractorNPathSignLatched = + localFlop(significandSubtractorNPath.sign); + final leadOneNPathLatched = localFlop(leadOneNPath); + final validLeadOneNPathLatched = localFlop(validLeadOneNPath); + final largerSignLatched = localFlop(larger.sign); + final smallerSignLatched = localFlop(smaller.sign); final expCalcNPath = OnesComplementAdder( largerExpLatched, leadOneNPathLatched.zeroExtend(exponentWidth), @@ -307,7 +242,8 @@ class FloatingPointAdderRound extends Module { final preExpNPath = expCalcNPath.sum.slice(exponentWidth - 1, 0); - final posExpNPath = preExpNPath.or() & ~expCalcNPath.sign; + final posExpNPath = + preExpNPath.or() & ~expCalcNPath.sign & validLeadOneNPathLatched; final exponentNPath = mux(posExpNPath, preExpNPath, zeroExp); @@ -330,14 +266,26 @@ class FloatingPointAdderRound extends Module { final isR = deltaLatched.gte(Const(2, width: delta.width)) | ~effectiveSubtractionLatched; - _sum <= - mux( - isR, - [ - largerSignLatched, - exponentRPath, - mantissaRPath.slice(mantissaRPath.width - 2, 1) - ].swizzle(), - [signNPath, exponentNPath, finalSignificandNPath].swizzle()); + + Combinational([ + If(isNaNLatched, then: [ + outputSum < outputSum.nan, + ], orElse: [ + If(isInfLatched, then: [ + outputSum < outputSum.inf(sign: largerSignLatched), + ], orElse: [ + If(isR, then: [ + outputSum.sign < largerSignLatched, + outputSum.exponent < exponentRPath, + outputSum.mantissa < + mantissaRPath.slice(mantissaRPath.width - 2, 1), + ], orElse: [ + outputSum.sign < signNPath, + outputSum.exponent < exponentNPath, + outputSum.mantissa < finalSignificandNPath, + ]) + ]) + ]) + ]); } } diff --git a/lib/src/arithmetic/floating_point/floating_point_adder_simple.dart b/lib/src/arithmetic/floating_point/floating_point_adder_simple.dart index 4051e219c..b91ae5eb2 100644 --- a/lib/src/arithmetic/floating_point/floating_point_adder_simple.dart +++ b/lib/src/arithmetic/floating_point/floating_point_adder_simple.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // floating_point_adder_simple.dart @@ -7,98 +7,114 @@ // 2024 August 30 // Author: Desmond A Kirkpatrick inps, Logic Function(Logic term1, Logic term2) op) + ppTree = KoggeStone.new, + super.name = 'floatingpoint_adder_simple'}) + : super() { + final outputSum = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + output('sum') <= outputSum; - /// Output [FloatingPoint] computed - late final FloatingPoint sum = - FloatingPoint(exponentWidth: exponentWidth, mantissaWidth: mantissaWidth) - ..gets(output('sum')); + // Ensure that the larger number is wired as 'a' + final ae = this.a.exponent; + final be = this.b.exponent; + final am = this.a.mantissa; + final bm = this.b.mantissa; + final doSwap = ae.lt(be) | + (ae.eq(be) & am.lt(bm)) | + ((ae.eq(be) & am.eq(bm)) & super.a.sign); + final FloatingPoint a; + final FloatingPoint b; + (a, b) = swap(doSwap, (super.a, super.b)); - /// The result of [FloatingPoint] addition - @protected - late final FloatingPoint _sum = - FloatingPoint(exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final isInf = a.isInfinity | b.isInfinity; + final isNaN = + a.isNaN | b.isNaN | (a.isInfinity & b.isInfinity & (a.sign ^ b.sign)); - /// Swapping two FloatingPoint structures based on a conditional - static (FloatingPoint, FloatingPoint) _swap( - Logic swap, (FloatingPoint, FloatingPoint) toSwap) => - ( - toSwap.$1.clone()..gets(mux(swap, toSwap.$2, toSwap.$1)), - toSwap.$2.clone()..gets(mux(swap, toSwap.$1, toSwap.$2)) - ); + // Align and add mantissas + final expDiff = a.exponent - b.exponent; + final aMantissa = mux( + a.isNormal, + [Const(1), a.mantissa, Const(0, width: mantissaWidth + 1)].swizzle(), + [a.mantissa, Const(0, width: mantissaWidth + 2)].swizzle()); + final bMantissa = mux( + b.isNormal, + [Const(1), b.mantissa, Const(0, width: mantissaWidth + 1)].swizzle(), + [b.mantissa, Const(0, width: mantissaWidth + 2)].swizzle()); - /// Add two floating point numbers [a] and [b], returning result in [sum] - FloatingPointAdderSimple(FloatingPoint a, FloatingPoint b, - {ParallelPrefix Function(List, Logic Function(Logic, Logic)) - ppGen = KoggeStone.new, - super.name = 'floatingpoint_adder_simple'}) - : exponentWidth = a.exponent.width, - mantissaWidth = a.mantissa.width { - if (b.exponent.width != exponentWidth || - b.mantissa.width != mantissaWidth) { - throw RohdHclException('FloatingPoint widths must match'); - } - a = a.clone()..gets(addInput('a', a, width: a.width)); - b = b.clone()..gets(addInput('b', b, width: b.width)); - addOutput('sum', width: _sum.width) <= _sum; + final adder = SignMagnitudeAdder( + a.sign, aMantissa, b.sign, bMantissa >>> expDiff, adderGen); - // Ensure that the larger number is wired as 'a' - final doSwap = a.exponent.lt(b.exponent) | - (a.exponent.eq(b.exponent) & a.mantissa.lt(b.mantissa)) | - ((a.exponent.eq(b.exponent) & a.mantissa.eq(b.mantissa)) & b.sign); + final intSum = adder.sum.slice(adder.sum.width - 1, 0); - (a, b) = _swap(doSwap, (a, b)); + final aSignLatched = localFlop(a.sign); + final aExpLatched = localFlop(a.exponent); + final sumLatched = localFlop(intSum); + final isInfLatched = localFlop(isInf); + final isNaNLatched = localFlop(isNaN); - final aExp = - a.exponent + mux(a.isNormal(), a.zeroExponent(), a.oneExponent()); - final bExp = - b.exponent + mux(b.isNormal(), b.zeroExponent(), b.oneExponent()); + final mantissa = + sumLatched.reversed.getRange(0, min(intSum.width, intSum.width)); + final leadOneValid = Logic(); + final leadOnePre = ParallelPrefixPriorityEncoder(mantissa, + ppGen: ppTree, valid: leadOneValid) + .out; + // Limit leadOne to exponent range and match widths + final infExponent = outputSum.inf(sign: aSignLatched).exponent; + final leadOne = (leadOnePre.width > exponentWidth) + ? mux(leadOnePre.gte(infExponent.zeroExtend(leadOnePre.width)), + infExponent, leadOnePre.getRange(0, exponentWidth)) + : leadOnePre.zeroExtend(exponentWidth); - // Align and add mantissas - final expDiff = aExp - bExp; - final adder = SignMagnitudeAdder( - a.sign, - [a.isNormal(), a.mantissa].swizzle(), - b.sign, - [b.isNormal(), b.mantissa].swizzle() >>> expDiff, - (a, b, {carryIn}) => - ParallelPrefixAdder(a, b, carryIn: carryIn, ppGen: ppGen)); + final leadOneDominates = leadOne.gt(aExpLatched) | ~leadOneValid; + final outExp = + mux(leadOneDominates, a.zeroExponent, aExpLatched - leadOne + 1); - final sum = adder.sum.slice(adder.sum.width - 2, 0); - final leadOneE = - ParallelPrefixPriorityEncoder(sum.reversed, ppGen: ppGen).out; - final leadOne = leadOneE.zeroExtend(exponentWidth); + final realIsInf = isInfLatched | outExp.eq(infExponent); - // Assemble the output FloatingPoint - _sum.sign <= adder.sign; Combinational([ If.block([ - Iff(adder.sum[-1] & a.sign.eq(b.sign), [ - _sum.mantissa < (sum >> 1).slice(mantissaWidth - 1, 0), - _sum.exponent < a.exponent + 1 + Iff(isNaNLatched, [ + outputSum < outputSum.nan, ]), - ElseIf(a.exponent.gt(leadOne) & sum.or(), [ - _sum.mantissa < (sum << leadOne).slice(mantissaWidth - 1, 0), - _sum.exponent < a.exponent - leadOne + ElseIf(realIsInf, [ + // ROHD 0.6.0 trace error if we use the following + outputSum < outputSum.inf(sign: aSignLatched), + // outputSum.sign < aSignLatched, + // outputSum.exponent < infExponent, + // outputSum.mantissa < Const(0, width: mantissaWidth, fill: true), ]), - ElseIf(leadOne.eq(0) & sum.or(), [ - _sum.mantissa < (sum << leadOne).slice(mantissaWidth - 1, 0), - _sum.exponent < a.exponent - leadOne + 1 + ElseIf(leadOneDominates, [ + outputSum.sign < aSignLatched, + outputSum.exponent < a.zeroExponent, + outputSum.mantissa < + (sumLatched << aExpLatched + 1) + .getRange(intSum.width - mantissaWidth, intSum.width), ]), Else([ - // subnormal result - _sum.mantissa < sum.slice(mantissaWidth - 1, 0), - _sum.exponent < _sum.zeroExponent() + outputSum.sign < aSignLatched, + outputSum.exponent < aExpLatched - leadOne + 1, + outputSum.mantissa < + (sumLatched << leadOne + 1) + .getRange(intSum.width - mantissaWidth, intSum.width), ]) ]) ]); diff --git a/lib/src/arithmetic/floating_point/floating_point_multiplier.dart b/lib/src/arithmetic/floating_point/floating_point_multiplier.dart new file mode 100644 index 000000000..2a3f07408 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_multiplier.dart @@ -0,0 +1,83 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_multiplier_simple.dart +// Implementation of non-rounding floating-point multiplier +// +// 2025 January 3 +// Author: Desmond A Kirkpatrick , Logic Function(Logic, Logic)) ppGen = + KoggeStone.new, + super.name = 'floating_point_multiplier'}) + : exponentWidth = a.exponent.width, + mantissaWidth = a.mantissa.width { + if (b.exponent.width != exponentWidth || + b.mantissa.width != mantissaWidth) { + throw RohdHclException('FloatingPoint widths must match'); + } + this.clk = (clk != null) ? addInput('clk', clk) : clk; + this.enable = (enable != null) ? addInput('enable', enable) : enable; + this.reset = (reset != null) ? addInput('clk', reset) : reset; + + this.a = a.clone()..gets(addInput('a', a, width: a.width)); + this.b = b.clone()..gets(addInput('b', b, width: b.width)); + addOutput('product', width: a.exponent.width + a.mantissa.width + 1); + } + + /// Pipelining helper that uses the context for signals clk/enable/reset + Logic localFlop(Logic input) => + condFlop(clk, input, en: enable, reset: reset); +} diff --git a/lib/src/arithmetic/floating_point/floating_point_multiplier_simple.dart b/lib/src/arithmetic/floating_point/floating_point_multiplier_simple.dart new file mode 100644 index 000000000..ba5a7c722 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_multiplier_simple.dart @@ -0,0 +1,112 @@ +// Copyright (C) 2024-2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_multiplier_simple.dart +// Implementation of a non-rounding floating-point multiplier. +// +// 2024 December 30 +// Author: Desmond A Kirkpatrick inps, Logic Function(Logic term1, Logic term2) op) + ppTree = KoggeStone.new, + super.name}) { + final product = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + output('product') <= product; + final a = this.a; + final b = this.b; + + final aMantissa = mux(a.isNormal, [a.isNormal, a.mantissa].swizzle(), + [a.mantissa, Const(0)].swizzle()); + final bMantissa = mux(b.isNormal, [b.isNormal, b.mantissa].swizzle(), + [b.mantissa, Const(0)].swizzle()); + + final productExp = a.exponent.zeroExtend(exponentWidth + 2) + + b.exponent.zeroExtend(exponentWidth + 2) - + a.bias.zeroExtend(exponentWidth + 2); + + final pp = PartialProductGeneratorCompactRectSignExtension( + aMantissa, bMantissa, RadixEncoder(radix)); + final compressor = + ColumnCompressor(pp, clk: clk, reset: reset, enable: enable) + ..compress(); + final adder = adderGen(compressor.extractRow(0), compressor.extractRow(1)); + // Input mantissas have implicit lead: product mantissa width is (mw+1)*2) + final mantissa = adder.sum.getRange(0, (mantissaWidth + 1) * 2); + + final isInf = a.isInfinity | b.isInfinity; + final isNaN = a.isNaN | + b.isNaN | + ((a.isInfinity | b.isInfinity) & (a.isZero | b.isZero)); + + final productExpLatch = localFlop(productExp); + final aSignLatch = localFlop(a.sign); + final bSignLatch = localFlop(b.sign); + final isInfLatch = localFlop(isInf); + final isNaNLatch = localFlop(isNaN); + + final leadingOnePos = ParallelPrefixPriorityEncoder(mantissa.reversed, + ppGen: ppTree, name: 'leading_one_encoder') + .out + .zeroExtend(exponentWidth + 2); + + final shifter = SignedShifter( + mantissa, + mux(productExpLatch[-1] | productExpLatch.lt(leadingOnePos), + productExpLatch, leadingOnePos), + name: 'mantissa_shifter'); + + final remainingExp = productExpLatch - leadingOnePos + 1; + + final overFlow = isInfLatch | + (~remainingExp[-1] & + remainingExp.abs().gte(Const(1, width: exponentWidth, fill: true) + .zeroExtend(exponentWidth + 2))); + + Combinational([ + If(isNaNLatch, then: [ + product < product.nan, + ], orElse: [ + If(overFlow, then: [ + // TODO(desmonddak): use this line after trace issue is resolved + // product < product.inf(inSign: aSignLatch ^ bSignLatch), + product.sign < aSignLatch ^ bSignLatch, + product.exponent < product.nan.exponent, + product.mantissa < Const(0, width: mantissaWidth, fill: true), + ], orElse: [ + product.sign < aSignLatch ^ bSignLatch, + If(remainingExp[-1], then: [ + product.exponent < Const(0, width: exponentWidth) + ], orElse: [ + product.exponent < remainingExp.getRange(0, exponentWidth), + ]), + // Remove the leading one for implicit representation + product.mantissa < + shifter.shifted.getRange(-mantissaWidth - 1, mantissa.width - 1) + ]) + ]) + ]); + } +} diff --git a/lib/src/arithmetic/multiplier.dart b/lib/src/arithmetic/multiplier.dart index 88d903f16..2a6aca753 100644 --- a/lib/src/arithmetic/multiplier.dart +++ b/lib/src/arithmetic/multiplier.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // multiplier.dart @@ -204,7 +204,7 @@ class CompressionTreeMultiplier extends Multiplier { Logic get product => output('product'); /// Construct a compression tree integer multiplier with a given [radix] - /// and prefix tree functor [ppTree] for the compressor and final adder. + /// and an [Adder] generator functor [adderGen] for the final adder. /// /// Sign extension methodology is defined by the partial product generator /// supplied via [ppGen]. @@ -238,9 +238,9 @@ class CompressionTreeMultiplier extends Multiplier { super.signedMultiplier = false, super.selectSignedMultiplicand, super.selectSignedMultiplier, - ParallelPrefix Function(List, Logic Function(Logic, Logic)) - ppTree = KoggeStone.new, - PartialProductGenerator Function(Logic, Logic, RadixEncoder, + Adder Function(Logic a, Logic b, {Logic? carryIn}) adderGen = + NativeAdder.new, + PartialProductGenerator Function(Logic a, Logic b, RadixEncoder encoder, {required bool signedMultiplier, required bool signedMultiplicand, Logic? selectSignedMultiplier, @@ -265,9 +265,7 @@ class CompressionTreeMultiplier extends Multiplier { final compressor = ColumnCompressor(clk: clk, reset: reset, enable: enable, pp) ..compress(); - final adder = ParallelPrefixAdder( - compressor.extractRow(0), compressor.extractRow(1), - ppGen: ppTree); + final adder = adderGen(compressor.extractRow(0), compressor.extractRow(1)); product <= adder.sum.slice(a.width + b.width - 1, 0); } } @@ -291,7 +289,7 @@ class CompressionTreeMultiplyAccumulate extends MultiplyAccumulate { Logic get accumulate => output('accumulate'); /// Construct a compression tree integer multiply-add with a given [radix] - /// and prefix tree functor [ppTree] for the compressor and final adder. + /// and an [Adder] generator functor [adderGen] for the final adder. /// /// [a] and [b] are the product terms, [c] is the accumulate term which /// must be the sum of the widths plus 1. @@ -334,9 +332,9 @@ class CompressionTreeMultiplyAccumulate extends MultiplyAccumulate { super.selectSignedMultiplicand, super.selectSignedMultiplier, super.selectSignedAddend, - ParallelPrefix Function(List, Logic Function(Logic, Logic)) - ppTree = KoggeStone.new, - PartialProductGenerator Function(Logic, Logic, RadixEncoder, + Adder Function(Logic a, Logic b, {Logic? carryIn}) adderGen = + NativeAdder.new, + PartialProductGenerator Function(Logic a, Logic b, RadixEncoder encoder, {required bool signedMultiplier, required bool signedMultiplicand, Logic? selectSignedMultiplier, @@ -383,9 +381,7 @@ class CompressionTreeMultiplyAccumulate extends MultiplyAccumulate { final compressor = ColumnCompressor(clk: clk, reset: reset, enable: enable, pp) ..compress(); - final adder = ParallelPrefixAdder( - compressor.extractRow(0), compressor.extractRow(1), - ppGen: ppTree); + final adder = adderGen(compressor.extractRow(0), compressor.extractRow(1)); accumulate <= adder.sum.slice(a.width + b.width - 1 + 1, 0); } } diff --git a/lib/src/arithmetic/ones_complement_adder.dart b/lib/src/arithmetic/ones_complement_adder.dart index deec89ff6..09b918798 100644 --- a/lib/src/arithmetic/ones_complement_adder.dart +++ b/lib/src/arithmetic/ones_complement_adder.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // ones_complement_adder.dart @@ -25,17 +25,19 @@ class OnesComplementAdder extends Adder { @protected Logic _sign = Logic(); - /// [OnesComplementAdder] constructor with an adder functor [adderGen] - /// Either a Logic [subtractIn] or a boolean [subtract] can enable - /// subtraction, with [subtractIn] overriding [subtract]. If Logic [carryOut] - /// is provided as not null, then the end-around carry is not performed and is - /// left to the caller via the output [carryOut]. + /// [OnesComplementAdder] constructor with an adder functor [adderGen]. + /// - Either an optional Logic [subtractIn] or a boolean [subtract] can enable + /// subtraction, but providing both non-null will result in an exception. + /// - If Logic [carryOut] is provided as not null, then the end-around carry + /// is not performed and is provided as value on [carryOut]. + /// - [carryIn] allows for another adder to chain into this one. OnesComplementAdder(super.a, super.b, {Adder Function(Logic, Logic, {Logic? carryIn}) adderGen = ParallelPrefixAdder.new, Logic? subtractIn, Logic? carryOut, - bool subtract = false, + Logic? carryIn, + bool? subtract, super.name = 'ones_complement_adder'}) { if (subtractIn != null) { subtractIn = addInput('subtractIn', subtractIn); @@ -45,22 +47,25 @@ class OnesComplementAdder extends Adder { addOutput('carryOut'); carryOut <= this.carryOut!; } - if ((subtractIn != null) & subtract) { + if ((subtractIn != null) & (subtract != null)) { throw RohdHclException( - 'Subtraction is controlled by a non-null subtractIn: ' - 'subtract boolean is ignored'); + "either provide a Logic signal 'subtractIn' for runtime " + " configuration, or a boolean parameter 'subtract' for " + 'generation time configuration, but not both.'); } - final doSubtract = subtractIn ?? (subtract ? Const(1) : Const(0)); + final doSubtract = + subtractIn ?? (subtract != null ? Const(subtract) : Const(0)); final ax = a.zeroExtend(a.width); final bx = b.zeroExtend(b.width); - final adder = adderGen(ax, mux(doSubtract, ~bx, bx)); + final adder = + adderGen(ax, mux(doSubtract, ~bx, bx), carryIn: carryIn ?? Const(0)); if (this.carryOut != null) { this.carryOut! <= adder.sum[-1]; } - final endAround = mux(doSubtract, adder.sum[-1], Const(0)); + final endAround = adder.sum[-1]; final magnitude = adder.sum.slice(a.width - 1, 0); sum <= diff --git a/lib/src/arithmetic/parallel_prefix_operations.dart b/lib/src/arithmetic/parallel_prefix_operations.dart index bbbcdf0ee..12d60f3f8 100644 --- a/lib/src/arithmetic/parallel_prefix_operations.dart +++ b/lib/src/arithmetic/parallel_prefix_operations.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // parallel_prefix_operations.dart @@ -162,7 +162,8 @@ class ParallelPrefixOrScan extends Module { /// OrScan constructor ParallelPrefixOrScan(Logic inp, - {ParallelPrefix Function(List, Logic Function(Logic, Logic)) + {ParallelPrefix Function( + List inps, Logic Function(Logic term1, Logic term2) op) ppGen = KoggeStone.new, super.name = 'parallel_prefix_orscan'}) { inp = addInput('inp', inp, width: inp.width); @@ -179,7 +180,8 @@ class ParallelPrefixPriorityFinder extends Module { /// Priority Finder constructor ParallelPrefixPriorityFinder(Logic inp, - {ParallelPrefix Function(List, Logic Function(Logic, Logic)) + {ParallelPrefix Function( + List inps, Logic Function(Logic term1, Logic term2) op) ppGen = KoggeStone.new, super.name = 'parallel_prefix_finder'}) { inp = addInput('inp', inp, width: inp.width); @@ -190,19 +192,42 @@ class ParallelPrefixPriorityFinder extends Module { /// Priority Encoder based on ParallelPrefix tree class ParallelPrefixPriorityEncoder extends Module { - /// Output [out] is the bit position of the first '1' in the Logic input - /// Search is counted from the LSB + /// Output [out] is the bit position of the first '1' in the Logic input. + /// Search starts from the LSB. Logic get out => output('out'); + /// Optional output that says the encoded position is valid. + Logic? get valid => tryOutput('valid'); + /// PriorityEncoder constructor + /// - [ppGen] is the type of [ParallelPrefix] tree to use + /// - [valid] is an optional Logic output to raise if no '1' is found + /// + /// If there is a '1' in the [inp], the [ParallelPrefixPriorityEncoder] + /// sets [out] to the index of the position of the first '1' starting from + /// the LSb (and optionally sets [valid] to true). + /// + /// If there is no 1' in the [inp], it sets [out] to [inp].width + 1, + /// as well as setting optional [valid] to false. ParallelPrefixPriorityEncoder(Logic inp, - {ParallelPrefix Function(List, Logic Function(Logic, Logic)) + {ParallelPrefix Function( + List inps, Logic Function(Logic term1, Logic term2) op) ppGen = KoggeStone.new, + Logic? valid, super.name = 'parallel_prefix_encoder'}) { inp = addInput('inp', inp, width: inp.width); - addOutput('out', width: log2Ceil(inp.width)); + final sz = log2Ceil(inp.width + 1); + addOutput('out', width: sz); + if (valid != null) { + addOutput('valid'); + valid <= this.valid!; + } final u = ParallelPrefixPriorityFinder(inp, ppGen: ppGen); - out <= OneHotToBinary(u.out).binary; + final pos = OneHotToBinary(u.out).binary.zeroExtend(sz); + if (this.valid != null) { + this.valid! <= pos.or() | inp[0]; + } + out <= mux(pos.or() | inp[0], pos, Const(inp.width + 1, width: sz)); } } @@ -211,8 +236,9 @@ class ParallelPrefixAdder extends Adder { /// Adder constructor ParallelPrefixAdder(super.a, super.b, {super.carryIn, - ParallelPrefix Function(List, Logic Function(Logic, Logic)) ppGen = - KoggeStone.new, + ParallelPrefix Function( + List inps, Logic Function(Logic term1, Logic term2) op) + ppGen = KoggeStone.new, super.name = 'parallel_prefix_adder'}) { final l = List.generate(a.width - 1, (i) => [a[i + 1] & b[i + 1], a[i + 1] | b[i + 1]].swizzle()); @@ -243,7 +269,8 @@ class ParallelPrefixIncr extends Module { /// Increment constructor ParallelPrefixIncr(Logic inp, - {ParallelPrefix Function(List, Logic Function(Logic, Logic)) + {ParallelPrefix Function( + List inps, Logic Function(Logic term1, Logic term2) op) ppGen = KoggeStone.new, super.name = 'parallel_prefix_incr'}) { inp = addInput('inp', inp, width: inp.width); @@ -262,7 +289,8 @@ class ParallelPrefixDecr extends Module { /// Decrement constructor ParallelPrefixDecr(Logic inp, - {ParallelPrefix Function(List, Logic Function(Logic, Logic)) + {ParallelPrefix Function( + List inps, Logic Function(Logic term1, Logic term2) op) ppGen = KoggeStone.new, super.name = 'parallel_prefix_decr'}) { inp = addInput('inp', inp, width: inp.width); diff --git a/lib/src/arithmetic/signals/floating_point_logic.dart b/lib/src/arithmetic/signals/floating_point_logic.dart index 1d9d4f6cc..ebe33ae94 100644 --- a/lib/src/arithmetic/signals/floating_point_logic.dart +++ b/lib/src/arithmetic/signals/floating_point_logic.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // floating_point_logic.dart @@ -47,13 +47,53 @@ class FloatingPoint extends LogicStructure { /// Return a Logic true if this FloatingPoint contains a normal number, /// defined as having mantissa in the range [1,2) - Logic isNormal() => exponent.neq(LogicValue.zero.zeroExtend(exponent.width)); + late final Logic isNormal = Logic(name: 'isNormal', naming: Naming.mergeable) + ..gets(exponent.neq(LogicValue.zero.zeroExtend(exponent.width))); + + /// Return a Logic true if this FloatingPoint is Not a Number (NaN) + /// by having its exponent field set to the NaN value (typically all + /// ones) and a non-zero mantissa. + late final isNaN = Logic(name: 'isNaN', naming: Naming.mergeable) + ..gets(exponent.eq(floatingPointValue.nan.exponent) & mantissa.or()); + + /// Return a Logic true if this FloatingPoint is an infinity + /// by having its exponent field set to the NaN value (typically all + /// ones) and a zero mantissa. + late final isInfinity = Logic(name: 'isInfinity', naming: Naming.mergeable) + ..gets(exponent.eq(floatingPointValue.infinity.exponent) & ~mantissa.or()); + + /// Return a Logic true if this FloatingPoint is an zero + /// by having its exponent field set to the NaN value (typically all + /// ones) and a zero mantissa. + late final isZero = Logic(name: 'isZero', naming: Naming.mergeable) + ..gets(exponent.eq(floatingPointValue.zero.exponent) & ~mantissa.or()); /// Return the zero exponent representation for this type of FloatingPoint - Logic zeroExponent() => Const(LogicValue.zero).zeroExtend(exponent.width); + late final zeroExponent = Logic( + name: 'zeroExponent', naming: Naming.mergeable, width: exponent.width) + ..gets(Const(LogicValue.zero, width: exponent.width)); /// Return the one exponent representation for this type of FloatingPoint - Logic oneExponent() => Const(LogicValue.one).zeroExtend(exponent.width); + late final oneExponent = Logic( + name: 'oneExponent', naming: Naming.mergeable, width: exponent.width) + ..gets(Const(LogicValue.one, width: exponent.width)); + + /// Return the exponent Logic value representing the true zero exponent + /// 2^0 = 1 often termed [bias] or the offset of the stored exponent. + late final bias = + Logic(name: 'bias', naming: Naming.mergeable, width: exponent.width) + ..gets(Const((1 << exponent.width - 1) - 1, width: exponent.width)); + + /// Construct a FloatingPoint that represents infinity for this FP type. + FloatingPoint inf({Logic? sign, bool negative = false}) => FloatingPoint.inf( + exponentWidth: exponent.width, + mantissaWidth: mantissa.width, + sign: sign, + negative: negative); + + /// Construct a FloatingPoint that represents NaN for this FP type. + late final nan = FloatingPoint.nan( + exponentWidth: exponent.width, mantissaWidth: mantissa.width); @override void put(dynamic val, {bool fill = false}) { @@ -63,6 +103,27 @@ class FloatingPoint extends LogicStructure { super.put(val, fill: fill); } } + + /// Construct a FloatingPoint that represents infinity. + factory FloatingPoint.inf( + {required int exponentWidth, + required int mantissaWidth, + Logic? sign, + bool negative = false}) { + final signLogic = Logic()..gets(sign ?? Const(negative)); + final exponent = Const(1, width: exponentWidth, fill: true); + final mantissa = Const(0, width: mantissaWidth, fill: true); + return FloatingPoint._(signLogic, exponent, mantissa); + } + + /// Construct a FloatingPoint that represents NaN. + factory FloatingPoint.nan( + {required int exponentWidth, required int mantissaWidth}) { + final signLogic = Const(0); + final exponent = Const(1, width: exponentWidth, fill: true); + final mantissa = Const(1, width: mantissaWidth); + return FloatingPoint._(signLogic, exponent, mantissa); + } } /// Single floating point representation diff --git a/lib/src/arithmetic/values/floating_point_values/floating_point_8_value.dart b/lib/src/arithmetic/values/floating_point_values/floating_point_8_value.dart index 2d61f6d59..c8f5bd66f 100644 --- a/lib/src/arithmetic/values/floating_point_values/floating_point_8_value.dart +++ b/lib/src/arithmetic/values/floating_point_values/floating_point_8_value.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // floating_point_8_value.dart @@ -32,10 +32,15 @@ class FloatingPoint8E4M3Value extends FloatingPointValue { int get constrainedMantissaWidth => mantissaWidth; /// The maximum value representable by the E4M3 format - static double get maxValue => 448.toDouble(); + static double get maxValue => + FloatingPoint8E4M3Value.getFloatingPointConstant( + FloatingPointConstants.largestNormal) + .toDouble(); /// The minimum value representable by the E4M3 format - static double get minValue => pow(2, -9).toDouble(); + static double get minValue => FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.smallestPositiveSubnormal, 4, 3) + .toDouble(); /// Constructor for a double precision floating point value FloatingPoint8E4M3Value( @@ -70,6 +75,22 @@ class FloatingPoint8E4M3Value extends FloatingPointValue { : super.ofInts( exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + /// Inf is not representable in this format + @override + bool get isAnInfinity => false; + + @override + bool get isNaN => (exponent.toInt() == 15) && (mantissa.toInt() == 7); + + /// Override the toDouble to avoid NaN + @override + double toDouble() { + if (exponent.toInt() == 15) { + return 448; + } + return super.toDouble(); + } + /// Numeric conversion of a [FloatingPoint8E4M3Value] from a host double factory FloatingPoint8E4M3Value.ofDouble(double inDouble) { if ((inDouble.abs() > maxValue) | @@ -86,6 +107,28 @@ class FloatingPoint8E4M3Value extends FloatingPointValue { factory FloatingPoint8E4M3Value.ofLogicValue(LogicValue val) => FloatingPointValue.buildOfLogicValue( FloatingPoint8E4M3Value.new, exponentWidth, mantissaWidth, val); + + /// Return the [FloatingPointValue] representing the constant specified. + /// Special case for 8E4M3 type. + factory FloatingPoint8E4M3Value.getFloatingPointConstant( + FloatingPointConstants constantFloatingPoint) { + switch (constantFloatingPoint) { + /// Largest positive number, most positive exponent, full mantissa + case FloatingPointConstants.largestNormal: + return FloatingPoint8E4M3Value.ofBinaryStrings( + '0', '1' * exponentWidth, '${'1' * (mantissaWidth - 1)}0'); + case FloatingPointConstants.nan: + return FloatingPoint8E4M3Value.ofBinaryStrings( + '0', '${'1' * (exponentWidth - 1)}1', '1' * mantissaWidth); + case FloatingPointConstants.infinity: + case FloatingPointConstants.negativeInfinity: + throw RohdHclException('Infinity is not representable in this format'); + case _: + return FloatingPointValue.getFloatingPointConstant( + constantFloatingPoint, exponentWidth, mantissaWidth) + as FloatingPoint8E4M3Value; + } + } } /// The E5M2 representation of a 8-bit floating point value as defined in diff --git a/lib/src/arithmetic/values/floating_point_values/floating_point_value.dart b/lib/src/arithmetic/values/floating_point_values/floating_point_value.dart index 825758b34..1e1c1052f 100644 --- a/lib/src/arithmetic/values/floating_point_values/floating_point_value.dart +++ b/lib/src/arithmetic/values/floating_point_values/floating_point_value.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // floating_point_value.dart @@ -48,6 +48,9 @@ enum FloatingPointConstants { /// Largest possible number infinity, + + /// Not a Number, demarked by all 1s in exponent and any 1 in mantissa + nan, } /// IEEE Floating Point Rounding Modes @@ -299,6 +302,32 @@ class FloatingPointValue implements Comparable { mantissa: val.slice(mantissaWidth - 1, 0)); } + /// Abbreviation Functions for common constants + + /// Return the Infinity value for this FloatingPointValue size. + FloatingPointValue get infinity => + FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.infinity, exponent.width, mantissa.width); + + /// Return the Negative Infinity value for this FloatingPointValue size. + FloatingPointValue get negativeInfinity => + FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.negativeInfinity, + exponent.width, + mantissa.width); + + /// Return the Negative Infinity value for this FloatingPointValue size. + FloatingPointValue get nan => FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.nan, exponent.width, mantissa.width); + + /// Return the value one for this FloatingPointValue size. + FloatingPointValue get one => FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.one, exponent.width, mantissa.width); + + /// Return the Negative Infinity value for this FloatingPointValue size. + FloatingPointValue get zero => FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.positiveZero, exponent.width, mantissa.width); + /// Return the [FloatingPointValue] representing the constant specified factory FloatingPointValue.getFloatingPointConstant( FloatingPointConstants constantFloatingPoint, @@ -353,12 +382,17 @@ class FloatingPointValue implements Comparable { /// Largest positive number, most positive exponent, full mantissa case FloatingPointConstants.largestNormal: return FloatingPointValue.ofBinaryStrings( - '0', '0' * exponentWidth, '1' * mantissaWidth); + '0', '${'1' * (exponentWidth - 1)}0', '1' * mantissaWidth); /// Largest possible number case FloatingPointConstants.infinity: return FloatingPointValue.ofBinaryStrings( '0', '1' * exponentWidth, '0' * mantissaWidth); + + /// Not a Number (NaN) + case FloatingPointConstants.nan: + return FloatingPointValue.ofBinaryStrings( + '0', '1' * exponentWidth, '${'0' * (mantissaWidth - 1)}1'); } } @@ -375,6 +409,19 @@ class FloatingPointValue implements Comparable { return FloatingPoint64Value.ofDouble(inDouble); } + if (inDouble.isNaN) { + return FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.nan, exponentWidth, mantissaWidth); + } + if (inDouble.isInfinite) { + return FloatingPointValue.getFloatingPointConstant( + inDouble < 0.0 + ? FloatingPointConstants.negativeInfinity + : FloatingPointConstants.infinity, + exponentWidth, + mantissaWidth); + } + if (roundingMode != FloatingPointRoundingMode.roundNearestEven && roundingMode != FloatingPointRoundingMode.truncate) { throw UnimplementedError( @@ -457,16 +504,12 @@ class FloatingPointValue implements Comparable { } else if ((exponentWidth == 11) && (mantissaWidth == 52)) { return FloatingPoint64Value.ofDouble(inDouble); } - - var doubleVal = inDouble; if (inDouble.isNaN) { - return FloatingPointValue( - exponent: - LogicValue.ofInt(pow(2, exponentWidth).toInt() - 1, exponentWidth), - mantissa: LogicValue.zero, - sign: LogicValue.zero, - ); + return FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.nan, exponentWidth, mantissaWidth); } + + var doubleVal = inDouble; LogicValue sign; if (inDouble < 0.0) { doubleVal = -doubleVal; @@ -474,6 +517,14 @@ class FloatingPointValue implements Comparable { } else { sign = LogicValue.zero; } + if (inDouble.isInfinite) { + return FloatingPointValue.getFloatingPointConstant( + sign.toBool() + ? FloatingPointConstants.negativeInfinity + : FloatingPointConstants.infinity, + exponentWidth, + mantissaWidth); + } // If we are dealing with a really small number we need to scale it up var scaleToWhole = (doubleVal != 0) ? (-log(doubleVal) / log(2)).ceil() : 0; @@ -512,6 +563,15 @@ class FloatingPointValue implements Comparable { ? fullLength - mantissaWidth - scaleToWhole : FloatingPointValue.computeMinExponent(exponentWidth); + if (e > FloatingPointValue.computeMaxExponent(exponentWidth) + 1) { + return FloatingPointValue.getFloatingPointConstant( + sign.toBool() + ? FloatingPointConstants.negativeInfinity + : FloatingPointConstants.infinity, + exponentWidth, + mantissaWidth); + } + if (e <= -FloatingPointValue.computeBias(exponentWidth)) { fullValue = fullValue >>> (scaleToWhole - FloatingPointValue.computeBias(exponentWidth)); @@ -533,10 +593,7 @@ class FloatingPointValue implements Comparable { .reversed; return FloatingPointValue( - exponent: exponent, - mantissa: mantissa, - sign: sign, - ); + exponent: exponent, mantissa: mantissa, sign: sign); } @override @@ -566,47 +623,69 @@ class FloatingPointValue implements Comparable { return 0; } - /// Return the bias of this FP format - // int bias() => FloatingPointValue.computeBias(exponent.width); - @override bool operator ==(Object other) { if (other is! FloatingPointValue) { return false; } - if ((exponent.width != other.exponent.width) | (mantissa.width != other.mantissa.width)) { return false; } + if (isNaN != other.isNaN) { + return false; + } + if (isAnInfinity != other.isAnInfinity) { + return false; + } + if (isAnInfinity) { + return sign == other.sign; + } // IEEE 754: -0 an +0 are considered equal if ((exponent.isZero && mantissa.isZero) && (other.exponent.isZero && other.mantissa.isZero)) { return true; } - return (sign == other.sign) & (exponent == other.exponent) & (mantissa == other.mantissa); } - // TODO(desmonddak): figure out the difference with Infinity /// Return true if the represented floating point number is considered - /// NaN or 'Not a Number' due to overflow - bool isNaN() { - if ((exponent.width == 4) & (mantissa.width == 3)) { - // FP8 E4M3 does not support infinities - final cond1 = (1 + exponent.toInt()) == pow(2, exponent.width).toInt(); - final cond2 = (1 + mantissa.toInt()) == pow(2, mantissa.width).toInt(); - return cond1 & cond2; - } else { - return exponent.toInt() == - computeMaxExponent(exponent.width) + computeBias(exponent.width) + 1; - } - } + /// NaN or 'Not a Number' + bool get isNaN => + (exponent.toInt() == + computeMaxExponent(exponent.width) + + computeBias(exponent.width) + + 1) & + !mantissa.or().isZero; + + /// Return true if the represented floating point number is considered + /// infinity or negative infinity + bool get isAnInfinity => + (exponent.toInt() == + computeMaxExponent(exponent.width) + + computeBias(exponent.width) + + 1) & + mantissa.or().isZero; + + /// Return true if the represented floating point number is zero. Note + /// that the equality operator will treat + /// [FloatingPointConstants.positiveZero] + /// and [FloatingPointConstants.negativeZero] as equal. + bool get isZero => + this == + FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.positiveZero, exponent.width, mantissa.width); /// Return the value of the floating point number in a Dart [double] type. double toDouble() { + if (isNaN) { + return double.nan; + } + if (isAnInfinity) { + return sign.isZero ? double.infinity : double.negativeInfinity; + } var doubleVal = double.nan; if (value.isValid) { if (exponent.toInt() == 0) { @@ -614,7 +693,7 @@ class FloatingPointValue implements Comparable { pow(2.0, computeMinExponent(exponent.width)) * mantissa.toBigInt().toDouble() / pow(2.0, mantissa.width); - } else if (!isNaN()) { + } else if (!isNaN) { doubleVal = (sign.toBool() ? -1.0 : 1.0) * (1.0 + mantissa.toBigInt().toDouble() / pow(2.0, mantissa.width)) * pow( @@ -658,26 +737,84 @@ class FloatingPointValue implements Comparable { throw RohdHclException('FloatingPointValue: ' 'multiplicand must have the same mantissa and exponent widths'); } + if (isNaN | other.isNaN) { + return FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.nan, exponent.width, mantissa.width); + } return FloatingPointValue.ofDouble(op(toDouble(), other.toDouble()), mantissaWidth: mantissa.width, exponentWidth: exponent.width); } /// Multiply operation for [FloatingPointValue] - FloatingPointValue operator *(FloatingPointValue multiplicand) => - _performOp(multiplicand, (a, b) => a * b); + FloatingPointValue operator *(FloatingPointValue multiplicand) { + if (isAnInfinity) { + if (multiplicand.isAnInfinity) { + return sign != multiplicand.sign ? negativeInfinity : infinity; + } else if (multiplicand.isZero) { + return nan; + } else { + return this; + } + } else if (multiplicand.isAnInfinity) { + if (isZero) { + return nan; + } else { + return multiplicand; + } + } + return _performOp(multiplicand, (a, b) => a * b); + } /// Addition operation for [FloatingPointValue] - FloatingPointValue operator +(FloatingPointValue addend) => - _performOp(addend, (a, b) => a + b); + FloatingPointValue operator +(FloatingPointValue addend) { + if (isAnInfinity) { + if (addend.isAnInfinity) { + if (sign != addend.sign) { + return nan; + } else { + return sign.toBool() ? negativeInfinity : infinity; + } + } else { + return this; + } + } else if (addend.isAnInfinity) { + return addend; + } + return _performOp(addend, (a, b) => a + b); + } /// Divide operation for [FloatingPointValue] - FloatingPointValue operator /(FloatingPointValue divisor) => - _performOp(divisor, (a, b) => a / b); + FloatingPointValue operator /(FloatingPointValue divisor) { + if (isAnInfinity) { + if (divisor.isAnInfinity | divisor.isZero) { + return nan; + } else { + return this; + } + } else { + if (divisor.isZero) { + return sign != divisor.sign ? negativeInfinity : infinity; + } + } + return _performOp(divisor, (a, b) => a / b); + } /// Subtract operation for [FloatingPointValue] - FloatingPointValue operator -(FloatingPointValue subend) => - _performOp(subend, (a, b) => a - b); + FloatingPointValue operator -(FloatingPointValue subend) { + if (isAnInfinity & subend.isAnInfinity) { + if (sign == subend.sign) { + return nan; + } else { + return this; + } + } else if (subend.isAnInfinity) { + return subend.negate(); + } else if (isAnInfinity) { + return this; + } + return _performOp(subend, (a, b) => a - b); + } /// Negate operation for [FloatingPointValue] FloatingPointValue negate() => FloatingPointValue( diff --git a/lib/src/component_config/components/component_registry.dart b/lib/src/component_config/components/component_registry.dart index 81064fb50..83afc436e 100644 --- a/lib/src/component_config/components/component_registry.dart +++ b/lib/src/component_config/components/component_registry.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // component_registry.dart @@ -26,6 +26,8 @@ List get componentRegistry => [ EdgeDetectorConfigurator(), FindConfigurator(), FloatingPointAdderRoundConfigurator(), + FloatingPointAdderSimpleConfigurator(), + FloatingPointMultiplierSimpleConfigurator(), ParallelPrefixAdderConfigurator(), CompressionTreeMultiplierConfigurator(), ExtremaConfigurator(), diff --git a/lib/src/component_config/components/components.dart b/lib/src/component_config/components/components.dart index 22a438669..bfcb643ac 100644 --- a/lib/src/component_config/components/components.dart +++ b/lib/src/component_config/components/components.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause export 'config_carry_save_multiplier.dart'; @@ -13,6 +13,8 @@ export 'config_fixed_to_float.dart'; export 'config_float8_to_fixed.dart'; export 'config_float_to_fixed.dart'; export 'config_floating_point_adder_round.dart'; +export 'config_floating_point_adder_simple.dart'; +export 'config_floating_point_multiplier_simple.dart'; export 'config_one_hot.dart'; export 'config_parallel_prefix_adder.dart'; export 'config_priority_arbiter.dart'; diff --git a/lib/src/component_config/components/config_compound_adder.dart b/lib/src/component_config/components/config_compound_adder.dart index b4eb734f2..dadc08868 100644 --- a/lib/src/component_config/components/config_compound_adder.dart +++ b/lib/src/component_config/components/config_compound_adder.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // compound_adder.dart @@ -13,10 +13,35 @@ import 'package:rohd_hcl/rohd_hcl.dart'; /// A [Configurator] for [CompoundAdder]. class CompoundAdderConfigurator extends Configurator { + /// Map from Type to Function for Adder generator + static Map + adderGeneratorMap = { + Ripple: (a, b, {carryIn, name}) => + ParallelPrefixAdder(a, b, ppGen: Ripple.new, name: name!), + Sklansky: (a, b, {carryIn, name}) => + ParallelPrefixAdder(a, b, ppGen: Sklansky.new, name: name!), + KoggeStone: (a, b, {carryIn, name}) => + ParallelPrefixAdder(a, b, name: name!), + BrentKung: (a, b, {carryIn, name}) => + ParallelPrefixAdder(a, b, ppGen: BrentKung.new, name: name!), + NativeAdder: (a, b, {carryIn, name}) => + NativeAdder(a, b, carryIn: carryIn, name: name!) + }; + + /// Controls the type of [Adder] used for internal adders. + static final adderTypeKnob = + ChoiceConfigKnob(adderGeneratorMap.keys.toList(), value: NativeAdder); + /// Map from Type to Adder generator static Map generatorMap = { TrivialCompoundAdder: TrivialCompoundAdder.new, - CarrySelectCompoundAdder: CarrySelectCompoundAdder.new + CarrySelectCompoundAdder: (a, b, {Logic? carryIn}) => + CarrySelectCompoundAdder( + a, + b, + carryIn: carryIn, + adderGen: adderGeneratorMap[adderTypeKnob.value]!, + ) }; /// A knob controlling the width of the inputs to the adder. @@ -29,8 +54,11 @@ class CompoundAdderConfigurator extends Configurator { final String name = 'Compound Adder'; @override - late final Map> knobs = UnmodifiableMapView( - {'Width': logicWidthKnob, 'Adder Type': moduleTypeKnob}); + late final Map> knobs = UnmodifiableMapView({ + 'Width': logicWidthKnob, + 'Compound Adder Type': moduleTypeKnob, + 'Internal Adder Type': adderTypeKnob, + }); @override Module createModule() => generatorMap[moduleTypeKnob.value]!( diff --git a/lib/src/component_config/components/config_compression_tree_multiplier.dart b/lib/src/component_config/components/config_compression_tree_multiplier.dart index 5b9e77649..400a594ce 100644 --- a/lib/src/component_config/components/config_compression_tree_multiplier.dart +++ b/lib/src/component_config/components/config_compression_tree_multiplier.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // config_compression_tree_multiplier.dart @@ -14,26 +14,28 @@ import 'package:rohd_hcl/rohd_hcl.dart'; /// A [Configurator] for [CompressionTreeMultiplier]s. class CompressionTreeMultiplierConfigurator extends Configurator { - /// Map from Type to Function for Parallel Prefix generator - static Map, Logic Function(Logic, Logic))> - generatorMap = { - Ripple: Ripple.new, - Sklansky: Sklansky.new, - KoggeStone: KoggeStone.new, - BrentKung: BrentKung.new + /// Map from Type to Function for Adder generator + static Map + adderGeneratorMap = { + Ripple: (a, b, {carryIn}) => ParallelPrefixAdder(a, b, ppGen: Ripple.new), + Sklansky: (a, b, {carryIn}) => + ParallelPrefixAdder(a, b, ppGen: Sklansky.new), + KoggeStone: ParallelPrefixAdder.new, + BrentKung: (a, b, {carryIn}) => + ParallelPrefixAdder(a, b, ppGen: BrentKung.new), + NativeAdder: (a, b, {carryIn}) => NativeAdder(a, b, carryIn: carryIn) }; - /// Controls the type of [ParallelPrefix] tree used in the adder. - final prefixTreeKnob = - ChoiceConfigKnob(generatorMap.keys.toList(), value: KoggeStone); - /// Controls the Booth encoding radix of the multiplier.! final radixKnob = ChoiceConfigKnob( [2, 4, 8, 16], value: 4, ); + /// Controls the type of [Adder] used for internal adders. + final adderTypeKnob = + ChoiceConfigKnob(adderGeneratorMap.keys.toList(), value: NativeAdder); + /// Controls the width of the multiplicand.! final IntConfigKnob multiplicandWidthKnob = IntConfigKnob(value: 5); @@ -63,11 +65,11 @@ class CompressionTreeMultiplierConfigurator extends Configurator { signMultiplicandValueKnob.value == 'selected' ? Logic() : null, selectSignedMultiplier: signMultiplierValueKnob.value == 'selected' ? Logic() : null, - ppTree: generatorMap[prefixTreeKnob.value]!); + adderGen: adderGeneratorMap[adderTypeKnob.value]!); @override late final Map> knobs = UnmodifiableMapView({ - 'Tree type': prefixTreeKnob, + 'Adder type': adderTypeKnob, 'Radix': radixKnob, 'Multiplicand width': multiplicandWidthKnob, 'Multiplicand sign': signMultiplicandValueKnob, diff --git a/lib/src/component_config/components/config_floating_point_adder_round.dart b/lib/src/component_config/components/config_floating_point_adder_round.dart index 2ca5a5283..6ad431662 100644 --- a/lib/src/component_config/components/config_floating_point_adder_round.dart +++ b/lib/src/component_config/components/config_floating_point_adder_round.dart @@ -1,8 +1,8 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // -// config_floating_point_adder.dart -// Configurator for a Floating-Point Adder. +// config_floating_point_adder_round.dart +// Configurator for a rounding Floating-Point adder. // // 2024 October 11 // Author: Desmond Kirkpatrick @@ -22,12 +22,15 @@ class FloatingPointAdderRoundConfigurator extends Configurator { ParallelPrefixAdder(a, b, ppGen: Sklansky.new), KoggeStone: ParallelPrefixAdder.new, BrentKung: (a, b, {carryIn}) => - ParallelPrefixAdder(a, b, ppGen: BrentKung.new) + ParallelPrefixAdder(a, b, ppGen: BrentKung.new), + NativeAdder: (a, b, {carryIn}) => NativeAdder(a, b, carryIn: carryIn) }; /// Map from Type to Function for Parallel Prefix generator - static Map, Logic Function(Logic, Logic))> + static Map< + Type, + ParallelPrefix Function( + List inps, Logic Function(Logic term1, Logic term2) op)> treeGeneratorMap = { Ripple: Ripple.new, Sklansky: Sklansky.new, @@ -35,9 +38,9 @@ class FloatingPointAdderRoundConfigurator extends Configurator { BrentKung: BrentKung.new }; - /// Controls the type of [ParallelPrefix] tree used in internal adders. - final adderTreeKnob = - ChoiceConfigKnob(adderGeneratorMap.keys.toList(), value: KoggeStone); + /// Controls the type of [Adder] used for internal adders. + final adderTypeKnob = + ChoiceConfigKnob(adderGeneratorMap.keys.toList(), value: NativeAdder); /// Controls the type of [ParallelPrefix] tree used in the other functions. final prefixTreeKnob = @@ -62,13 +65,13 @@ class FloatingPointAdderRoundConfigurator extends Configurator { FloatingPoint( exponentWidth: exponentWidthKnob.value, mantissaWidth: mantissaWidthKnob.value), - adderGen: adderGeneratorMap[adderTreeKnob.value]!, + adderGen: adderGeneratorMap[adderTypeKnob.value]!, ppTree: treeGeneratorMap[prefixTreeKnob.value]!); @override late final Map> knobs = UnmodifiableMapView({ 'Prefix tree type': prefixTreeKnob, - 'Adder tree type': adderTreeKnob, + 'Adder tree type': adderTypeKnob, 'Exponent width': exponentWidthKnob, 'Mantissa width': mantissaWidthKnob, 'Pipelined': pipelinedKnob, diff --git a/lib/src/component_config/components/config_floating_point_adder_simple.dart b/lib/src/component_config/components/config_floating_point_adder_simple.dart new file mode 100644 index 000000000..c6f8de393 --- /dev/null +++ b/lib/src/component_config/components/config_floating_point_adder_simple.dart @@ -0,0 +1,82 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// config_floating_point_adder_simple.dart +// Configurator for a simple Floating-Point adder. +// +// 2025 January 9 +// Author: Desmond Kirkpatrick + +import 'dart:collection'; + +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// A [Configurator] for [FloatingPointAdderSimple]s. +class FloatingPointAdderSimpleConfigurator extends Configurator { + /// Map from Type to Function for Adder generator + static Map + adderGeneratorMap = { + Ripple: (a, b, {carryIn}) => ParallelPrefixAdder(a, b, ppGen: Ripple.new), + Sklansky: (a, b, {carryIn}) => + ParallelPrefixAdder(a, b, ppGen: Sklansky.new), + KoggeStone: ParallelPrefixAdder.new, + BrentKung: (a, b, {carryIn}) => + ParallelPrefixAdder(a, b, ppGen: BrentKung.new), + NativeAdder: (a, b, {carryIn}) => NativeAdder(a, b, carryIn: carryIn) + }; + + /// Map from Type to Function for Parallel Prefix generator + static Map< + Type, + ParallelPrefix Function( + List inps, Logic Function(Logic term1, Logic term2) op)> + treeGeneratorMap = { + Ripple: Ripple.new, + Sklansky: Sklansky.new, + KoggeStone: KoggeStone.new, + BrentKung: BrentKung.new + }; + + /// Controls the type of [Adder] used for internal adders. + final adderTypeKnob = + ChoiceConfigKnob(adderGeneratorMap.keys.toList(), value: NativeAdder); + + /// Controls the type of [ParallelPrefix] tree used in the other functions. + final prefixTreeKnob = + ChoiceConfigKnob(treeGeneratorMap.keys.toList(), value: KoggeStone); + + /// Controls the width of the exponent. + final IntConfigKnob exponentWidthKnob = IntConfigKnob(value: 4); + + /// Controls the width of the mantissa. + final IntConfigKnob mantissaWidthKnob = IntConfigKnob(value: 5); + + /// Controls whether the adder is pipelined + final ToggleConfigKnob pipelinedKnob = ToggleConfigKnob(value: false); + + @override + Module createModule() => FloatingPointAdderSimple( + clk: pipelinedKnob.value ? Logic() : null, + FloatingPoint( + exponentWidth: exponentWidthKnob.value, + mantissaWidth: mantissaWidthKnob.value, + ), + FloatingPoint( + exponentWidth: exponentWidthKnob.value, + mantissaWidth: mantissaWidthKnob.value), + adderGen: adderGeneratorMap[adderTypeKnob.value]!, + ppTree: treeGeneratorMap[prefixTreeKnob.value]!); + + @override + late final Map> knobs = UnmodifiableMapView({ + 'Prefix tree type': prefixTreeKnob, + 'Adder tree type': adderTypeKnob, + 'Exponent width': exponentWidthKnob, + 'Mantissa width': mantissaWidthKnob, + 'Pipelined': pipelinedKnob, + }); + + @override + final String name = 'Floating-Point Simple Adder'; +} diff --git a/lib/src/component_config/components/config_floating_point_multiplier_simple.dart b/lib/src/component_config/components/config_floating_point_multiplier_simple.dart new file mode 100644 index 000000000..63620ca6c --- /dev/null +++ b/lib/src/component_config/components/config_floating_point_multiplier_simple.dart @@ -0,0 +1,82 @@ +// Copyright (C) 2023-2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// config_floating_point_multiplier_simple.dart +// Configurator for a simple Floating-Point multiplier. +// +// 2025 January 6 +// Author: Desmond Kirkpatrick + +import 'dart:collection'; + +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// A [Configurator] for [FloatingPointMultiplierSimple]s. +class FloatingPointMultiplierSimpleConfigurator extends Configurator { + /// Map from Type to Function for Adder generator + static Map + adderGeneratorMap = { + Ripple: (a, b, {carryIn}) => ParallelPrefixAdder(a, b, ppGen: Ripple.new), + Sklansky: (a, b, {carryIn}) => + ParallelPrefixAdder(a, b, ppGen: Sklansky.new), + KoggeStone: ParallelPrefixAdder.new, + BrentKung: (a, b, {carryIn}) => + ParallelPrefixAdder(a, b, ppGen: BrentKung.new), + NativeAdder: (a, b, {carryIn}) => NativeAdder(a, b, carryIn: carryIn) + }; + + /// Map from Type to Function for Parallel Prefix generator + static Map< + Type, + ParallelPrefix Function( + List inps, Logic Function(Logic term1, Logic term2) op)> + treeGeneratorMap = { + Ripple: Ripple.new, + Sklansky: Sklansky.new, + KoggeStone: KoggeStone.new, + BrentKung: BrentKung.new + }; + + /// Controls the type of [Adder] used for internal adders. + final adderTypeKnob = + ChoiceConfigKnob(adderGeneratorMap.keys.toList(), value: NativeAdder); + + /// Controls the type of [ParallelPrefix] tree used in the internal functions. + final prefixTreeKnob = + ChoiceConfigKnob(treeGeneratorMap.keys.toList(), value: KoggeStone); + + /// Controls the width of the exponent. + final IntConfigKnob exponentWidthKnob = IntConfigKnob(value: 4); + + /// Controls the width of the mantissa. + final IntConfigKnob mantissaWidthKnob = IntConfigKnob(value: 5); + + /// Controls whether the multiplier is pipelined + final ToggleConfigKnob pipelinedKnob = ToggleConfigKnob(value: false); + + @override + Module createModule() => FloatingPointMultiplierSimple( + clk: pipelinedKnob.value ? Logic() : null, + FloatingPoint( + exponentWidth: exponentWidthKnob.value, + mantissaWidth: mantissaWidthKnob.value, + ), + FloatingPoint( + exponentWidth: exponentWidthKnob.value, + mantissaWidth: mantissaWidthKnob.value), + adderGen: adderGeneratorMap[adderTypeKnob.value]!, + ppTree: treeGeneratorMap[prefixTreeKnob.value]!); + + @override + late final Map> knobs = UnmodifiableMapView({ + 'Adder type': adderTypeKnob, + 'Prefix tree type': prefixTreeKnob, + 'Exponent width': exponentWidthKnob, + 'Mantissa width': mantissaWidthKnob, + 'Pipelined': pipelinedKnob, + }); + + @override + final String name = 'Floating-Point Simple Multiplier'; +} diff --git a/lib/src/component_config/components/config_parallel_prefix_adder.dart b/lib/src/component_config/components/config_parallel_prefix_adder.dart index 5a8998e22..cd30d987b 100644 --- a/lib/src/component_config/components/config_parallel_prefix_adder.dart +++ b/lib/src/component_config/components/config_parallel_prefix_adder.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // config_parallel-prefix_adder.dart @@ -15,8 +15,10 @@ import 'package:rohd_hcl/rohd_hcl.dart'; /// A [Configurator] for [ParallelPrefixAdder]s. class ParallelPrefixAdderConfigurator extends Configurator { /// Map from Type to Function for Parallel Prefix generator - static Map, Logic Function(Logic, Logic))> + static Map< + Type, + ParallelPrefix Function( + List inps, Logic Function(Logic term1, Logic term2) op)> generatorMap = { Ripple: Ripple.new, Sklansky: Sklansky.new, diff --git a/lib/src/signed_shifter.dart b/lib/src/signed_shifter.dart new file mode 100644 index 000000000..8f3820e99 --- /dev/null +++ b/lib/src/signed_shifter.dart @@ -0,0 +1,27 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// signed_shifter.dart +// Implementation of bidirectional shifter. +// +// 2025 January 8 +// Author: Desmond Kirkpatrick + +import 'package:rohd/rohd.dart'; + +/// A bit shifter that takes a positive or negative shift amount +class SignedShifter extends Module { + /// The output [shifted] bits + Logic get shifted => output('shifted'); + + /// Create a [SignedShifter] that treats shift as signed + /// - [bits] is the input to be shifted + /// - [shift] is the signed amount to be shifted + SignedShifter(Logic bits, Logic shift, {super.name = 'shifter'}) { + bits = addInput('bits', bits, width: bits.width); + shift = addInput('shift', shift, width: shift.width); + + addOutput('shifted', width: bits.width); + shifted <= mux(shift[-1], bits >>> shift.abs(), bits << shift); + } +} diff --git a/lib/src/utils.dart b/lib/src/utils.dart index 102149e64..585cae441 100644 --- a/lib/src/utils.dart +++ b/lib/src/utils.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // utils.dart @@ -63,3 +63,34 @@ extension SignedBigInt on BigInt { ? BigInt.from(value).toSigned(width) : BigInt.from(value).toUnsigned(width); } + +/// Conditionally constructs a positive edge triggered flip condFlop on [clk]. +/// +/// It returns either [FlipFlop.q] if [clk] is non-null or [d] if not. +/// +/// When the optional [en] is provided, an additional input will be created for +/// condFlop. If optional [en] is high or not provided, output will vary as per +/// input[d]. For low [en], output remains frozen irrespective of input [d]. +/// +/// - When the optional [reset] is provided, the condFlop will be reset +/// (active-high). +/// - If no [resetValue] is provided, the reset value is always `0`. Otherwise, +/// it will reset to the provided [resetValue]. +/// - If [asyncReset] is true, the [reset] signal (if provided) will be treated +/// as an async reset. If [asyncReset] is false, the reset signal will be +/// treated as synchronous. +Logic condFlop( + Logic? clk, + Logic d, { + Logic? en, + Logic? reset, + dynamic resetValue, + bool asyncReset = false, +}) => + (clk == null) + ? d + : flop(clk, d, + en: en, + reset: reset, + resetValue: resetValue, + asyncReset: asyncReset); diff --git a/test/arithmetic/adder_test.dart b/test/arithmetic/adder_test.dart index 3aad74554..1b8c157b0 100644 --- a/test/arithmetic/adder_test.dart +++ b/test/arithmetic/adder_test.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // adder_test.dart @@ -194,10 +194,17 @@ void main() { final generators = [Ripple.new, Sklansky.new, KoggeStone.new, BrentKung.new]; + final adders = [ + RippleCarryAdder.new, + NativeAdder.new, + ]; + group('adder random', () { for (final n in [63, 64, 65]) { for (final testCin in [false, true]) { - testAdderRandom(n, 30, RippleCarryAdder.new, testCarryIn: testCin); + for (final adder in adders) { + testAdderRandom(n, 30, adder, testCarryIn: testCin); + } for (final ppGen in generators) { testAdderRandom( n, @@ -311,6 +318,22 @@ void main() { } }); + test('ones complement subtractor', () { + const width = 5; + final a = Logic(width: width); + final b = Logic(width: width); + + const subtract = true; + const av = 1; + const bv = 6; + + a.put(av); + b.put(bv); + final adder = OnesComplementAdder(a, b, subtract: subtract); + expect(adder.sum.value.toInt(), equals(bv - av)); + expect(adder.sign.value, LogicValue.one); + }); + test('ones complement with Logic subtract', () { const width = 2; final a = Logic(width: width); diff --git a/test/arithmetic/compound_adder_test.dart b/test/arithmetic/compound_adder_test.dart index bcf71aebb..1fef91d59 100644 --- a/test/arithmetic/compound_adder_test.dart +++ b/test/arithmetic/compound_adder_test.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // compound_adder_test.dart diff --git a/test/arithmetic/fixed_to_float_test.dart b/test/arithmetic/fixed_to_float_test.dart index 3f5c73fcf..8564a9174 100644 --- a/test/arithmetic/fixed_to_float_test.dart +++ b/test/arithmetic/fixed_to_float_test.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // fixed_to_float_test.dart @@ -139,8 +139,8 @@ void main() async { } }); - // Test is skipped as FloatingPointValue.ofDouble does not handle infinities. - // TODO(desmonddak): + // TODO(desmonddak): complete this test as now + // FloatingPointValue.ofDouble handles infinities. test('Signed Q7.0 to E3M2', () async { final fixed = FixedPoint(signed: true, m: 7, n: 0); final dut = FixedToFloat(fixed, exponentWidth: 3, mantissaWidth: 2); diff --git a/test/arithmetic/float_to_fixed_test.dart b/test/arithmetic/float_to_fixed_test.dart index 239147fb5..aad8008ba 100644 --- a/test/arithmetic/float_to_fixed_test.dart +++ b/test/arithmetic/float_to_fixed_test.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // float_to_fixed_test.dart @@ -20,7 +20,7 @@ void main() async { for (var val = 0; val < pow(2, 8); val++) { final fpv = FloatingPointValue.ofLogicValue( 5, 2, LogicValue.ofInt(val, float.width)); - if (!fpv.isNaN()) { + if (!fpv.isAnInfinity & !fpv.isNaN) { float.put(fpv); final fxp = dut.fixed; final fxpExp = FixedPointValue.ofDouble(fpv.toDouble(), @@ -41,7 +41,7 @@ void main() async { for (var val = 0; val < pow(2, 8); val++) { final fp8 = FloatingPointValue.ofLogicValue( 4, 3, LogicValue.ofInt(val, float.width)); - if (!fp8.isNaN()) { + if (!fp8.isNaN & !fp8.isAnInfinity) { float.put(fp8.value); final fx8 = FixedPointValue.ofDouble(fp8.toDouble(), signed: true, m: 23, n: 9); @@ -55,7 +55,7 @@ void main() async { for (var val = 0; val < pow(2, 8); val++) { final fp8 = FloatingPointValue.ofLogicValue( 5, 2, LogicValue.ofInt(val, float.width)); - if (!fp8.isNaN()) { + if (!fp8.isNaN & !fp8.isAnInfinity) { float.put(fp8.value); final fx8 = FixedPointValue.ofDouble(fp8.toDouble(), signed: true, m: 16, n: 16); diff --git a/test/arithmetic/floating_point/floating_point_adder_round_test.dart b/test/arithmetic/floating_point/floating_point_adder_round_test.dart index e454a4ae6..dc962351a 100644 --- a/test/arithmetic/floating_point/floating_point_adder_round_test.dart +++ b/test/arithmetic/floating_point/floating_point_adder_round_test.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // floating_point_rnd_test.dart @@ -17,78 +17,77 @@ void main() { tearDown(() async { await Simulator.reset(); }); - test('FP: singleton N path', () async { - final clk = SimpleClockGenerator(10).clk; - - const eWidth = 4; - const mWidth = 5; - final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - - final fva = FloatingPointValue.ofInts(14, 31, - exponentWidth: eWidth, mantissaWidth: mWidth); - final fvb = FloatingPointValue.ofInts(13, 7, - exponentWidth: eWidth, mantissaWidth: mWidth, sign: true); - - fa.put(fva); - fb.put(fvb); + test('FP: rounding adder singleton N path', () async { + const exponentWidth = 4; + const mantissawidth = 5; + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissawidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissawidth); + + final fv1 = FloatingPointValue.ofInts(14, 31, + exponentWidth: exponentWidth, mantissaWidth: mantissawidth); + final fv2 = FloatingPointValue.ofInts(13, 7, + exponentWidth: exponentWidth, mantissaWidth: mantissawidth, sign: true); + + fp1.put(fv1); + fp2.put(fv2); final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( - fva.toDouble() + fvb.toDouble(), - exponentWidth: eWidth, - mantissaWidth: mWidth); + fv1.toDouble() + fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissawidth); final expected = expectedNoRound; - final adder = FloatingPointAdderRound(fa, fb, clk: clk); + final adder = FloatingPointAdderRound(fp1, fp2); unawaited(Simulator.run()); - await clk.nextNegedge; - fa.put(0); - fb.put(0); final computed = adder.sum.floatingPointValue; - expect(computed.isNaN(), equals(expected.isNaN())); expect(computed, equals(expected)); await Simulator.endSimulation(); }); - test('FP: N path, subtraction, delta < 2', () async { - const eWidth = 3; - const mWidth = 5; + test('FP: rounding adder N path, subtraction, delta < 2', () async { + const exponentWidth = 3; + const mantissaWidth = 5; final one = FloatingPointValue.getFloatingPointConstant( - FloatingPointConstants.one, eWidth, mWidth); - final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - fa.put(one); - fb.put(one); - final adder = FloatingPointAdderRound(fa, fb); + FloatingPointConstants.one, exponentWidth, mantissaWidth); + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(one); + fp2.put(one); + final adder = FloatingPointAdderRound(fp1, fp2); await adder.build(); unawaited(Simulator.run()); - final largestExponent = FloatingPointValue.computeBias(eWidth) + - FloatingPointValue.computeMaxExponent(eWidth); - final largestMantissa = pow(2, mWidth).toInt() - 1; - for (var i = 0; i <= largestExponent; i++) { - for (var j = 0; j <= largestExponent; j++) { - if ((i - j).abs() < 2) { - for (var ii = 0; ii <= largestMantissa; ii++) { - for (var jj = 0; jj <= largestMantissa; jj++) { - final fva = FloatingPointValue.ofInts(i, ii, - exponentWidth: eWidth, mantissaWidth: mWidth); - final fvb = FloatingPointValue.ofInts(j, jj, - exponentWidth: eWidth, mantissaWidth: mWidth, sign: true); - - fa.put(fva); - fb.put(fvb); + final largestExponent = FloatingPointValue.computeBias(exponentWidth) + + FloatingPointValue.computeMaxExponent(exponentWidth); + final largestMantissa = pow(2, mantissaWidth).toInt() - 1; + for (var e1 = 0; e1 <= largestExponent; e1++) { + for (var e2 = 0; e2 <= largestExponent; e2++) { + if ((e1 - e2).abs() < 2) { + for (var m1 = 0; m1 <= largestMantissa; m1++) { + final fv1 = FloatingPointValue.ofInts(e1, m1, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + for (var m2 = 0; m2 <= largestMantissa; m2++) { + final fv2 = FloatingPointValue.ofInts(e2, m2, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + sign: true); + + fp1.put(fv1); + fp2.put(fv2); // No rounding final expected = FloatingPointValue.ofDoubleUnrounded( - fva.toDouble() + fvb.toDouble(), - exponentWidth: eWidth, - mantissaWidth: mWidth); + fv1.toDouble() + fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); final computed = adder.sum.floatingPointValue; - expect(computed.isNaN(), equals(expected.isNaN())); expect(computed, equals(expected)); } } @@ -98,68 +97,78 @@ void main() { await Simulator.endSimulation(); }); - test('FP: singleton R path', () async { + test('FP: rounding adder singleton R path', () async { final clk = SimpleClockGenerator(10).clk; - const eWidth = 4; - const mWidth = 5; - final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - fa.put(0); - fb.put(0); - - final fva = FloatingPointValue.ofInts(3, 11, - exponentWidth: eWidth, mantissaWidth: mWidth); - final fvb = FloatingPointValue.ofInts(11, 25, - exponentWidth: eWidth, mantissaWidth: mWidth, sign: true); - - fa.put(fva); - fb.put(fvb); - - final expected = fva + fvb; - final adder = FloatingPointAdderRound(clk: clk, fa, fb); + const exponentWidth = 4; + const mantissaWidth = 5; + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(0); + fp2.put(0); + + final fv1 = FloatingPointValue.ofInts(3, 11, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = FloatingPointValue.ofInts(11, 25, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth, sign: true); + + fp1.put(fv1); + fp2.put(fv2); + + final expected = fv1 + fv2; + final adder = FloatingPointAdderRound(clk: clk, fp1, fp2); await adder.build(); unawaited(Simulator.run()); await clk.nextNegedge; - fa.put(0); - fb.put(0); + fp1.put(0); + fp2.put(0); final computed = adder.sum.floatingPointValue; - expect(computed.isNaN(), equals(expected.isNaN())); + expect(computed.isNaN, equals(expected.isNaN)); expect(computed, equals(expected)); await Simulator.endSimulation(); }); - test('FP: R path, strict subnormal', () async { - const eWidth = 4; - const mWidth = 5; - - final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - fa.put(0); - fb.put(0); - final adder = FloatingPointAdderRound(fa, fb); + test('FP: rounding adder R path, strict subnormal', () async { + const exponentWidth = 4; + const mantissaWidth = 5; + + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(0); + fp2.put(0); + final adder = FloatingPointAdderRound(fp1, fp2); await adder.build(); unawaited(Simulator.run()); - final largestMantissa = pow(2, mWidth).toInt() - 1; - for (final sign in [false]) { - for (var i = 0; i <= 1; i++) { - for (var j = 0; j <= 1; j++) { - if (!sign || (i - j).abs() >= 2) { - for (var ii = 0; ii <= largestMantissa; ii++) { - for (var jj = 0; jj <= largestMantissa; jj++) { - final fva = FloatingPointValue.ofInts(i, ii, - exponentWidth: eWidth, mantissaWidth: mWidth); - final fvb = FloatingPointValue.ofInts(j, jj, - exponentWidth: eWidth, mantissaWidth: mWidth, sign: sign); - - fa.put(fva); - fb.put(fvb); - final expected = fva + fvb; + final largestMantissa = pow(2, mantissaWidth).toInt() - 1; + for (final sign in [false, true]) { + for (var e1 = 0; e1 <= 1; e1++) { + for (var e2 = 0; e2 <= 1; e2++) { + if (!sign || (e1 - e2).abs() >= 2) { + for (var m1 = 0; m1 <= largestMantissa; m1++) { + final fv1 = FloatingPointValue.ofInts(e1, m1, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + for (var m2 = 0; m2 <= largestMantissa; m2++) { + final fv2 = FloatingPointValue.ofInts(e2, m2, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + sign: sign); + + fp1.put(fv1); + fp2.put(fv2); + final expected = fv1 + fv2; final computed = adder.sum.floatingPointValue; - expect(computed.isNaN(), equals(expected.isNaN())); - expect(computed, equals(expected)); + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); } } } @@ -169,112 +178,333 @@ void main() { await Simulator.endSimulation(); }); - test('FP: R path, full random', () async { + test('FP: rounding adder R path, full random', () async { final clk = SimpleClockGenerator(10).clk; - const eWidth = 3; - const mWidth = 5; + const exponentWidth = 3; + const mantissaWidth = 5; - final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - fa.put(0); - fb.put(0); - final adder = FloatingPointAdderRound(clk: clk, fa, fb); + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(0); + fp2.put(0); + final adder = FloatingPointAdderRound(clk: clk, fp1, fp2); await adder.build(); unawaited(Simulator.run()); final value = Random(47); var cnt = 200; while (cnt > 0) { - final fva = FloatingPointValue.random(value, - exponentWidth: eWidth, mantissaWidth: mWidth); - final fvb = FloatingPointValue.random(value, - exponentWidth: eWidth, mantissaWidth: mWidth); - fa.put(fva); - fb.put(fvb); - if ((fva.exponent.toInt() - fvb.exponent.toInt()).abs() >= 2) { + final fv1 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(fv1); + fp2.put(fv2); + if ((fv1.exponent.toInt() - fv2.exponent.toInt()).abs() >= 2) { cnt--; - final expected = fva + fvb; + final expected = fv1 + fv2; await clk.nextNegedge; - fa.put(0); - fb.put(0); + fp1.put(0); + fp2.put(0); final computed = adder.sum.floatingPointValue; - expect(computed.isNaN(), equals(expected.isNaN())); - expect(computed, equals(expected)); + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); } } await Simulator.endSimulation(); }); - test('FP: singleton merged path', () async { + test('FP: rounding adder singleton merged pipelined path', () async { final clk = SimpleClockGenerator(10).clk; - const eWidth = 3; - const mWidth = 5; - final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - fa.put(0); - fb.put(0); - final fva = FloatingPointValue.ofInts(14, 31, - exponentWidth: eWidth, mantissaWidth: mWidth); - final fvb = FloatingPointValue.ofInts(13, 7, - exponentWidth: eWidth, mantissaWidth: mWidth, sign: true); - fa.put(fva); - fb.put(fvb); + const exponentWidth = 3; + const mantissaWidth = 5; + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(0); + fp2.put(0); + final fv1 = FloatingPointValue.ofInts(14, 31, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = FloatingPointValue.ofInts(13, 7, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth, sign: true); + fp1.put(fv1); + fp2.put(fv2); final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( - fva.toDouble() + fvb.toDouble(), - exponentWidth: eWidth, - mantissaWidth: mWidth); + fv1.toDouble() + fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); final FloatingPointValue expected; - final expectedRound = fva + fvb; - if (((fva.exponent.toInt() - fvb.exponent.toInt()).abs() < 2) & - (fva.sign.toInt() != fvb.sign.toInt())) { + final expectedRound = fv1 + fv2; + if (((fv1.exponent.toInt() - fv2.exponent.toInt()).abs() < 2) & + (fv1.sign.toInt() != fv2.sign.toInt())) { expected = expectedNoRound; } else { expected = expectedRound; } - final adder = FloatingPointAdderRound(clk: clk, fa, fb); + final adder = FloatingPointAdderRound(clk: clk, fp1, fp2); await adder.build(); unawaited(Simulator.run()); await clk.nextNegedge; - fa.put(0); - fb.put(0); + fp1.put(0); + fp2.put(0); final computed = adder.sum.floatingPointValue; - expect(computed.isNaN(), equals(expected.isNaN())); - expect(computed, equals(expected)); + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); await Simulator.endSimulation(); }); - test('FP: full random wide', () async { - const eWidth = 11; - const mWidth = 52; - - final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - fa.put(0); - fb.put(0); - final adder = FloatingPointAdderRound(fa, fb); + test('FP: rounding adder full random wide', () async { + const exponentWidth = 11; + const mantissaWidth = 52; + + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(0); + fp2.put(0); + final adder = FloatingPointAdderRound(fp1, fp2); await adder.build(); unawaited(Simulator.run()); final value = Random(51); var cnt = 100; while (cnt > 0) { - final fva = FloatingPointValue.random(value, - exponentWidth: eWidth, mantissaWidth: mWidth); - final fvb = FloatingPointValue.random(value, - exponentWidth: eWidth, mantissaWidth: mWidth); - fa.put(fva); - fb.put(fvb); - final expected = fva + fvb; + final fv1 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(fv1); + fp2.put(fv2); + final expected = fv1 + fv2; final computed = adder.sum.floatingPointValue; - expect(computed.isNaN(), equals(expected.isNaN())); - expect(computed, equals(expected)); + expect(computed.isNaN, equals(expected.isNaN)); + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); cnt--; } await Simulator.endSimulation(); }); + + test('FP: rounding adder singleton merged path', () async { + const exponentWidth = 3; + const mantissaWidth = 5; + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(0); + fp2.put(0); + final fv1 = FloatingPointValue.ofInts(14, 31, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = FloatingPointValue.ofInts(13, 7, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth, sign: true); + fp1.put(fv1); + fp2.put(fv2); + + final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( + fv1.toDouble() + fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + + final FloatingPointValue expected; + final expectedRound = fv1 + fv2; + if (((fv1.exponent.toInt() - fv2.exponent.toInt()).abs() < 2) & + (fv1.sign.toInt() != fv2.sign.toInt())) { + expected = expectedNoRound; + } else { + expected = expectedRound; + } + final adder = FloatingPointAdderRound(fp1, fp2); + + final computed = adder.sum.floatingPointValue; + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); + }); + + test('FP: rounding adder singleton', () async { + const exponentWidth = 4; + const mantissaWidth = 4; + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(0); + fp2.put(0); + final fv1 = FloatingPointValue.ofBinaryStrings('0', '1100', '0000'); + final fv2 = FloatingPointValue.ofBinaryStrings('1', '1100', '0000'); + + fp1.put(fv1); + fp2.put(fv2); + + final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( + fv1.toDouble() + fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + + final FloatingPointValue expected; + final expectedRound = fv1 + fv2; + if (((fv1.exponent.toInt() - fv2.exponent.toInt()).abs() < 2) & + (fv1.sign.toInt() != fv2.sign.toInt())) { + expected = expectedNoRound; + } else { + expected = expectedRound; + } + final adder = FloatingPointAdderRound(fp1, fp2); + + final computed = adder.sum.floatingPointValue; + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); + }); + + test('FP: rounding adder exhaustive', () { + const exponentWidth = 4; + const mantissaWidth = 4; + + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final adder = FloatingPointAdderRound(fp1, fp2); + + final expLimit = pow(2, exponentWidth); + final mantLimit = pow(2, mantissaWidth); + for (final subtract in [0, 1]) { + for (var e1 = 0; e1 < expLimit; e1++) { + for (var m1 = 0; m1 < mantLimit; m1++) { + final fv1 = FloatingPointValue.ofInts(e1, m1, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + for (var e2 = 0; e2 < expLimit; e2++) { + for (var m2 = 0; m2 < mantLimit; m2++) { + final fv2 = FloatingPointValue.ofInts(e2, m2, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + sign: subtract == 1); + + fp1.put(fv1.value); + fp2.put(fv2.value); + final computed = adder.sum.floatingPointValue; + final expectedDouble = fv1.toDouble() + fv2.toDouble(); + + final FloatingPointValue expected; + if ((subtract == 1) & + ((fv1.exponent.toInt() - fv2.exponent.toInt()).abs() < 2)) { + expected = FloatingPointValue.ofDoubleUnrounded(expectedDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + } else { + expected = FloatingPointValue.ofDouble(expectedDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + } + + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); + } + } + } + } + } + }); + test('FP: rounding adder general singleton test', () { + FloatingPointValue ofString(String s) => + FloatingPointValue.ofSpacedBinaryString(s); + + final fv1 = ofString('0 001 111111'); + final fv2 = ofString('1 010 000000'); + + final fp1 = FloatingPoint( + exponentWidth: fv1.exponent.width, mantissaWidth: fv1.mantissa.width); + final fp2 = FloatingPoint( + exponentWidth: fv2.exponent.width, mantissaWidth: fv2.mantissa.width); + fp1.put(fv1); + fp2.put(fv2); + final adder = FloatingPointAdderRound(fp1, fp2); + final exponentWidth = adder.sum.exponent.width; + final mantissaWidth = adder.sum.mantissa.width; + + final expectedDouble = + fp1.floatingPointValue.toDouble() + fp2.floatingPointValue.toDouble(); + + final expectedNoRound = FloatingPointValue.ofDoubleUnrounded(expectedDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + expect(adder.sum.floatingPointValue, equals(expectedNoRound)); + }); + test('FP: rounding with prefix adder', () async { + final clk = SimpleClockGenerator(10).clk; + + const eWidth = 3; + const mWidth = 5; + final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + fa.put(0); + fb.put(0); + final fv1 = FloatingPointValue.ofInts(14, 31, + exponentWidth: eWidth, mantissaWidth: mWidth); + final fv2 = FloatingPointValue.ofInts(13, 7, + exponentWidth: eWidth, mantissaWidth: mWidth, sign: true); + fa.put(fv1); + fb.put(fv2); + + final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( + fv1.toDouble() + fv2.toDouble(), + exponentWidth: eWidth, + mantissaWidth: mWidth); + + final FloatingPointValue expected; + final expectedRound = fv1 + fv2; + if (((fv1.exponent.toInt() - fv2.exponent.toInt()).abs() < 2) & + (fv1.sign.toInt() != fv2.sign.toInt())) { + expected = expectedNoRound; + } else { + expected = expectedRound; + } + final adder = FloatingPointAdderRound( + clk: clk, fa, fb, adderGen: ParallelPrefixAdder.new); + await adder.build(); + unawaited(Simulator.run()); + await clk.nextNegedge; + fa.put(0); + fb.put(0); + + final computed = adder.sum.floatingPointValue; + expect(computed.isNaN, equals(expected.isNaN)); + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); + await Simulator.endSimulation(); + }); } diff --git a/test/arithmetic/floating_point/floating_point_adder_simple_test.dart b/test/arithmetic/floating_point/floating_point_adder_simple_test.dart index c5b0f17ee..4ce667333 100644 --- a/test/arithmetic/floating_point/floating_point_adder_simple_test.dart +++ b/test/arithmetic/floating_point/floating_point_adder_simple_test.dart @@ -1,7 +1,7 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // -// floating_point_smple test.dart +// floating_point_simple test.dart // Tests of FloatingPointAdderSimple -- non-rounding FP adder // // 2024 April 1 @@ -10,309 +10,391 @@ // Desmond A Kirkpatrick + FloatingPointValue.ofSpacedBinaryString(s); + + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + test('FP: simple adder narrow corner tests', () { + final testCases = [ + (ofString('0 0001 0000'), ofString('0 0000 0000')), + // subnormal from ae=1 s1=1, chop + (ofString('0 0000 0001'), ofString('1 0001 0000')), + // ae=0, l1=0 -- don't chop the leading digit + (ofString('0 0000 0000'), ofString('1 0000 1000')), + // requires unrounded comparison + (ofString('0 0000 0001'), ofString('1 0010 0010')), + // fix for shifting by l1 + (ofString('0 0000 0010'), ofString('1 0010 0000')), + // circle back ae=1 l1=1, shift, do not chop + (ofString('0 0000 0001'), ofString('1 0001 0000')), + // Large exponent difference requires rounding? + (ofString('0 0000 0001'), ofString('1 0111 0000')), + // This one wants no rounding + (ofString('0 0000 0001'), ofString('1 0011 0000')), + // wants rounding + (ofString('0 0000 0001'), ofString('1 0101 0000')), + // here a=7, l1=0, we need to add 1 + (ofString('0 0111 0000'), ofString('0 0111 0000')), + // Needs a shift of 1 when ae = 0 and l1 > ae and subnormal + (ofString('0 0000 0000'), ofString('0 0000 0001')), + // needs to shift 1 more and add to exponent a = 0 l1=0 when adding + (ofString('0 0000 0010'), ofString('0 0000 1110')), + // counterexample to adding 1 to exponent a = 0 l1=14 + (ofString('0 0000 0000'), ofString('0 0000 0000')), + //another counterexample: adding 1 to many to exp + (ofString('0 0000 0001'), ofString('0 0000 0001')), + // catastrophic cancellation + (ofString('0 1100 0000'), ofString('0 1100 0000')), + ]; + final adder = FloatingPointAdderSimple(fp1, fp2); - test('FP: addersmall numbers test', () { - final val = FloatingPoint32Value.getFloatingPointConstant( - FloatingPointConstants.smallestPositiveSubnormal) - .toDouble(); - final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.getFloatingPointConstant( - FloatingPointConstants.smallestPositiveSubnormal) - .value); - final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.getFloatingPointConstant( - FloatingPointConstants.smallestPositiveSubnormal) - .negate() - .value); - final out = FloatingPoint32Value.ofDouble(val - val); + for (final test in testCases) { + final fv1 = test.$1; + final fv2 = test.$2; + fp1.put(fv1.value); + fp2.put(fv2.value); + final expectedDouble = fp1.floatingPointValue.toDouble() + + fp2.floatingPointValue.toDouble(); + + final expectedRound = FloatingPointValue.ofDouble(expectedDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + + final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( + expectedDouble, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + final expected = expectedNoRound; + + final computed = adder.sum.floatingPointValue; + if ((computed != expectedNoRound) && (computed != expectedRound)) { + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expectedNoRound (${expectedNoRound.toDouble()})\texpected +'''); + } + } + }); + test('FP: simple adder narrow singleton test', () { + fp1.put(ofString('0 1100 0000')); + fp2.put(ofString('1 1100 0000')); + final adder = FloatingPointAdderSimple(fp1, fp2); - final adder = FloatingPointAdderSimple(fp1, fp2); + final expectedDouble = + fp1.floatingPointValue.toDouble() + fp2.floatingPointValue.toDouble(); + + final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( + expectedDouble, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + expect(adder.sum.floatingPointValue, equals(expectedNoRound)); + }); + test('FP: simple adder singleton pipelined path', () async { + final clk = SimpleClockGenerator(10).clk; + fp1.put(ofString('0 0000 0000')); + fp2.put(ofString('0 0001 0000')); + + final expectedDouble = + fp1.floatingPointValue.toDouble() + fp2.floatingPointValue.toDouble(); + + final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( + expectedDouble, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + + final FloatingPointValue expected; + expected = expectedNoRound; + final adder = FloatingPointAdderSimple(clk: clk, fp1, fp2); + await adder.build(); + unawaited(Simulator.run()); + await clk.nextNegedge; + fp1.put(0); + fp2.put(0); - final fpSuper = adder.sum.floatingPointValue; - final fpStr = fpSuper.toDouble().abs().toStringAsPrecision(7); - final valStr = out.toDouble().toStringAsPrecision(7); - expect(fpStr, valStr); + final computed = adder.sum.floatingPointValue; + expect(computed, equals(expected)); + await Simulator.endSimulation(); + }); + + test('FP: adder simple pipeline random', () async { + final clk = SimpleClockGenerator(10).clk; + + final adder = FloatingPointAdderSimple(clk: clk, fp1, fp2); + await adder.build(); + unawaited(Simulator.run()); + + final value = Random(513); + + for (var i = 0; i < 500; i++) { + final fv1 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + normal: true); + final fv2 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + normal: true); + + fp1.put(fv1.value); + fp2.put(fv2.value); + await clk.nextNegedge; + fp1.put(0); + fp2.put(0); + + final computed = adder.sum.floatingPointValue; + + final expectedRound = FloatingPointValue.ofDouble( + fv1.toDouble() + fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + + final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( + fv1.toDouble() + fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + + if ((computed != expectedNoRound) & (computed != expectedRound)) { + expect(computed, equals(expectedNoRound), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expectedNoRound (${expectedNoRound.toDouble()})\texpected +'''); + } + } + await Simulator.endSimulation(); + }); }); - test('FP: adder carry numbers test', () { - final val = pow(2.5, -12).toDouble(); - final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.ofDouble(pow(2.5, -12).toDouble()).value); - final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.ofDouble(pow(2.5, -12).toDouble()).value); - final out = FloatingPoint32Value.ofDouble(val + val); + test('FP: adder simple wide mantissa random', () async { + const exponentWidth = 2; + const mantissaWidth = 20; + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); final adder = FloatingPointAdderSimple(fp1, fp2); + await adder.build(); + unawaited(Simulator.run()); - final fpSuper = adder.sum.floatingPointValue; - final fpStr = fpSuper.toDouble().toStringAsPrecision(7); - final valStr = out.toDouble().toStringAsPrecision(7); - expect(fpStr, valStr); - }); - - test('FP: adder basic loop test', () { - final input = [(3.25, 1.5), (4.5, 3.75)]; - - for (final pair in input) { - final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.ofDouble(pair.$1).value); - final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.ofDouble(pair.$2).value); - final out = FloatingPoint32Value.ofDouble(pair.$1 + pair.$2); - - final adder = FloatingPointAdderSimple(fp1, fp2); - - final fpSuper = adder.sum.floatingPointValue; - final fpStr = fpSuper.toDouble().toStringAsPrecision(7); - final valStr = out.toDouble().toStringAsPrecision(7); - expect(fpStr, valStr); - } - }); + final value = Random(513); -// if you name two tests the same they get run together -// RippleCarryAdder: cannot access inputs from outside -- super.a issue - test('FP: adder basic loop test - negative numbers', () { - final input = [(4.5, 3.75), (9.0, -3.75), (-9.0, 3.9375), (-3.9375, 9.0)]; + for (var i = 0; i < 500; i++) { + final fv1 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - for (final pair in input) { - final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.ofDouble(pair.$1).value); - final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.ofDouble(pair.$2).value); - final out = FloatingPoint32Value.ofDouble(pair.$1 + pair.$2); + fp1.put(fv1.value); + fp2.put(fv2.value); - final adder = FloatingPointAdderSimple(fp1, fp2); + final computed = adder.sum.floatingPointValue; - final fpSuper = adder.sum.floatingPointValue; - final fpStr = fpSuper.toDouble().toStringAsPrecision(7); - final valStr = out.toDouble().toStringAsPrecision(7); - expect(fpStr, valStr); + final expectedRound = FloatingPointValue.ofDouble( + fv1.toDouble() + fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + + final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( + fv1.toDouble() + fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + + if ((computed != expectedNoRound) & (computed != expectedRound)) { + expect(computed, equals(expectedNoRound), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expectedNoRound (${expectedNoRound.toDouble()})\texpected +'''); + } } }); - test('FP: adder basic subnormal test', () { - final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.getFloatingPointConstant( - FloatingPointConstants.smallestPositiveNormal) - .value); - final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.getFloatingPointConstant( - FloatingPointConstants.smallestPositiveSubnormal) - .negate() - .value); - - final out = FloatingPoint32Value.ofDouble( - fp1.floatingPointValue.toDouble() + fp2.floatingPointValue.toDouble()); - final adder = FloatingPointAdderSimple(fp1, fp2); - - final fpSuper = adder.sum.floatingPointValue; - final fpStr = fpSuper.toDouble().toStringAsPrecision(7); - final valStr = out.toDouble().toStringAsPrecision(7); - expect(fpStr, valStr); - }); - - test('FP: tiny subnormal test', () { - const ew = 4; - const mw = 4; - final fp1 = FloatingPoint(exponentWidth: ew, mantissaWidth: mw) - ..put(FloatingPointValue.getFloatingPointConstant( - FloatingPointConstants.smallestPositiveNormal, ew, mw) - .value); - final fp2 = FloatingPoint(exponentWidth: ew, mantissaWidth: mw) - ..put(FloatingPointValue.getFloatingPointConstant( - FloatingPointConstants.smallestPositiveSubnormal, ew, mw) - .negate() - .value); - - final outDouble = - fp1.floatingPointValue.toDouble() + fp2.floatingPointValue.toDouble(); - final out = FloatingPointValue.ofDoubleUnrounded(outDouble, - exponentWidth: ew, mantissaWidth: mw); - final adder = FloatingPointAdderSimple(fp1, fp2); - - expect(adder.sum.floatingPointValue.compareTo(out), 0); - }); + test('FP: adder simple wide exponent random', () async { + const exponentWidth = 10; + const mantissaWidth = 2; + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - test('FP: addernegative number requiring a carryOut', () { - const pair = (9.0, -3.75); - const ew = 3; - const mw = 5; - - final fp1 = FloatingPoint(exponentWidth: ew, mantissaWidth: mw) - ..put(FloatingPointValue.ofDouble(pair.$1, - exponentWidth: ew, mantissaWidth: mw) - .value); - final fp2 = FloatingPoint(exponentWidth: ew, mantissaWidth: mw) - ..put(FloatingPointValue.ofDouble(pair.$2, - exponentWidth: ew, mantissaWidth: mw) - .value); - - final out = FloatingPointValue.ofDouble(pair.$1 + pair.$2, - exponentWidth: ew, mantissaWidth: mw); final adder = FloatingPointAdderSimple(fp1, fp2); + await adder.build(); - expect(adder.sum.floatingPointValue.compareTo(out), 0); - }); - - test('FP: adder subnormal cancellation', () { - const ew = 4; - const mw = 4; - final fp1 = FloatingPoint(exponentWidth: ew, mantissaWidth: mw) - ..put(FloatingPointValue.getFloatingPointConstant( - FloatingPointConstants.smallestPositiveSubnormal, ew, mw) - .negate() - .value); - final fp2 = FloatingPoint(exponentWidth: ew, mantissaWidth: mw) - ..put(FloatingPointValue.getFloatingPointConstant( - FloatingPointConstants.smallestPositiveSubnormal, ew, mw) - .value); - - final out = fp2.floatingPointValue + fp1.floatingPointValue; - - final adder = FloatingPointAdderSimple(fp1, fp2); - expect(adder.sum.floatingPointValue.abs().compareTo(out), 0); - }); + final value = Random(513); - test('FP: adder adder basic loop adder test2', () { - final input = [(4.5, 3.75), (9.0, -3.75), (-9.0, 3.9375), (-3.9375, 9.0)]; + for (var i = 0; i < 500; i++) { + final fv1 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - for (final pair in input) { - final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.ofDouble(pair.$1).value); - final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.ofDouble(pair.$2).value); - final out = FloatingPoint32Value.ofDouble(pair.$1 + pair.$2); + fp1.put(fv1.value); + fp2.put(fv2.value); - final adder = FloatingPointAdderSimple(fp1, fp2); + final computed = adder.sum.floatingPointValue; - final fpSuper = adder.sum.floatingPointValue; - final fpStr = fpSuper.toDouble().toStringAsPrecision(7); - final valStr = out.toDouble().toStringAsPrecision(7); - expect(fpStr, valStr); + final expectedRound = FloatingPointValue.ofDouble( + fv1.toDouble() + fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + + final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( + fv1.toDouble() + fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + + if ((computed != expectedNoRound) & (computed != expectedRound)) { + expect(computed, equals(expectedNoRound), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expectedNoRound (${expectedNoRound.toDouble()})\texpected +'''); + } } }); - test('FP: adder singleton', () { - const pair = (9.0, -3.75); - { - final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.ofDouble(pair.$1).value); - final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.ofDouble(pair.$2).value); - final out = FloatingPoint32Value.ofDouble(pair.$1 + pair.$2); + test('FP: simple adder general singleton test', () { + FloatingPointValue ofString(String s) => + FloatingPointValue.ofSpacedBinaryString(s); + + final fv1 = ofString('0 001 111111'); + final fv2 = ofString('1 010 000000'); + + final fp1 = FloatingPoint( + exponentWidth: fv1.exponent.width, mantissaWidth: fv1.mantissa.width); + final fp2 = FloatingPoint( + exponentWidth: fv2.exponent.width, mantissaWidth: fv2.mantissa.width); + fp1.put(fv1); + fp2.put(fv2); + final adder = FloatingPointAdderSimple(fp1, fp2); + final exponentWidth = adder.sum.exponent.width; + final mantissaWidth = adder.sum.mantissa.width; - final adder = FloatingPointAdderSimple(fp1, fp2); + final expectedDouble = + fp1.floatingPointValue.toDouble() + fp2.floatingPointValue.toDouble(); - final fpSuper = adder.sum.floatingPointValue; - final fpStr = fpSuper.toDouble().toStringAsPrecision(7); - final valStr = out.toDouble().toStringAsPrecision(7); - expect(fpStr, valStr); - } - }); - test('FP: adder random', () { - const eWidth = 5; - const mWidth = 20; - - final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); - final fpv = FloatingPointValue.ofInts(0, 0, - exponentWidth: eWidth, mantissaWidth: mWidth); - final smallest = FloatingPointValue.getFloatingPointConstant( - FloatingPointConstants.smallestPositiveNormal, eWidth, mWidth); - fa.put(0); - fb.put(0); - final adder = FloatingPointAdderSimple(fa, fb); - final value = Random(513); - for (var i = 0; i < 50; i++) { - final fva = FloatingPointValue.random(value, - exponentWidth: eWidth, mantissaWidth: mWidth, normal: true); - final fvb = FloatingPointValue.random(value, - exponentWidth: eWidth, mantissaWidth: mWidth, normal: true); - fa.put(fva); - fb.put(fvb); - // fromDoubleIter does not round like '+' would - final expected = FloatingPointValue.ofDoubleUnrounded( - fva.toDouble() + fvb.toDouble(), - exponentWidth: fpv.exponent.width, - mantissaWidth: fpv.mantissa.width); - final computed = adder.sum.floatingPointValue; - final ulp = FloatingPointValue.ofInts( - max(expected.exponent.toInt(), 1), 1, - exponentWidth: eWidth, mantissaWidth: mWidth); - final diff = (expected.toDouble() - computed.toDouble()).abs(); - if (expected.isNormal()) { - expect(expected.isNaN(), equals(computed.isNaN())); - if (!expected.isNaN()) { - expect(diff, lessThan(ulp.toDouble() * smallest.toDouble())); - } - } - } + final expectedNoRound = FloatingPointValue.ofDoubleUnrounded(expectedDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + expect(adder.sum.floatingPointValue, equals(expectedNoRound)); }); } diff --git a/test/arithmetic/floating_point/floating_point_adder_test.dart b/test/arithmetic/floating_point/floating_point_adder_test.dart new file mode 100644 index 000000000..2c5edfbc6 --- /dev/null +++ b/test/arithmetic/floating_point/floating_point_adder_test.dart @@ -0,0 +1,86 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_adder_test.dart +// Basic tests for all floating-point adders. +// +// 2025 January 3 +// Author: Desmond A Kirkpatrick + FloatingPointValue.ofSpacedBinaryString(s); + final testCases = [ + (ofString('0 0001 0000'), ofString('0 0000 0000')), + (ofString('0 0111 0010'), ofString('0 1110 1111')), + (ofString('0 1010 0000'), ofString('0 1011 0100')), + (fv.infinity, fv.infinity), + (fv.negativeInfinity, fv.negativeInfinity), + (fv.infinity, fv.negativeInfinity), + (fv.infinity, fv.zero), + (fv.negativeInfinity, fv.zero), + (fv.infinity, fv.one), + (fv.zero, fv.one), + (fv.negativeInfinity, fv.one), + ]; + + for (final test in testCases) { + final fv1 = test.$1; + final fv2 = test.$2; + + final expected = FloatingPointValue.ofDoubleUnrounded( + fv1.toDouble() * fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + + fp1.put(fv1.value); + fp2.put(fv2.value); + final multiply = FloatingPointMultiplierSimple(fp1, fp2); + final computed = multiply.product.floatingPointValue; + + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); + } + }); + + test('FP: simple multiplier exhaustive', () { + const exponentWidth = 3; + const mantissaWidth = 3; + + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(0); + fp2.put(0); + final multiply = FloatingPointMultiplierSimple(fp1, fp2); + + final expLimit = pow(2, exponentWidth) - 1; + final mantLimit = pow(2, mantissaWidth); + for (final subtract in [0, 1]) { + for (var e1 = 0; e1 < expLimit; e1++) { + for (var m1 = 0; m1 < mantLimit; m1++) { + final fv1 = FloatingPointValue.ofInts(e1, m1, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + for (var e2 = 0; e2 < expLimit; e2++) { + for (var m2 = 0; m2 < mantLimit; m2++) { + final fv2 = FloatingPointValue.ofInts(e2, m2, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + sign: subtract == 1); + + final expected = FloatingPointValue.ofDoubleUnrounded( + fv1.toDouble() * fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + + fp1.put(fv1.value); + fp2.put(fv2.value); + final computed = multiply.product.floatingPointValue; + + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); + } + } + } + } + } + }); + + test('FP: simple multiplier full random', () async { + const exponentWidth = 4; + const mantissaWidth = 4; + + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(0); + fp2.put(0); + final multiplier = FloatingPointMultiplierSimple(fp1, fp2); + final value = Random(51); + + var cnt = 1000; + while (cnt > 0) { + final fv1 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(fv1); + fp2.put(fv2); + + final expected = FloatingPointValue.ofDoubleUnrounded( + fv1.toDouble() * fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + final computed = multiplier.product.floatingPointValue; + + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); + cnt--; + } + }); + + test('FP: simple multiplier singleton', () { + const exponentWidth = 4; + const mantissaWidth = 4; + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv1 = FloatingPointValue.ofBinaryStrings('1', '1100', '0111'); + + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = FloatingPointValue.ofBinaryStrings('0', '1100', '0000'); + + final doubleProduct = fv1.toDouble() * fv2.toDouble(); + final expected = FloatingPointValue.ofDoubleUnrounded(doubleProduct, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + + fp1.put(fv1.value); + fp2.put(fv2.value); + + final multiply = FloatingPointMultiplierSimple(fp1, fp2); + final computed = multiply.product.floatingPointValue; + + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); + }); + }); + test('FP: simple multiplier singleton pipelined', () async { + final clk = SimpleClockGenerator(10).clk; + + const exponentWidth = 4; + const mantissaWidth = 4; + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv1 = FloatingPointValue.ofBinaryStrings('0', '0111', '0000'); + + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = FloatingPointValue.ofBinaryStrings('0', '1101', '0101'); + + final expected = FloatingPointValue.ofDoubleUnrounded( + fv1.toDouble() * fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + + fp1.put(fv1.value); + fp2.put(fv2.value); + + final multiply = FloatingPointMultiplierSimple(fp1, fp2, clk: clk); + + unawaited(Simulator.run()); + await clk.nextNegedge; + fp1.put(0); + fp2.put(0); + final computed = multiply.product.floatingPointValue; + + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t+ + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); + await Simulator.endSimulation(); + }); +} diff --git a/test/arithmetic/floating_point/floating_point_value_test.dart b/test/arithmetic/floating_point/floating_point_value_test.dart index 36933ce81..9dd624aad 100644 --- a/test/arithmetic/floating_point/floating_point_value_test.dart +++ b/test/arithmetic/floating_point/floating_point_value_test.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // floating_point_test.dart @@ -17,27 +17,25 @@ import 'package:test/test.dart'; void main() { test('FPV: exhaustive round-trip', () { - const signStr = '0'; const exponentWidth = 4; const mantissaWidth = 4; - var exponent = LogicValue.zero.zeroExtend(exponentWidth); - var mantissa = LogicValue.zero.zeroExtend(mantissaWidth); - for (var k = 0; k < pow(2.0, exponentWidth).toInt() - 1; k++) { - final expStr = exponent.bitString; - for (var i = 0; i < pow(2.0, mantissaWidth).toInt(); i++) { - final mantStr = mantissa.bitString; - final fp = FloatingPointValue.ofBinaryStrings(signStr, expStr, mantStr); - final dbl = fp.toDouble(); - final fp2 = FloatingPointValue.ofDouble(dbl, - exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - if (fp != fp2) { - if (fp.isNaN() != fp2.isNaN()) { - expect(fp, equals(fp2)); - } + for (final signStr in ['0', '1']) { + var exponent = LogicValue.zero.zeroExtend(exponentWidth); + var mantissa = LogicValue.zero.zeroExtend(mantissaWidth); + for (var k = 0; k < pow(2.0, exponentWidth).toInt() - 1; k++) { + final expStr = exponent.bitString; + for (var i = 0; i < pow(2.0, mantissaWidth).toInt(); i++) { + final mantStr = mantissa.bitString; + final fp = + FloatingPointValue.ofBinaryStrings(signStr, expStr, mantStr); + final dbl = fp.toDouble(); + final fp2 = FloatingPointValue.ofDouble(dbl, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + expect(fp, equals(fp2)); + mantissa = mantissa + 1; } - mantissa = mantissa + 1; + exponent = exponent + 1; } - exponent = exponent + 1; } }); @@ -164,10 +162,7 @@ void main() { for (var c = 0; c < corners.length; c++) { final val = corners[c][1] as double; final str = corners[c][0] as String; - final fp = - FloatingPointValue.ofDouble(val, exponentWidth: 4, mantissaWidth: 3); - expect(val, fp.toDouble()); - expect(str, fp.toString()); + final fp8 = FloatingPoint8E4M3Value.ofDouble(val); expect(val, fp8.toDouble()); expect(str, fp8.toString()); @@ -313,4 +308,105 @@ void main() { fp2.compareTo(FloatingPointValue.ofSpacedBinaryString('0 0000 0000')), equals(0)); }); + test('FPV: infinity/NaN conversion tests', () async { + const exponentWidth = 4; + const mantissaWidth = 4; + final infinity = FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.infinity, exponentWidth, mantissaWidth); + final negativeInfinity = FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.negativeInfinity, exponentWidth, mantissaWidth); + + final tooLargeNumber = FloatingPointValue.ofDouble(257, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + + expect(infinity.toDouble(), equals(double.infinity)); + expect(negativeInfinity.toDouble(), equals(double.negativeInfinity)); + + expect(tooLargeNumber.toDouble(), equals(double.infinity)); + + expect(tooLargeNumber.negate().toDouble(), equals(double.negativeInfinity)); + + expect( + FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.nan, exponentWidth, mantissaWidth) + .toDouble() + .isNaN, + equals(true)); + }); + test('FPV: infinity/NaN unrounded conversion tests', () async { + const exponentWidth = 4; + const mantissaWidth = 4; + final infinity = FloatingPointValue.ofDoubleUnrounded(double.infinity, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final negativeInfinity = FloatingPointValue.ofDoubleUnrounded( + double.negativeInfinity, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + final tooLargeNumber = FloatingPointValue.ofDoubleUnrounded(257, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + expect(tooLargeNumber.toDouble(), equals(double.infinity)); + expect(infinity.toDouble(), equals(double.infinity)); + expect(tooLargeNumber.negate().toDouble(), equals(double.negativeInfinity)); + expect(negativeInfinity.toDouble(), equals(double.negativeInfinity)); + }); + + test('FPV: infinity operation tests', () { + const exponentWidth = 4; + const mantissaWidth = 4; + final one = FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.one, exponentWidth, mantissaWidth); + final zero = FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.positiveZero, exponentWidth, mantissaWidth); + final infinity = FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.infinity, exponentWidth, mantissaWidth); + final negativeInfinity = FloatingPointValue.getFloatingPointConstant( + FloatingPointConstants.negativeInfinity, exponentWidth, mantissaWidth); + + for (final f in [infinity, negativeInfinity]) { + for (final s in [infinity, negativeInfinity]) { + // Addition + if (f == s) { + expect((f + s).toDouble(), equals(f.toDouble() + s.toDouble())); + } else { + expect((f + s).toDouble().isNaN, + equals((f.toDouble() + s.toDouble()).isNaN)); + } + // Subtraction + if (f != s) { + expect((f - s).toDouble(), equals(f.toDouble())); + } else { + expect((f - s).toDouble().isNaN, + equals((f.toDouble() - s.toDouble()).isNaN)); + } + // Multiplication + expect((f * s).toDouble(), equals(f.toDouble() * s.toDouble())); + // Division + expect((f / s).toDouble().isNaN, + equals((f.toDouble() / s.toDouble()).isNaN)); + } + } + for (final f in [infinity, negativeInfinity]) { + for (final s in [zero, one]) { + // Addition + expect((f + s).toDouble(), equals(f.toDouble() + s.toDouble())); + // Subtraction + expect((f - s).toDouble(), equals(f.toDouble())); + expect((s - f).toDouble(), equals(-f.toDouble())); + // Multiplication + if (s == zero) { + expect((f * s).toDouble().isNaN, + equals((f.toDouble() * s.toDouble()).isNaN)); + } else { + expect((f * s).toDouble(), equals(f.toDouble())); + } + // Division + if (s == zero) { + expect((f / s).toDouble().isNaN, + equals((f.toDouble() * s.toDouble()).isNaN)); + } else { + expect((f / s).toDouble(), equals(f.toDouble())); + } + } + } + }); } diff --git a/test/arithmetic/multiplier_test.dart b/test/arithmetic/multiplier_test.dart index 629804bfe..722e1e38d 100644 --- a/test/arithmetic/multiplier_test.dart +++ b/test/arithmetic/multiplier_test.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // multiplier_test.dart @@ -210,7 +210,6 @@ void main() { ' SelM=${(selectSignedMultiplier != null) ? 1 : 0}'; return (a, b, {selectSignedMultiplicand, selectSignedMultiplier}) => CompressionTreeMultiplier(a, b, radix, - ppTree: ppTree, ppGen: ppGen, signedMultiplicand: signedMultiplicand, signedMultiplier: signedMultiplier, @@ -277,7 +276,6 @@ void main() { ' SelM=${(selectSignedMultiplier != null) ? 1 : 0}'; return (a, b, c) => CompressionTreeMultiplyAccumulate(a, b, c, radix, - ppTree: ppTree, ppGen: ppGen, signedMultiplicand: signedMultiplicand, signedMultiplier: signedMultiplier, diff --git a/test/arithmetic/parallel_prefix_operations_test.dart b/test/arithmetic/parallel_prefix_operations_test.dart index b42365f90..e982cf999 100644 --- a/test/arithmetic/parallel_prefix_operations_test.dart +++ b/test/arithmetic/parallel_prefix_operations_test.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // parallel_prefix_operations.dart @@ -89,7 +89,7 @@ void testPriorityEncoder( // put/expect testing - for (var j = 0; j < (1 << n); ++j) { + for (var j = 1; j < (1 << n); ++j) { final golden = computePriorityEncoding(j); inp.put(j); final result = mod.out.value.toInt(); @@ -202,6 +202,22 @@ void main() { expect(ParallelPrefixPriorityEncoder(val).out.value.toInt(), equals(0)); expect(ParallelPrefixPriorityEncoder(val.reversed).out.value.toInt(), equals(3)); + + final valid = Logic(); + ParallelPrefixPriorityEncoder(val, valid: valid); + expect(valid.value.toBool(), equals(true)); + }); + test('priority encoder return beyond width if zero', () { + final val = Logic(width: 5); + // ignore: cascade_invocations + val.put(0); + expect(ParallelPrefixPriorityEncoder(val).out.value.toInt(), + equals(val.width + 1)); + expect(ParallelPrefixPriorityEncoder(val.reversed).out.value.toInt(), + equals(val.width + 1)); + final valid = Logic(); + ParallelPrefixPriorityEncoder(val, valid: valid); + expect(valid.value.toBool(), equals(false)); }); // Note: all ParallelPrefixAdders are tested in adder_test.dart diff --git a/test/signed_shifter_test.dart b/test/signed_shifter_test.dart new file mode 100644 index 000000000..08c350104 --- /dev/null +++ b/test/signed_shifter_test.dart @@ -0,0 +1,33 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// signed_shifter_test.dart +// Tests for signed shifter +// +// 2025 January 8 +// Author: Desmond Kirkpatrick + +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; +import 'package:test/test.dart'; + +void main() { + test('sigend shifter test', () { + final bits = Const(16, width: 32); + final shift = Logic(width: 3); + + final shifter = SignedShifter(bits, shift); + var expected = 16; + for (var i = 0; i < 4; i++) { + shift.put(i); + expect(shifter.shifted.value.toInt(), equals(expected)); + expected = expected << 1; + } + expected = 1; + for (var i = 4; i < 8; i++) { + shift.put(i); + expect(shifter.shifted.value.toInt(), equals(expected)); + expected = expected << 1; + } + }); +} From eb9de15dbf9f33554a658da3a68e8a5c3e90c037 Mon Sep 17 00:00:00 2001 From: Desmond Kirkpatrick Date: Wed, 15 Jan 2025 18:16:49 -0800 Subject: [PATCH 2/3] Refactor signextend (#154) * refactoring of sign extension classes to be separate from product generation. --- doc/components/adder.md | 32 +- doc/components/multiplier.md | 51 ++-- doc/components/multiplier_components.md | 39 ++- lib/src/arithmetic/addend_compressor.dart | 2 +- .../arithmetic/evaluate_partial_product.dart | 4 +- lib/src/arithmetic/multiplier.dart | 32 +- .../parallel_prefix_operations.dart | 6 +- .../arithmetic/partial_product_generator.dart | 81 ++--- .../partial_product_sign_extend.dart | 282 +++++++++++++++--- lib/src/arithmetic/sign_magnitude_adder.dart | 2 +- test/arithmetic/addend_compressor_test.dart | 13 +- test/arithmetic/multiplier_encoder_test.dart | 135 ++++++--- test/arithmetic/multiplier_test.dart | 41 +-- 13 files changed, 468 insertions(+), 252 deletions(-) diff --git a/doc/components/adder.md b/doc/components/adder.md index 18428ac93..6f2bd287b 100644 --- a/doc/components/adder.md +++ b/doc/components/adder.md @@ -13,7 +13,7 @@ ROHD-HCL provides a set of adder modules to get the sum from a pair of Logic. So A ripple carry adder is a digital circuit used for binary addition. It consists of a series of [FullAdder](https://intel.github.io/rohd-hcl/rohd_hcl/FullAdder-class.html)s connected in a chain, with the carry output of each adder linked to the carry input of the next one. Starting from the least significant bit (LSB) to most significant bit (MSB), the adder sequentially adds corresponding bits of two binary numbers. -The [RippleCarryAdder](https://intel.github.io/rohd-hcl/rohd_hcl/RippleCarryAdder-class.html) module in ROHD-HCL accept input `Logic`s a and b as the input pin and the name of the module `name`. Note that the width of the inputs must be the same or a [RohdHclException](https://intel.github.io/rohd-hcl/rohd_hcl/RohdHclException-class.html) will be thrown. +The [adder](https://intel.github.io/rohd-hcl/rohd_hcl/Adder-class.html) module in ROHD-HCL accept input `Logic`s a and b as the input pin and the name of the module `name`. Note that the width of the inputs must be the same or a [RohdHclException](https://intel.github.io/rohd-hcl/rohd_hcl/RohdHclException-class.html) will be thrown. An example is shown below to add two inputs of signals that have 8-bits of width. @@ -24,8 +24,8 @@ final b = Logic(name: 'b', width: 8); a.put(5); b.put(5); -final rippleCarryAdder = RippleCarryAdder(a, b); -final sum = rippleCarryAdder.sum; +final adder = adder(a, b); +final sum = adder.sum; ``` ## Parallel Prefix Adder @@ -69,7 +69,7 @@ Here is an example of instantiating a [OnesComplementAdder](https://intel.githu b.put(bv); final carry = Logic(); final adder = OnesComplementAdder( - a, b, carryOut: carry, adderGen: RippleCarryAdder.new, + a, b, carryOut: carry, adderGen: adder.new, subtract: true); final mag = adder.sum.value.toInt() + (carry.value.isZero ? 0 : 1)); final out = (adder.sign.value.toInt() == 1 ? -mag : mag); @@ -97,7 +97,7 @@ Here is an example of instantiating a [SignMagnitudeAdder](https://intel.github. b.put(18); bSign.put(0); - final adder = SignMagnitudeAdder(aSign, a, bSign, b, adderGen: RippleCarryAdder.new, + final adder = SignMagnitudeAdder(aSign, a, bSign, b, adderGen: adder.new, largestMagnitudeFirst: true); final sum = adder.sum; @@ -113,14 +113,18 @@ The [`CarrySelectCompoundAdder`](https://intel.github.io/rohd-hcl/rohd_hcl/Carry The delay of the adder is defined by the combination of the sub-adders and the accumulated carry-select chain delay. The [CarrySelectCompoundAdder](https://intel.github.io/rohd-hcl/rohd_hcl/CarrySelectCompoundAdder-class.html) module in ROHD-HCL accepts input `Logic`s a and b as the input pin and the name of the module `name`. Note that the width of the inputs must be the same or a [RohdHclException](https://intel.github.io/rohd-hcl/rohd_hcl/RohdHclException-class.html) will be thrown. -The compound adder generator provides two alogithms for splitting the adder into adder sub-blocks: -- The [CarrySelectCompoundAdder.splitSelectAdderAlgorithm4Bit](https://intel.github.io/rohd-hcl/rohd_hcl/CarrySelectCompoundAdder/splitSelectAdderAlgorithm4Bit.html) algoritm splits the adder into blocks of 4-bit ripple-carry adders with the first one width adjusted down. -- The [CarrySelectCompoundAdder.splitSelectAdderAlgorithmSingleBlock](https://intel.github.io/rohd-hcl/rohd_hcl/CarrySelectCompoundAdder/splitSelectAdderAlgorithmSingleBlock.html) algorithm generates only one sub=block with the full bitwidth of the adder. +The compound adder forms a select chain around a set of adders specified by: -Input `List Function(int adderFullWidth) widthGen` should be used to specify the custom adder splitting algorithm that returns a list of sub-adders width. The default one is [CarrySelectCompoundAdder.splitSelectAdderAlgorithmSingleBlock](). +- `addergen`: an adder generator functor option to build the block adders with the default being `ParallelPrefixAdder`. -The `adderGen` input selects the type of sub-adder used, with the default being `ParallelPrefixAdder`. +The compound adder generator provides two algorithms for splitting the adder into adder sub-blocks: + +- `splitSelectAdderAlgorithmSingleBlock: + - The [CarrySelectCompoundAdder.splitSelectAdderAlgorithm4Bit](https://intel.github.io/rohd-hcl/rohd_hcl/CarrySelectCompoundAdder/splitSelectAdderAlgorithm4Bit.html) algoritm splits the adder into blocks of 4-bit ripple-carry adders with the first one width adjusted down. + - The [CarrySelectCompoundAdder.splitSelectAdderAlgorithmSingleBlock](https://intel.github.io/rohd-hcl/rohd_hcl/CarrySelectCompoundAdder/splitSelectAdderAlgorithmSingleBlock.html) algorithm generates only one sub=block with the full bitwidth of the adder. + +- `List Function(int adderFullWidth) widthGen` should be used to specify the custom adder splitting algorithm that returns a list of sub-adders width. The default one is [CarrySelectCompoundAdder.splitSelectAdderAlgorithmSingleBlock](). An example is shown below to add two inputs of signals that have 8-bits of width. @@ -131,11 +135,11 @@ final b = Logic(name: 'b', width: 8); a.put(5); b.put(5); -final rippleCarryAdder = CarrySelectCompoundAdder(a, b); -final sum = rippleCarryAdder.sum; -final sum1 = rippleCarryAdder.sum1; +final adder = CarrySelectCompoundAdder(a, b); +final sum = adder.sum; +final sum1 = adder.sum1; -final rippleCarryAdder4BitBlock = CarrySelectCompoundAdder(a, b, +final adder4BitBlock = CarrySelectCompoundAdder(a, b, widthGen: CarrySelectCompoundAdder.splitSelectAdderAlgorithm4Bit); ``` diff --git a/doc/components/multiplier.md b/doc/components/multiplier.md index f2439c983..75a71c6f5 100644 --- a/doc/components/multiplier.md +++ b/doc/components/multiplier.md @@ -1,6 +1,6 @@ -# Multiplier +# Integer Multiplier -ROHD-HCL provides an abstract `Multiplier` module which multiplies two +ROHD-HCL provides an abstract [Multiplier](https://intel.github.io/rohd-hcl/rohd_hcl/Multiplier-class.html) module which multiplies two numbers represented as two `Logic`s, potentially of different widths, treating them as either signed (twos' complement) or unsigned. It produces the product as a `Logic` with width equal to the sum of the @@ -16,7 +16,7 @@ of this abstract `Module`: - [Compression Tree Multiplier](#compression-tree-multiplier) An additional kind of abstract module provided is a -`MultiplyAccumulate` module which multiplies two numbers represented +[MultiplierAccumulate](https://intel.github.io/rohd-hcl/rohd_hcl/MultiplyAccumulate-class.html) module which multiplies two numbers represented as two `Logic`s and adds the result to a third `Logic` with width equal to the sum of the widths of the main inputs. Similar to the `Multiplier`, the signs of the operands are either fixed by a parameter, @@ -86,23 +86,24 @@ Simulator.endSimulation(); A compression tree multiplier is a digital circuit used for performing multiplication operations, using Booth encoding to produce addends, a -compression tree for reducing addends to a final pair, and a final -adder generated from a parallel prefix tree option. It is particularly -useful in applications that require high speed multiplication, such as -digital signal processing. +compression tree for reducing addends to a final pair, and a final adder +generated from a parallel prefix tree functor parameter. It is particularly +useful in applications that require high speed and varying width multiplication, +such as digital signal processing. The parameters of the -`CompressionTreeMultiplier` are: +[CompressionTreeMultiplier](https://intel.github.io/rohd-hcl/rohd_hcl/CompressionTreeMultiplier-class.html) are: - Two input terms `a` and `b` which can be different widths. - The radix used for Booth encoding (2, 4, 8, and 16 are currently supported). -- The type of `ParallelPrefix` tree used in the final `ParallelPrefixAdder` (optional). -- `ppGen` parameter: the type of `PartialProductGenerator` to use which has derived classes for different styles of sign extension. In some cases this adds an extra row to hold a sign bit. -- `signedMultiplicand` parameter: whether the multiplicand (first arg) should be treated as signed (twos' complement) or unsigned. -- `signedMultiplier` parameter: whether the multiplier (second arg) should be treated as signed (twos' complement) or unsigned. -- An optional `selectSignedMultiplicand` control signal which overrides the `signedMultiplicand` parameter allowing for runtime control of signed or unsigned operation with the same hardware. `signedMultiplicand` must be false if using this control signal. -- An optional `selectSignedMultiplier` control signal which overrides the `signedMultiplier` parameter allowing for runtime control of signed or unsigned operation with the same hardware. `signedMultiplier` must be false if using this control signal. -- An optional `clk`, as well as `enable` and `reset` that are used to add a pipestage in the `ColumnCompressor` to allow for pipelined operation. +- `seGen` parameter: the type of `PartialProductSignExtension` functor to use which has derived classes for different styles of sign extension. In some cases this adds an extra row to hold a sign bit (default `CompactRectSignExtension` does not). See [Sign Extension Options](./multiplier_components.md#sign-extension-option). +- Signed or unsigned operands: + - `signedMultiplicand` parameter: whether the multiplicand (first arg) should be treated as signed (twos' complement) or unsigned. + - `signedMultiplier` parameter: whether the multiplier (second arg) should be treated as signed (twos' complement) or unsigned. +- Alternatively, it supports runtime control of signage: + - An optional `selectSignedMultiplicand` control signal which allows for runtime control of signed or unsigned operation with the same hardware. `signedMultiplicand` must be false if using this control signal. + - An optional `selectSignedMultiplier` control signal which allows for runtime control of signed or unsigned operation with the same hardware. `signedMultiplier` must be false if using this control signal. +- An optional `clk`, as well as `enable` and `reset` that are used to add a pipestage in the `ColumnCompressor` to allow for pipelined operation, making the multiplier operate in 2 cycles. Here is an example of use of the `CompressionTreeMultiplier` with one signed input: @@ -130,22 +131,18 @@ A compression tree multiply-accumulate is similar to a compress tree multiplier, but it inserts an additional addend into the compression tree to allow for accumulation into this third input. -The parameters of the -`CompressionTreeMultiplyAccumulate` are: +The additional parameters of the +[CompressionTreeMultiplyAccumulate](https://intel.github.io/rohd-hcl/rohd_hcl/CompressionTreeMultiplyAccumulate-class.html) over the [CompressionTreeMltiplier](#compression-tree-multiplier) are: -- Two input product terms `a` and `b` which can be different widths - The accumulate input term `c` which must have width as sum of the two operand widths + 1. -- The radix used for Booth encoding (2, 4, 8, and 16 are currently supported) -- The type of `ParallelPrefix` tree used in the final `ParallelPrefixAdder` (default Kogge-Stone). -- `ppGen` parameter: the type of `PartialProductGenerator` to use which has derived classes for different styles of sign extension. In some cases this adds an extra row to hold a sign bit (default `PartialProductGeneratorCompactRectSignExtension`). -- `signedMultiplicand` parameter: whether the multiplicand (first arg) should be treated as signed (2s complement) or unsigned -- `signedMultiplier` parameter: whether the multiplier (second arg) should be treated as signed (twos' complement) or unsigned -- `signedAddend` parameter: whether the addend (third arg) should be treated as signed (twos' complement) or unsigned -- An optional `selectSignedMultiplicand` control signal which overrides the `signedMultiplicand` parameter allowing for runtime control of signed or unsigned operation with the same hardware. `signedMultiplicand` must be false if using this control signal. -- An optional `selectSignedMultiplier` control signal which overrides the `signedMultiplier` parameter allowing for runtime control of signed or unsigned operation with the same hardware. `signedMultiplier` must be false if using this control signal. -- An optional `selectSignedAddend` control signal which overrides the `signedAddend` parameter allowing for runtime control of signed or unsigned operation with the same hardware. `signedAddend` must be false if using this control signal. +- Addend signage: + - `signedAddend` parameter: whether the addend (third arg) should be treated as signed (twos' complement) or unsigned +OR + - An optional `selectSignedAddend` control signal allows for runtime control of signed or unsigned operation with the same hardware. `signedAddend` must be false if using this control signal. - An optional `clk`, as well as `enable` and `reset` that are used to add a pipestage in the `ColumnCompressor` to allow for pipelined operation. +The output width of the `CompressionTreeMultiplier` is the sum of the product term widths plus one to accomodate the additional acccumulate term. + Here is an example of using the `CompressionTreeMultiplyAccumulate` with all inputs as signed: ```dart diff --git a/doc/components/multiplier_components.md b/doc/components/multiplier_components.md index 00c49a81c..ed5f2c606 100644 --- a/doc/components/multiplier_components.md +++ b/doc/components/multiplier_components.md @@ -51,7 +51,7 @@ row slice mult A few things to note: first, that we are negating by ones' complement (so we need a -0) and second, these rows do not add up to (18: 10010). For Booth encoded rows to add up properly, they need to be in twos' complement form, and they need to be sign-extended. - Here is the matrix with a crude sign extension `brute` (the table formatting is available from our `PartialProductGenerator` component). With twos' complementation, and sign bits folded in (note the LSB of each row has a sign term from the previous row), these addends are correctly formed and add to (18: 10010). + Here is the matrix with a crude sign extension `brute` (the table formatting is available from our [PartialProductGenerator](https://intel.github.io/rohd-hcl/rohd_hcl/PartialProductGenerator-class.html) component). With twos' complementation, and sign bits folded in (note the LSB of each row has a sign term from the previous row), these addends are correctly formed and add to (18: 10010). ```text 7 6 5 4 3 2 1 0 @@ -90,7 +90,19 @@ Note that radix-4 shifts by 2 positions each row, but with only two rows and wit ## Partial Product Generator -This building block creates a set of rows of partial products from a multiplicand and a multiplier. It maintains the partial products as a list of rows, which are themselves lists of Logic as well as a row shift value for each row to represent the starting column of the row's least-significant bit. Its primary inputs are the multiplicand, multiplier, `RadixEncoder`, whether the operands are signed, and the type of `SignExtension` to use in generating the partial product rows. +The base class of `PartialProductGenerator` is [PartialProductArray](https://intel.github.io/rohd-hcl/rohd_hcl/PartialProductArray-class.html) which is simply a `List>` to represent addends and a `rowShift[row]` to represent the shifts in the partial product matrix. If customization is needed beyond sign extension options, routines are provided that allow for fixed customization of bit positions or conditional (mux based on a Logic) form in the `PartialProductArray`. + +```dart +final ppa = PartialProductArray(a,b); +ppa.setAbsolute(row, col, logic); +ppa.setAbsoluteAll(row, col, List); +ppa.muxAbsolute(row, col, condition, logic); +ppa.muxAbsoluteAll(row, col, condition, List); +``` + + The `PartialProductGenerator` adds to this the [RadixEncoder](https://intel.github.io/rohd-hcl/rohd_hcl/RadixEncoder-class.html) to encode the rows along with a matching `MultiplicandSelector` to create the actual mantissas used in each row. + +As a building block which creates a set of rows of partial products from a multiplicand and a multiplier, it maintains the partial products as a list of rows om the `PartialProductArray` base. Its primary inputs are the multiplicand, multiplier, `RadixEncoder`, and whether the operands are signed. The partial product generator produces a set of addends in shifted position to be added. The main output of the component is @@ -122,17 +134,13 @@ Our `RadixEncoder` module is general, creating selection tables for arbitrary Bo ### Sign Extension Option -The `PartialProductGenerator` class also provides for sign extension with multiple options including `SignExtension.none` which is no sign extension for help in debugging, as well as `SignExtension.compactRect` which is a compact form which works for rectangular products where the multiplicand and multiplier can be of different widths. - -The `PartialProductGenerator` creates a set of addends in its base class `PartialProductArray` which is simply a `List>` to represent addends and a `rowShift[row]` to represent the shifts in the partial product matrix. If customization is needed beyond sign extension options, routines are provided that allow for fixed customization of bit positions or conditional (mux based on a Logic) form in the `PartialProductArray`. +The `PartialProductSignExtension` defines the API for doing different kinds of sign extension on the `PartialProductArray`, from very simplistic for helping design new arithmetics to fairly standard to even compact, rectangular forms. -```dart -final ppg = PartialProductGenerator(a,b); -ppg.setAbsolute(row, col, logic); -ppg.setAbsoluteAll(row, col, List); -ppg.muxAbsolute(row, col, condition, logic); -ppg.muxAbsoluteAll(row, col, condition, List); -``` +- None: no sign extension. +- Brute: full width extension which is robust but costly. +- StopBit: A standard form which has the inverse-sign and a '1' stop bit in each row +- Compact: A form that eliminates a final sign in an otherwise empty final row. +- CompactRect: An enhanced form of compact that can handle rectangular multiplications. ### Partial Product Visualization @@ -167,7 +175,7 @@ You can also generate a Markdown form of the same matrix: Once you have a partial product matrix, you would like to add up the addends. Traditionally this is done using compression trees which instantiate 2:1 and 3:2 column compressors (or carry-save adders) to reduce the matrix to two addends. The final two addends are often added with an efficient final adder. -Our `ColumnCompressor` class uses a delay-driven approach to efficiently compress the rows of the partial product matrix. Its only argument is a `PartialProductArray` (base class of `PartialProductGenerator`), and it creates a list of `ColumnQueue`s containing the final two addends stored by column after compression. An `extractRow`routine can be used to extract the columns. `ColumnCompressor` currently has an extension `EvaluateColumnCompressor` which can be used to print out the compression progress. Here is the legend for these printouts. +Our [ColumnCompressor](https://intel.github.io/rohd-hcl/rohd_hcl/ColumnCompressor-class.html) class uses a delay-driven approach to efficiently compress the rows of the partial product matrix. Its only argument is a `PartialProductArray` (base class of `PartialProductGenerator`), and it creates a list of `ColumnQueue`s containing the final two addends stored by column after compression. An `extractRow`routine can be used to extract the columns. `ColumnCompressor` currently has an extension `EvaluateColumnCompressor` which can be used to print out the compression progress. Here is the legend for these printouts. - `ppR,C` = partial product entry at row R, column C - `sR,C` = sum term coming last from row R, column C @@ -213,7 +221,7 @@ Any adder can be used as the final adder of the final two addends produced from Here is a code snippet that shows how these components can be used to create a multiplier. -First the partial product generator is used, which has compact sign extension for rectangular products (`PartialProductGeneratorCompactRectSignExtension`) which we pass in the `RadixEncoder`, whether the operands are signed, and the kind of sign extension to use on the partial products. Note that sign extension is needed regardless of whether operands are signed or not due to Booth encoding. +First the partial product generator is used (`PartialProductGenerator`), which we pass in the `RadixEncoder`, whether the operands are signed. We operate on this generator with a compact sign extension class for rectangular products (`CompactRectSignExtension`). Note that sign extension is needed regardless of whether operands are signed or not due to Booth encoding. Next, we use the `ColumnCompressor` to compress the partial products into two final addends. @@ -222,7 +230,8 @@ Finally, we produce the product. ```dart final pp = - PartialProductGeneratorCompactRectSignExtension(a, b, RadixEncoder(radix), signedMultiplicand: true, signedMultiplier: true); + PartialProductGenerator(a, b, RadixEncoder(radix), signedMultiplicand: true, signedMultiplier: true); + CompactRectSignExtension(pp).signExtend(); final compressor = ColumnCompressor(pp)..compress(); final adder = ParallelPrefixAdder( compressor.exractRow(0), compressor.extractRow(1), BrentKung.new); diff --git a/lib/src/arithmetic/addend_compressor.dart b/lib/src/arithmetic/addend_compressor.dart index fa92e7c25..8c4d30c42 100644 --- a/lib/src/arithmetic/addend_compressor.dart +++ b/lib/src/arithmetic/addend_compressor.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // addend_compressor.dart diff --git a/lib/src/arithmetic/evaluate_partial_product.dart b/lib/src/arithmetic/evaluate_partial_product.dart index 97339bea9..bbf82bab4 100644 --- a/lib/src/arithmetic/evaluate_partial_product.dart +++ b/lib/src/arithmetic/evaluate_partial_product.dart @@ -11,7 +11,7 @@ import 'package:rohd/rohd.dart'; import 'package:rohd_hcl/rohd_hcl.dart'; /// The following routines are useful only during testing -extension TestPartialProductSignage on PartialProductGenerator { +extension TestPartialProductSignage on PartialProductGeneratorBase { /// Return true if multiplicand is truly signed (fixed or runtime) bool isSignedMultiplicand() => (selectSignedMultiplicand == null) ? signedMultiplicand @@ -28,7 +28,7 @@ extension TestPartialProductSignage on PartialProductGenerator { /// Debug routines for printing out partial product matrix during /// simulation with live logic values -extension EvaluateLivePartialProduct on PartialProductGenerator { +extension EvaluateLivePartialProduct on PartialProductGeneratorBase { /// Accumulate the partial products and return as BigInt BigInt evaluate() { final maxW = maxWidth(); diff --git a/lib/src/arithmetic/multiplier.dart b/lib/src/arithmetic/multiplier.dart index 2a6aca753..5714d70f8 100644 --- a/lib/src/arithmetic/multiplier.dart +++ b/lib/src/arithmetic/multiplier.dart @@ -188,7 +188,7 @@ abstract class MultiplyAccumulate extends Module { } } -/// An implementation of an integer multiplier using compression trees +/// An implementation of an integer multiplier using compression trees. class CompressionTreeMultiplier extends Multiplier { /// The clk for the pipelined version of column compression. Logic? clk; @@ -207,7 +207,7 @@ class CompressionTreeMultiplier extends Multiplier { /// and an [Adder] generator functor [adderGen] for the final adder. /// /// Sign extension methodology is defined by the partial product generator - /// supplied via [ppGen]. + /// supplied via [seGen]. /// /// [a] multiplicand and [b] multiplier are the product terms and they can /// be different widths allowing for rectangular multiplication. @@ -240,19 +240,16 @@ class CompressionTreeMultiplier extends Multiplier { super.selectSignedMultiplier, Adder Function(Logic a, Logic b, {Logic? carryIn}) adderGen = NativeAdder.new, - PartialProductGenerator Function(Logic a, Logic b, RadixEncoder encoder, - {required bool signedMultiplier, - required bool signedMultiplicand, - Logic? selectSignedMultiplier, - Logic? selectSignedMultiplicand}) - ppGen = PartialProductGeneratorCompactRectSignExtension.new, + PartialProductSignExtension Function(PartialProductGeneratorBase pp, + {String name}) + seGen = CompactRectSignExtension.new, super.name = 'compression_tree_multiplier'}) { clk = (clk != null) ? addInput('clk', clk!) : null; reset = (reset != null) ? addInput('reset', reset!) : null; enable = (enable != null) ? addInput('enable', enable!) : null; final product = addOutput('product', width: a.width + b.width); - final pp = ppGen( + final pp = PartialProductGeneratorBasic( a, b, RadixEncoder(radix), @@ -262,6 +259,8 @@ class CompressionTreeMultiplier extends Multiplier { signedMultiplier: signedMultiplier, ); + seGen(pp).signExtend(); + final compressor = ColumnCompressor(clk: clk, reset: reset, enable: enable, pp) ..compress(); @@ -304,7 +303,7 @@ class CompressionTreeMultiplyAccumulate extends MultiplyAccumulate { /// always signed (default is unsigned). /// /// Sign extension methodology is defined by the partial product generator - /// supplied via [ppGen]. + /// supplied via [seGen]. /// /// Optional [selectSignedMultiplicand] allows for runtime configuration of /// signed or unsigned operation, overriding the [signedMultiplicand] static @@ -334,19 +333,16 @@ class CompressionTreeMultiplyAccumulate extends MultiplyAccumulate { super.selectSignedAddend, Adder Function(Logic a, Logic b, {Logic? carryIn}) adderGen = NativeAdder.new, - PartialProductGenerator Function(Logic a, Logic b, RadixEncoder encoder, - {required bool signedMultiplier, - required bool signedMultiplicand, - Logic? selectSignedMultiplier, - Logic? selectSignedMultiplicand}) - ppGen = PartialProductGeneratorCompactRectSignExtension.new, + PartialProductSignExtension Function(PartialProductGeneratorBase pp, + {String name}) + seGen = CompactRectSignExtension.new, super.name = 'compression_tree_mac'}) { clk = (clk != null) ? addInput('clk', clk) : null; reset = (reset != null) ? addInput('reset', reset) : null; enable = (enable != null) ? addInput('enable', enable) : null; final accumulate = addOutput('accumulate', width: a.width + b.width + 1); - final pp = ppGen( + final pp = PartialProductGeneratorBasic( a, b, RadixEncoder(radix), @@ -356,6 +352,8 @@ class CompressionTreeMultiplyAccumulate extends MultiplyAccumulate { signedMultiplier: signedMultiplier, ); + seGen(pp).signExtend(); + final lastLength = pp.partialProducts[pp.rows - 1].length + pp.rowShift[pp.rows - 1]; diff --git a/lib/src/arithmetic/parallel_prefix_operations.dart b/lib/src/arithmetic/parallel_prefix_operations.dart index 12d60f3f8..9c062eed7 100644 --- a/lib/src/arithmetic/parallel_prefix_operations.dart +++ b/lib/src/arithmetic/parallel_prefix_operations.dart @@ -61,7 +61,7 @@ class Ripple extends ParallelPrefix { /// Sklansky shaped ParallelPrefix tree class Sklansky extends ParallelPrefix { /// Sklansky constructor - Sklansky(List inps, Logic Function(Logic, Logic) op) + Sklansky(List inps, Logic Function(Logic term1, Logic term2) op) : super(inps, 'sklansky') { final iseq = []; @@ -90,7 +90,7 @@ class Sklansky extends ParallelPrefix { /// KoggeStone shaped ParallelPrefix tree class KoggeStone extends ParallelPrefix { /// KoggeStone constructor - KoggeStone(List inps, Logic Function(Logic, Logic) op) + KoggeStone(List inps, Logic Function(Logic term1, Logic term2) op) : super(inps, 'kogge_stone') { final iseq = []; @@ -117,7 +117,7 @@ class KoggeStone extends ParallelPrefix { /// BrentKung shaped ParallelPrefix tree class BrentKung extends ParallelPrefix { /// BrentKung constructor - BrentKung(List inps, Logic Function(Logic, Logic) op) + BrentKung(List inps, Logic Function(Logic term1, Logic term2) op) : super(inps, 'brent_kung') { final iseq = []; diff --git a/lib/src/arithmetic/partial_product_generator.dart b/lib/src/arithmetic/partial_product_generator.dart index 048d3b92f..c8ceeb962 100644 --- a/lib/src/arithmetic/partial_product_generator.dart +++ b/lib/src/arithmetic/partial_product_generator.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // partial_product_generator.dart @@ -25,7 +25,7 @@ class SignBit extends Logic { } /// A [PartialProductArray] is a class that holds a set of partial products -/// for manipulation by [PartialProductGenerator] and [ColumnCompressor]. +/// for manipulation by [PartialProductGeneratorBase] and [ColumnCompressor]. abstract class PartialProductArray { /// name used for PartialProductGenerators final String name; @@ -43,7 +43,7 @@ abstract class PartialProductArray { /// extension routines late final List> partialProducts; - /// rows of partial products + /// Number of rows of partial products. int get rows => partialProducts.length; /// Return the actual largest width of all rows @@ -181,34 +181,33 @@ abstract class PartialProductArray { partialProducts[row].insertAll(col - rowShift[row], list); } -/// A [PartialProductGenerator] class that generates a set of partial products. -/// Essentially a set of -/// shifted rows of [Logic] addends generated by Booth recoding and -/// manipulated by sign extension, before being compressed -abstract class PartialProductGenerator extends PartialProductArray { - /// Get the shift increment between neighboring product rows +/// A [PartialProductGeneratorBase] class that generates a set of partial +/// products. Essentially a set of shifted rows of [Logic] addends generated by +/// Booth recoding and manipulated by sign extension, before being compressed. +abstract class PartialProductGeneratorBase extends PartialProductArray { + /// Get the shift increment between neighboring product rows. int get shift => selector.shift; - /// The multiplicand term + /// The multiplicand term. Logic get multiplicand => selector.multiplicand; - /// The multiplier term + /// The multiplier term. Logic get multiplier => encoder.multiplier; - /// Encoder for the full multiply operand + /// Encoder for the full multiply operand. late final MultiplierEncoder encoder; /// Selector for the multiplicand which uses the encoder to index into - /// multiples of the multiplicand and generate partial products + /// multiples of the multiplicand and generate partial products. late final MultiplicandSelector selector; - /// [multiplicand] operand is always signed + /// [multiplicand] operand is always signed. final bool signedMultiplicand; - /// [multiplier] operand is always signed + /// [multiplier] operand is always signed. final bool signedMultiplier; - /// Used to avoid sign extending more than once + /// Used to avoid sign extending more than once. bool isSignExtended = false; /// If not null, use this signal to select between signed and unsigned @@ -219,21 +218,25 @@ abstract class PartialProductGenerator extends PartialProductArray { /// [multiplier]. final Logic? selectSignedMultiplier; - /// Construct a [PartialProductGenerator] -- the partial product matrix. + /// Construct a [PartialProductGeneratorBase] -- the partial product matrix. /// /// [signedMultiplier] generates a fixed signed encoder versus using /// [selectSignedMultiplier] which is a runtime sign selection [Logic] /// in which case [signedMultiplier] must be false. - PartialProductGenerator( + PartialProductGeneratorBase( Logic multiplicand, Logic multiplier, RadixEncoder radixEncoder, {this.signedMultiplicand = false, - this.selectSignedMultiplicand, this.signedMultiplier = false, + this.selectSignedMultiplicand, this.selectSignedMultiplier, super.name = 'ppg'}) { if (signedMultiplier && (selectSignedMultiplier != null)) { throw RohdHclException('sign reconfiguration requires signed=false'); } + if (signedMultiplicand && (selectSignedMultiplicand != null)) { + throw RohdHclException('multiplicand sign reconfiguration requires ' + 'signedMultiplicand=false'); + } encoder = MultiplierEncoder(multiplier, radixEncoder, signedMultiplier: signedMultiplier, selectSignedMultiplier: selectSignedMultiplier); @@ -250,7 +253,6 @@ abstract class PartialProductGenerator extends PartialProductArray { 'or equal to ${selector.shift + (signedMultiplier ? 1 : 0)}'); } _build(); - signExtend(); } /// Perform sign extension (defined in child classes) @@ -268,43 +270,4 @@ abstract class PartialProductGenerator extends PartialProductArray { rowShift.add(row * shift); } } - - /// Helper function for sign extension routines: - /// For signed operands, set the MSB to [sign], otherwise add this [sign] bit. - void addStopSign(List addend, SignBit sign) { - if (!signedMultiplicand) { - addend.add(sign); - } else { - addend.last = sign; - } - } - - /// Helper function for sign extension routines: - /// For signed operands, flip the MSB, otherwise add this [sign] bit. - void addStopSignFlip(List addend, SignBit sign) { - if (!signedMultiplicand) { - if (selectSignedMultiplicand == null) { - addend.add(sign); - } else { - addend.add(SignBit(mux(selectSignedMultiplicand!, ~addend.last, sign), - inverted: selectSignedMultiplicand != null)); - } - } else { - addend.last = SignBit(~addend.last, inverted: true); - } - } -} - -/// A Partial Product Generator with no sign extension -class PartialProductGeneratorNoneSignExtension extends PartialProductGenerator { - /// Construct a basic Partial Product Generator - PartialProductGeneratorNoneSignExtension( - super.multiplicand, super.multiplier, super.radixEncoder, - {super.signedMultiplicand, - super.signedMultiplier, - super.selectSignedMultiplicand, - super.selectSignedMultiplier}); - - @override - void signExtend() {} } diff --git a/lib/src/arithmetic/partial_product_sign_extend.dart b/lib/src/arithmetic/partial_product_sign_extend.dart index 33aa48ae7..caffc6eb5 100644 --- a/lib/src/arithmetic/partial_product_sign_extend.dart +++ b/lib/src/arithmetic/partial_product_sign_extend.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // partial_product_test_sign_extend.dart @@ -12,7 +12,7 @@ import 'dart:math'; import 'package:rohd/rohd.dart'; import 'package:rohd_hcl/rohd_hcl.dart'; -/// Methods for sign extending the [PartialProductGenerator] +/// Methods for sign extending the [PartialProductGeneratorBase] enum SignExtension { /// No sign extension none, @@ -31,40 +31,160 @@ enum SignExtension { } /// Used to test different sign extension methods -typedef PPGFunction = PartialProductGenerator Function( +typedef PPGFunction = PartialProductGeneratorBase Function( Logic a, Logic b, RadixEncoder radixEncoder, {bool signedMultiplicand, Logic? selectSignedMultiplicand, bool signedMultiplier, Logic? selectSignedMultiplier}); +/// API for sign extension classes +abstract class PartialProductSignExtension { + /// name used for PartialProductSignExtension + final String name; + + /// The partial product generator we are sign extending. + final PartialProductGeneratorBase ppg; + + /// multiplicand operand is always signed. + bool get signedMultiplicand => ppg.signedMultiplicand; + + /// multiplier operand is always signed. + bool get signedMultiplier => ppg.signedMultiplier; + + /// If not null, use this signal to select between signed and unsigned + /// multiplicand. + Logic? get selectSignedMultiplicand => ppg.selectSignedMultiplicand; + + /// If not null, use this signal to select between signed and unsigned + /// multiplier. + Logic? get selectSignedMultiplier => ppg.selectSignedMultiplier; + + /// Number of rows of partial products. + int get rows => ppg.rows; + + /// The actual shift in each row. This value will be modified by the + /// sign extension routine used when folding in a sign bit from another + /// row. + List get rowShift => ppg.rowShift; + + /// Partial Products output. Generated by selector and extended by sign + /// extension routines. + List> get partialProducts => ppg.partialProducts; + + /// Used to avoid sign extending more than once. + bool get isSignExtended => ppg.isSignExtended; + set isSignExtended(bool set) { + ppg.isSignExtended = set; + } + + /// Get the shift increment between neighboring product rows. + int get shift => ppg.shift; + + /// Encoder for the full multiply operand. Used here just to get signs[]. + MultiplierEncoder get encoder => ppg.encoder; // signs getter + + // is multiplicand.width == entry.length? + // width=> multiplicand.width + shift - 1; + /// Only used to get width as above + MultiplicandSelector get selector => ppg.selector; // selector.width accessed + + /// Sign Extension class that operates on a [PartialProductGeneratorBase] + /// and sign-extends the entries. + PartialProductSignExtension(this.ppg, {this.name = 'no_sign_extension'}) { + if (signedMultiplier && (selectSignedMultiplier != null)) { + throw RohdHclException('sign reconfiguration requires signed=false'); + } + if (signedMultiplicand && (selectSignedMultiplicand != null)) { + throw RohdHclException('multiplicand sign reconfiguration requires ' + 'signedMultiplicand=false'); + } + } + + /// Execute the sign extension, overridden to specialize. + void signExtend(); + + /// Helper function for sign extension routines: + /// For signed operands, set the MSB to [sign], otherwise add this [sign] bit. + void addStopSign(List addend, SignBit sign) { + if (!signedMultiplicand) { + addend.add(sign); + } else { + addend.last = sign; + } + } + + /// Helper function for sign extension routines: + /// For signed operands, flip the MSB, otherwise add this [sign] bit. + void addStopSignFlip(List addend, SignBit sign) { + if (!signedMultiplicand) { + if (selectSignedMultiplicand == null) { + addend.add(sign); + } else { + addend.add(SignBit(mux(selectSignedMultiplicand!, ~addend.last, sign), + inverted: selectSignedMultiplicand != null)); + } + } else { + addend.last = SignBit(~addend.last, inverted: true); + } + } +} + /// Used to test different sign extension methods -PPGFunction curryPartialProductGenerator(SignExtension signExtension) => +typedef SignExtensionFunction = PartialProductSignExtension + Function(PartialProductGeneratorBase ppg, {String name}); + +/// Used to test different sign extension methods +SignExtensionFunction currySignExtensionFunction(SignExtension signExtension) => switch (signExtension) { - SignExtension.none => PartialProductGeneratorNoneSignExtension.new, - SignExtension.brute => PartialProductGeneratorBruteSignExtension.new, - SignExtension.stopBits => - PartialProductGeneratorStopBitsSignExtension.new, - SignExtension.compact => PartialProductGeneratorCompactSignExtension.new, - SignExtension.compactRect => - PartialProductGeneratorCompactRectSignExtension.new, + SignExtension.none => NoneSignExtension.new, + SignExtension.brute => BruteSignExtension.new, + SignExtension.stopBits => StopBitsSignExtension.new, + SignExtension.compact => CompactSignExtension.new, + SignExtension.compactRect => CompactRectSignExtension.new, }; -/// These other sign extensions are for assisting with testing and debugging. -/// More robust and simpler sign extensions in case -/// complex sign extension routines obscure other bugs. +/// A range of SignExtension classes to be used in building new arithmetic +/// building blocks. Start with [BruteSignExtension] when composing new +/// partial product array shapes as it should work in all situations. -/// A Partial Product Generator using Brute Sign Extension -class PartialProductGeneratorBruteSignExtension - extends PartialProductGenerator { - /// Construct a brute-force sign extending Partial Product Generator - PartialProductGeneratorBruteSignExtension( +/// A Partial Product Generator using None Sign Extension +class NoneSignExtension extends PartialProductSignExtension { + /// Construct a no sign-extension class. + NoneSignExtension(super.ppg, {super.name = 'none_sign_extension'}); + + /// Fully sign extend the PP array: useful for reference only + @override + void signExtend() {} +} + +/// A concrete base class for partial product generation +class PartialProductGeneratorBasic extends PartialProductGeneratorBase { + /// The extension routine we will be using. + late final PartialProductSignExtension extender; + + /// Construct a none sign extending Partial Product Generator + PartialProductGeneratorBasic( super.multiplicand, super.multiplier, super.radixEncoder, {super.signedMultiplicand, super.signedMultiplier, super.selectSignedMultiplicand, super.selectSignedMultiplier, - super.name = 'brute'}); + super.name = 'none'}) { + extender = NoneSignExtension(this); + signExtend(); + } + + @override + void signExtend() { + extender.signExtend(); + } +} + +/// A Brute Sign Extension class. +class BruteSignExtension extends PartialProductSignExtension { + /// Construct a brute-force sign extending Partial Product Generator + BruteSignExtension(super.ppg, {super.name = 'brute_sign_extension'}); /// Fully sign extend the PP array: useful for reference only @override @@ -101,19 +221,36 @@ class PartialProductGeneratorBruteSignExtension } } -/// A Partial Product Generator using Brute Sign Extension -class PartialProductGeneratorCompactSignExtension - extends PartialProductGenerator { - /// Construct a compact sign extending Partial Product Generator - PartialProductGeneratorCompactSignExtension( +/// A wrapper class for [BruteSignExtension] we used +/// during refactoring to be compatible with old calls. +class PartialProductGeneratorBruteSignExtension + extends PartialProductGeneratorBase { + /// The extension routine we will be using. + late final PartialProductSignExtension extender; + + /// Construct a compact rect sign extending Partial Product Generator + PartialProductGeneratorBruteSignExtension( super.multiplicand, super.multiplier, super.radixEncoder, {super.signedMultiplicand, - super.selectSignedMultiplicand, super.signedMultiplier, + super.selectSignedMultiplicand, super.selectSignedMultiplier, - super.name = 'compact'}); + super.name = 'brute'}) { + extender = BruteSignExtension(this); + signExtend(); + } + + @override + void signExtend() { + extender.signExtend(); + } +} + +/// A Compact Sign Extension class. +class CompactSignExtension extends PartialProductSignExtension { + /// Construct a compact sign extendsion class. + CompactSignExtension(super.ppg, {super.name = 'compact_sign_extension'}); - /// Sign extend the PP array using stop bits without adding a row. @override void signExtend() { // An implementation of @@ -226,17 +363,35 @@ class PartialProductGeneratorCompactSignExtension } } -/// A Partial Product Generator using Brute Sign Extension -class PartialProductGeneratorStopBitsSignExtension - extends PartialProductGenerator { - /// Construct a stop bits sign extending Partial Product Generator - PartialProductGeneratorStopBitsSignExtension( +/// A wrapper class for [CompactSignExtension] we used +/// during refactoring to be compatible with old calls. +class PartialProductGeneratorCompactSignExtension + extends PartialProductGeneratorBase { + /// The extension routine we will be using. + late final PartialProductSignExtension extender; + + /// Construct a compact sign extending Partial Product Generator + PartialProductGeneratorCompactSignExtension( super.multiplicand, super.multiplier, super.radixEncoder, {super.signedMultiplicand, - super.selectSignedMultiplicand, super.signedMultiplier, + super.selectSignedMultiplicand, super.selectSignedMultiplier, - super.name = 'stop_bits'}); + super.name = 'compact'}) { + extender = CompactSignExtension(this); + signExtend(); + } + + @override + void signExtend() { + extender.signExtend(); + } +} + +/// A StopBits Sign Extension. +class StopBitsSignExtension extends PartialProductSignExtension { + /// Construct a stop bits sign extendsion class. + StopBitsSignExtension(super.ppg, {super.name = 'stopbits_sign_extension'}); /// Sign extend the PP array using stop bits. /// If possible, fold the final carry into another row (only when rectangular @@ -314,9 +469,40 @@ class PartialProductGeneratorStopBitsSignExtension } } -/// A Partial Product Generator using Compact Rectangular Extension +// + +/// A wrapper class for [StopBitsSignExtension] we used +/// during refactoring to be compatible with old calls. +class PartialProductGeneratorStopBitsSignExtension + extends PartialProductGeneratorBase { + /// The extension routine we will be using. + late final PartialProductSignExtension extender; + + /// Construct a stop bits sign extending Partial Product Generator + PartialProductGeneratorStopBitsSignExtension( + super.multiplicand, super.multiplier, super.radixEncoder, + {super.signedMultiplicand, + super.signedMultiplier, + super.selectSignedMultiplicand, + super.selectSignedMultiplier, + super.name = 'stop_bits'}) { + extender = StopBitsSignExtension(this); + signExtend(); + } + + @override + void signExtend() { + extender.signExtend(); + } +} + +/// A wrapper class for CompactRectSignExtension we used +/// during refactoring to be compatible with old calls. class PartialProductGeneratorCompactRectSignExtension - extends PartialProductGenerator { + extends PartialProductGeneratorBase { + /// The extension routine we will be using. + late final PartialProductSignExtension extender; + /// Construct a compact rect sign extending Partial Product Generator PartialProductGeneratorCompactRectSignExtension( super.multiplicand, super.multiplier, super.radixEncoder, @@ -324,18 +510,28 @@ class PartialProductGeneratorCompactRectSignExtension super.signedMultiplier, super.selectSignedMultiplicand, super.selectSignedMultiplier, - super.name = 'compact_rect'}); + super.name = 'compact_rect'}) { + extender = CompactRectSignExtension(this); + signExtend(); + } + + @override + void signExtend() { + extender.signExtend(); + } +} +/// A Partial Product Generator using Compact Rectangular Extension +class CompactRectSignExtension extends PartialProductSignExtension { /// Sign extend the PP array using stop bits without adding a row /// This routine works with different widths of multiplicand/multiplier, /// an extension of Mohanty, B.K., Choubey designed by - /// Desmond A. Kirkpatrick + /// Desmond A. Kirkpatrick. + CompactRectSignExtension(super.ppg, + {super.name = 'compactrect_sign_extension'}); + @override void signExtend() { - if (signedMultiplicand && (selectSignedMultiplicand != null)) { - throw RohdHclException('multiplicand sign reconfiguration requires ' - 'signedMultiplicand=false'); - } if (isSignExtended) { throw RohdHclException('Partial Product array already sign-extended'); } diff --git a/lib/src/arithmetic/sign_magnitude_adder.dart b/lib/src/arithmetic/sign_magnitude_adder.dart index ea7ef9132..a21966c32 100644 --- a/lib/src/arithmetic/sign_magnitude_adder.dart +++ b/lib/src/arithmetic/sign_magnitude_adder.dart @@ -37,7 +37,7 @@ class SignMagnitudeAdder extends Adder { /// comparator. // TODO(desmonddak): this adder may need a carry-in for rounding SignMagnitudeAdder(this.aSign, super.a, this.bSign, super.b, - Adder Function(Logic, Logic, {Logic? carryIn}) adderGen, + Adder Function(Logic a, Logic b, {Logic? carryIn}) adderGen, {this.largestMagnitudeFirst = false, super.name = 'sign_magnitude_adder'}) { aSign = addInput('aSign', aSign); diff --git a/test/arithmetic/addend_compressor_test.dart b/test/arithmetic/addend_compressor_test.dart index d479e940f..56214dcb7 100644 --- a/test/arithmetic/addend_compressor_test.dart +++ b/test/arithmetic/addend_compressor_test.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // addend_compressor_test.dart @@ -18,7 +18,7 @@ import 'package:test/test.dart'; /// This [CompressorTestMod] module is used to test instantiation, where we can /// catch trace errors (IO not added) not found in a simple test instantiation. class CompressorTestMod extends Module { - late final PartialProductGenerator pp; + late final PartialProductGeneratorBase pp; late final ColumnCompressor compressor; @@ -36,8 +36,10 @@ class CompressorTestMod extends Module { clk = addInput('clk', iclk); } - pp = PartialProductGeneratorCompactRectSignExtension(a, b, encoder, + final pp = PartialProductGeneratorBasic(a, b, encoder, signedMultiplicand: signed, signedMultiplier: signed); + CompactRectSignExtension(pp).signExtend(); + compressor = ColumnCompressor(pp, clk: clk); compressor.compress(); final r0 = addOutput('r0', width: compressor.columns.length); @@ -111,12 +113,13 @@ void main() { selectSignedMultiplicand!.put(signed ? 1 : 0); selectSignedMultiplier!.put(signed ? 1 : 0); } - final pp = PartialProductGeneratorCompactRectSignExtension( - a, b, encoder, + final pp = PartialProductGeneratorBasic(a, b, encoder, signedMultiplicand: !useSelect & signed, signedMultiplier: !useSelect & signed, selectSignedMultiplicand: selectSignedMultiplicand, selectSignedMultiplier: selectSignedMultiplier); + CompactRectSignExtension(pp).signExtend(); + expect(pp.evaluate(), equals(bA * bB)); final compressor = ColumnCompressor(pp); expect(compressor.evaluate().$1, equals(bA * bB)); diff --git a/test/arithmetic/multiplier_encoder_test.dart b/test/arithmetic/multiplier_encoder_test.dart index 13570b83f..6ed65be82 100644 --- a/test/arithmetic/multiplier_encoder_test.dart +++ b/test/arithmetic/multiplier_encoder_test.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // multiplier_encoder_test.dart @@ -26,7 +26,7 @@ import 'package:test/test.dart'; // - Rectangular: again, need to cross a shift interval // - Sign Extension: [brute, stop, compact, compactRect] -void testPartialProductExhaustive(PartialProductGenerator pp) { +void testPartialProductExhaustive(PartialProductGeneratorBase pp) { final widthX = pp.selector.multiplicand.width; final widthY = pp.encoder.multiplier.width; @@ -72,7 +72,7 @@ void testPartialProductExhaustive(PartialProductGenerator pp) { }); } -void testPartialProductRandom(PartialProductGenerator pp, int iterations) { +void testPartialProductRandom(PartialProductGeneratorBase pp, int iterations) { final widthX = pp.selector.multiplicand.width; final widthY = pp.encoder.multiplier.width; @@ -119,7 +119,8 @@ void testPartialProductRandom(PartialProductGenerator pp, int iterations) { }); } -void testPartialProductSingle(PartialProductGenerator pp, BigInt X, BigInt Y) { +void testPartialProductSingle( + PartialProductGeneratorBase pp, BigInt X, BigInt Y) { test( 'single: ${pp.name} R${pp.selector.radix} ' 'WD=${pp.multiplicand.width} WM=${pp.multiplier.width} ' @@ -137,7 +138,7 @@ void testPartialProductSingle(PartialProductGenerator pp, BigInt X, BigInt Y) { }); } -void checkPartialProduct(PartialProductGenerator pp, BigInt iX, BigInt iY) { +void checkPartialProduct(PartialProductGeneratorBase pp, BigInt iX, BigInt iY) { final widthX = pp.selector.multiplicand.width; final widthY = pp.encoder.multiplier.width; @@ -167,11 +168,13 @@ void main() { final width = log2Ceil(radix) + (signedMultiplier ? 1 : 0); for (final signExtension in SignExtension.values.where((e) => e != SignExtension.none)) { - final ppg = curryPartialProductGenerator(signExtension); - final pp = ppg(Logic(name: 'X', width: width), - Logic(name: 'Y', width: width), RadixEncoder(radix), + final pp = PartialProductGeneratorBasic( + Logic(name: 'X', width: width), + Logic(name: 'Y', width: width), + RadixEncoder(radix), signedMultiplicand: signedMultiplicand, signedMultiplier: signedMultiplier); + currySignExtensionFunction(signExtension)(pp).signExtend(); testPartialProductExhaustive(pp); } } @@ -196,10 +199,11 @@ void main() { SignedBigInt.fromSignedInt(j, width, signed: signedMultiplier); a.put(X); b.put(Y); - final PartialProductGenerator pp; - pp = PartialProductGeneratorStopBitsSignExtension(a, b, encoder, + final PartialProductGeneratorBase pp; + pp = PartialProductGeneratorBasic(a, b, encoder, signedMultiplicand: signedMultiplicand, signedMultiplier: signedMultiplier); + StopBitsSignExtension(pp).signExtend(); testPartialProductSingle(pp, X, Y); } } @@ -221,9 +225,8 @@ void main() { for (final signExtension in SignExtension.values .where((e) => e != SignExtension.none)) { final width = log2Ceil(radix) + (signMultiplier ? 1 : 0); - final ppg = curryPartialProductGenerator(signExtension); - final PartialProductGenerator pp; - pp = ppg( + final PartialProductGeneratorBase pp; + pp = PartialProductGeneratorBasic( Logic(name: 'X', width: width), Logic(name: 'Y', width: width), encoder, @@ -233,6 +236,7 @@ void main() { selectMultiplicand ? selectSignMultiplicand : null, selectSignedMultiplier: selectMultiplier ? selectSignMultiplier : null); + currySignExtensionFunction(signExtension)(pp).signExtend(); testPartialProductExhaustive(pp); } @@ -255,7 +259,7 @@ void main() { for (final selectMultiplier in [false, true]) { selectSignMultiplicand.put(selectMultiplicand ? 1 : 0); selectSignMultiplier.put(selectMultiplier ? 1 : 0); - final PartialProductGenerator pp; + final PartialProductGeneratorBase pp; pp = PartialProductGeneratorStopBitsSignExtension( Logic(name: 'X', width: width), Logic(name: 'Y', width: width), @@ -286,9 +290,13 @@ void main() { for (var width = shift; width < min(5, 2 * shift); width++) { for (final signExtension in SignExtension.values.where((e) => e != SignExtension.none)) { - final ppg = curryPartialProductGenerator(signExtension); - final pp = ppg(Logic(name: 'X', width: width), - Logic(name: 'Y', width: width), encoder); + // final ppg = curryPartialProductGenerator(signExtension); + final pp = PartialProductGeneratorBasic( + Logic(name: 'X', width: width), + Logic(name: 'Y', width: width), + encoder); + currySignExtensionFunction(signExtension)(pp).signExtend(); + testPartialProductExhaustive(pp); } } @@ -308,11 +316,14 @@ void main() { SignExtension.stopBits, SignExtension.compactRect ]) { - final ppg = curryPartialProductGenerator(signExtension); - final pp = ppg(Logic(name: 'X', width: widthX), - Logic(name: 'Y', width: widthY), encoder, + final pp = PartialProductGeneratorBasic( + Logic(name: 'X', width: widthX), + Logic(name: 'Y', width: widthY), + encoder, signedMultiplicand: signedMultiplicand, signedMultiplier: signedMultiplier); + currySignExtensionFunction(signExtension)(pp).signExtend(); + testPartialProductExhaustive(pp); } } @@ -328,9 +339,12 @@ void main() { for (var width = shift; width < min(5, 2 * shift); width++) { for (final signExtension in SignExtension.values.where((e) => e != SignExtension.none)) { - final ppg = curryPartialProductGenerator(signExtension); - final pp = ppg(Logic(name: 'X', width: width), - Logic(name: 'Y', width: width), encoder); + final pp = PartialProductGeneratorBasic( + Logic(name: 'X', width: width), + Logic(name: 'Y', width: width), + encoder); + currySignExtensionFunction(signExtension)(pp).signExtend(); + testPartialProductExhaustive(pp); } } @@ -352,11 +366,14 @@ void main() { SignExtension.stopBits, SignExtension.compactRect ]) { - final ppg = curryPartialProductGenerator(signExtension); - final pp = ppg(Logic(name: 'X', width: widthX), - Logic(name: 'Y', width: widthY), encoder, + final pp = PartialProductGeneratorBasic( + Logic(name: 'X', width: widthX), + Logic(name: 'Y', width: widthY), + encoder, signedMultiplicand: signedMultiplicand, signedMultiplier: signedMultiplier); + currySignExtensionFunction(signExtension)(pp).signExtend(); + testPartialProductExhaustive(pp); } } @@ -385,10 +402,9 @@ void main() { final skew = align.$3; - final pp = PartialProductGeneratorCompactRectSignExtension( - Logic(name: 'X', width: width), - Logic(name: 'Y', width: width + skew), - encoder); + final pp = PartialProductGeneratorBasic(Logic(name: 'X', width: width), + Logic(name: 'Y', width: width + skew), encoder); + CompactRectSignExtension(pp).signExtend(); testPartialProductRandom(pp, 10); } @@ -405,16 +421,17 @@ void main() { final multiplicand = Logic(width: widthX); final multiplier = Logic(width: widthY); for (final signed in [false, true]) { - final pp = PartialProductGeneratorCompactSignExtension( + final ppg = PartialProductGeneratorBasic( multiplicand, multiplier, radixEncoder, signedMultiplicand: signed, signedMultiplier: signed); + CompactSignExtension(ppg).signExtend(); for (var i = BigInt.zero; i < BigInt.from(limitX); i += BigInt.one) { for (var j = BigInt.zero; j < BigInt.from(limitY); j += BigInt.one) { final X = signed ? i.toSigned(widthX) : i.toUnsigned(widthX); final Y = signed ? j.toSigned(widthY) : j.toUnsigned(widthY); multiplicand.put(X); multiplier.put(Y); - final value = pp.evaluate(); + final value = ppg.evaluate(); expect(value, equals(X * Y), reason: '$X * $Y = $value should be ${X * Y}'); } @@ -443,9 +460,59 @@ void main() { logicX.put(X); logicY.put(Y); logicZ.put(Z); - final pp = PartialProductGeneratorCompactRectSignExtension( - logicX, logicY, encoder, + final pp = PartialProductGeneratorBasic(logicX, logicY, encoder, + signedMultiplicand: true, signedMultiplier: true); + CompactRectSignExtension(pp).signExtend(); + + final lastLength = + pp.partialProducts[pp.rows - 1].length + pp.rowShift[pp.rows - 1]; + + final sign = logicZ[logicZ.width - 1]; + // for unsigned versus signed testing + // final sign = signed ? logicZ[logicZ.width - 1] : Const(0); + final l = [for (var i = 0; i < logicZ.width; i++) logicZ[i]]; + while (l.length < lastLength) { + l.add(sign); + } + l + ..add(~sign) + ..add(Const(1)); + // print(pp.representation()); + + pp.partialProducts.insert(0, l); + pp.rowShift.insert(0, 0); + // print(pp.representation()); + + if (pp.evaluate() != product) { + stdout.write('Fail: $X * $Y: ${pp.evaluate()} vs expected $product\n'); + } + expect(pp.evaluate(), equals(product)); + }); + + test('single MAC partial product sign extension test', () async { + final encoder = RadixEncoder(16); + const widthX = 8; + const widthY = 18; + + const i = 1478; + const j = 9; + const k = 0; + + final X = BigInt.from(i).toSigned(widthX); + final Y = BigInt.from(j).toSigned(widthY); + final Z = BigInt.from(k).toSigned(widthX + widthY); + // print('X=$X Y=$Y, Z=$Z'); + final product = X * Y + Z; + + final logicX = Logic(name: 'X', width: widthX); + final logicY = Logic(name: 'Y', width: widthY); + final logicZ = Logic(name: 'Z', width: widthX + widthY); + logicX.put(X); + logicY.put(Y); + logicZ.put(Z); + final pp = PartialProductGeneratorBasic(logicX, logicY, encoder, signedMultiplicand: true, signedMultiplier: true); + CompactRectSignExtension(pp).signExtend(); final lastLength = pp.partialProducts[pp.rows - 1].length + pp.rowShift[pp.rows - 1]; diff --git a/test/arithmetic/multiplier_test.dart b/test/arithmetic/multiplier_test.dart index 722e1e38d..ad7ad75e5 100644 --- a/test/arithmetic/multiplier_test.dart +++ b/test/arithmetic/multiplier_test.dart @@ -188,29 +188,19 @@ void main() { MultiplierCallback curryCompressionTreeMultiplier(int radix, ParallelPrefix Function(List, Logic Function(Logic, Logic)) ppTree, - {PPGFunction ppGen = PartialProductGeneratorCompactRectSignExtension.new, + {SignExtensionFunction seGen = CompactRectSignExtension.new, bool signedMultiplicand = false, bool signedMultiplier = false, Logic? selectSignedMultiplicand, Logic? selectSignedMultiplier}) { - String genName(Logic a, Logic b) => ppGen( - a, - b, - RadixEncoder(radix), - signedMultiplicand: signedMultiplicand, - signedMultiplier: signedMultiplier, - selectSignedMultiplicand: - selectSignedMultiplicand != null ? Logic() : null, - selectSignedMultiplier: - selectSignedMultiplier != null ? Logic() : null, - ).name; + String genName(Logic a, Logic b) => + seGen(PartialProductGeneratorBasic(a, b, RadixEncoder(radix))).name; final signage = ' SD=${signedMultiplicand ? 1 : 0}' ' SM=${signedMultiplier ? 1 : 0}' ' SelD=${(selectSignedMultiplicand != null) ? 1 : 0}' ' SelM=${(selectSignedMultiplier != null) ? 1 : 0}'; return (a, b, {selectSignedMultiplicand, selectSignedMultiplier}) => CompressionTreeMultiplier(a, b, radix, - ppGen: ppGen, signedMultiplicand: signedMultiplicand, signedMultiplier: signedMultiplier, selectSignedMultiplicand: selectSignedMultiplicand, @@ -223,8 +213,7 @@ void main() { MultiplyAccumulateCallback curryMultiplierAsMultiplyAccumulate(int radix, {ParallelPrefix Function(List, Logic Function(Logic, Logic)) ppTree = KoggeStone.new, - PPGFunction ppGen = - PartialProductGeneratorCompactRectSignExtension.new, + SignExtensionFunction seGen = CompactRectSignExtension.new, bool signedMultiplicand = false, bool signedMultiplier = false, Logic? selectSignedMultiplicand, @@ -240,7 +229,7 @@ void main() { curryCompressionTreeMultiplier( radix, ppTree, - ppGen: ppGen, + seGen: seGen, signedMultiplicand: signedMultiplicand, signedMultiplier: signedMultiplier, selectSignedMultiplicand: selectSignedMultiplicand, @@ -251,7 +240,7 @@ void main() { int radix, { ParallelPrefix Function(List, Logic Function(Logic, Logic)) ppTree = KoggeStone.new, - PPGFunction ppGen = PartialProductGeneratorCompactRectSignExtension.new, + SignExtensionFunction seGen = CompactRectSignExtension.new, bool signedMultiplicand = false, bool signedMultiplier = false, bool signedAddend = false, @@ -259,24 +248,14 @@ void main() { Logic? selectSignedMultiplier, Logic? selectSignedAddend, }) { - String genName(Logic a, Logic b) => ppGen( - a, - b, - RadixEncoder(radix), - signedMultiplicand: signedMultiplicand, - signedMultiplier: signedMultiplier, - selectSignedMultiplicand: - selectSignedMultiplicand != null ? Logic() : null, - selectSignedMultiplier: - selectSignedMultiplier != null ? Logic() : null, - ).name; + String genName(Logic a, Logic b) => + seGen(PartialProductGeneratorBasic(a, b, RadixEncoder(radix))).name; final signage = ' SD=${signedMultiplicand ? 1 : 0}' ' SM=${signedMultiplier ? 1 : 0}' ' SelD=${(selectSignedMultiplicand != null) ? 1 : 0}' ' SelM=${(selectSignedMultiplier != null) ? 1 : 0}'; return (a, b, c) => CompressionTreeMultiplyAccumulate(a, b, c, radix, - ppGen: ppGen, signedMultiplicand: signedMultiplicand, signedMultiplier: signedMultiplier, signedAddend: signedAddend, @@ -305,9 +284,9 @@ void main() { for (final width in [3, 4]) { for (final signExtension in SignExtension.values.where((e) => e != SignExtension.none)) { - final ppg = curryPartialProductGenerator(signExtension); + final seg = currySignExtensionFunction(signExtension); testMultiplyAccumulateRandom(width, 10, - curryMultiplierAsMultiplyAccumulate(radix, ppGen: ppg)); + curryMultiplierAsMultiplyAccumulate(radix, seGen: seg)); } } } From 519f987fdcca53b21ba2c95a4f8f57c9a13b3f96 Mon Sep 17 00:00:00 2001 From: Desmond Kirkpatrick Date: Fri, 24 Jan 2025 14:32:39 -0800 Subject: [PATCH 3/3] Naming fp (#156) --- confapp/pubspec.yaml | 6 +- doc/components/floating_point.md | 3 +- doc/components/multiplier_components.md | 24 +- lib/src/arithmetic/addend_compressor.dart | 32 +- lib/src/arithmetic/adder.dart | 16 +- lib/src/arithmetic/fixed_to_float.dart | 34 +- .../floating_point/floating_point_adder.dart | 55 ++- .../floating_point_adder_round.dart | 319 ++++++++++-------- .../floating_point_adder_simple.dart | 107 +++--- .../floating_point_multiplier.dart | 40 ++- .../floating_point_multiplier_simple.dart | 162 +++++---- lib/src/arithmetic/multiplicand_selector.dart | 41 ++- lib/src/arithmetic/multiplier.dart | 157 ++++++--- lib/src/arithmetic/multiplier_encoder.dart | 69 ++-- lib/src/arithmetic/ones_complement_adder.dart | 26 +- .../parallel_prefix_operations.dart | 50 ++- .../arithmetic/partial_product_generator.dart | 22 +- .../partial_product_sign_extend.dart | 121 +++++-- .../arithmetic/signals/fixed_point_logic.dart | 3 +- .../signals/floating_point_logic.dart | 94 ++++-- .../floating_point_32_value.dart | 19 ++ .../floating_point_bf16_value.dart | 64 ++++ .../floating_point_value.dart | 28 ++ ...nfig_floating_point_multiplier_simple.dart | 25 +- lib/src/encodings/tree_one_hot_to_binary.dart | 6 +- lib/src/utils.dart | 9 +- pubspec.yaml | 4 +- test/arithmetic/addend_compressor_test.dart | 4 +- .../floating_point_adder_simple_test.dart | 59 +--- .../floating_point_multiplier_test.dart | 315 ++++++++++++++++- .../floating_point_value_test.dart | 13 + test/arithmetic/multiplier_encoder_test.dart | 46 ++- test/arithmetic/multiplier_test.dart | 130 +++++-- 33 files changed, 1499 insertions(+), 604 deletions(-) diff --git a/confapp/pubspec.yaml b/confapp/pubspec.yaml index 820ff92f2..48c57ca38 100644 --- a/confapp/pubspec.yaml +++ b/confapp/pubspec.yaml @@ -10,7 +10,7 @@ environment: dependencies: flutter: sdk: flutter - rohd: ^0.6.0 + rohd: ^0.6.1 rohd_hcl: git: url: https://github.com/intel/rohd-hcl @@ -21,10 +21,6 @@ dependencies: bloc: ^8.1.2 google_fonts: 6.1.0 -dependency_overrides: - rohd_hcl: - path: ../ - dev_dependencies: flutter_test: sdk: flutter diff --git a/doc/components/floating_point.md b/doc/components/floating_point.md index 3a97e5f75..751ceff3b 100644 --- a/doc/components/floating_point.md +++ b/doc/components/floating_point.md @@ -74,5 +74,6 @@ A very basic [FloatingPointMultiplierSimple] component is available which does n It has options to control its performance: - 'radix': used to specify the radix of the Booth encoder (default radix=4: options are [2,4,8,16])'. -- adderGen': used to specify the kind of [Adder] used for key functions like the mantissa addition. Defaults to [NativeAdder], but you can select a [ParallelPrefixAdder] of your choice. + +- 'adderGen': used to specify the kind of [Adder] used for key functions like the mantiss addition. Defaults to [NativeAdder], but you can select a [ParallelPrefixAdder] of your choice. - 'ppTree': used to specify the type of ['ParallelPrefix'](https://intel.github.io/rohd-hcl/rohd_hcl/ParallelPrefix-class.html) used in the pther critical functions like leading-one detect. diff --git a/doc/components/multiplier_components.md b/doc/components/multiplier_components.md index ed5f2c606..0d8834a5a 100644 --- a/doc/components/multiplier_components.md +++ b/doc/components/multiplier_components.md @@ -51,7 +51,7 @@ row slice mult A few things to note: first, that we are negating by ones' complement (so we need a -0) and second, these rows do not add up to (18: 10010). For Booth encoded rows to add up properly, they need to be in twos' complement form, and they need to be sign-extended. - Here is the matrix with a crude sign extension `brute` (the table formatting is available from our [PartialProductGenerator](https://intel.github.io/rohd-hcl/rohd_hcl/PartialProductGenerator-class.html) component). With twos' complementation, and sign bits folded in (note the LSB of each row has a sign term from the previous row), these addends are correctly formed and add to (18: 10010). + Here is the matrix with a crude sign extension `brute` (the table formatting is available from our [Partial Product Generator](https://intel.github.io/rohd-hcl/rohd_hcl/PartialProductGeneratorBase-class.html) component followed by [Brute Force Sign Extension]). With twos' complementation, and sign bits folded in (note the LSB of each row has a sign term from the previous row), these addends are correctly formed and add to (18: 10010). ```text 7 6 5 4 3 2 1 0 @@ -90,7 +90,7 @@ Note that radix-4 shifts by 2 positions each row, but with only two rows and wit ## Partial Product Generator -The base class of `PartialProductGenerator` is [PartialProductArray](https://intel.github.io/rohd-hcl/rohd_hcl/PartialProductArray-class.html) which is simply a `List>` to represent addends and a `rowShift[row]` to represent the shifts in the partial product matrix. If customization is needed beyond sign extension options, routines are provided that allow for fixed customization of bit positions or conditional (mux based on a Logic) form in the `PartialProductArray`. +The base class of `PartialProductGenerator` is [PartialProductArray](https://intel.github.io/rohd-hcl/rohd_hcl/PartialProductArray-class.html) which is simply a `List>` to represent addends and a `rowShift[row]` to represent the shifts in the partial product matrix. If customization is needed beyond sign extension options, routines are provided that allow for fixed customization of bit positions or conditional (mux based on a Logic) form in the `PartialProductArray`. ```dart final ppa = PartialProductArray(a,b); @@ -100,9 +100,9 @@ ppa.muxAbsolute(row, col, condition, logic); ppa.muxAbsoluteAll(row, col, condition, List); ``` - The `PartialProductGenerator` adds to this the [RadixEncoder](https://intel.github.io/rohd-hcl/rohd_hcl/RadixEncoder-class.html) to encode the rows along with a matching `MultiplicandSelector` to create the actual mantissas used in each row. + The `PartialProductGenerator` adds to this the [RadixEncoder](https://intel.github.io/rohd-hcl/rohd_hcl/RadixEncoder-class.html) to encode the rows along with a matching `MultiplicandSelector` to create the actual mantissas used in each row. -As a building block which creates a set of rows of partial products from a multiplicand and a multiplier, it maintains the partial products as a list of rows om the `PartialProductArray` base. Its primary inputs are the multiplicand, multiplier, `RadixEncoder`, and whether the operands are signed. +As a building block which creates a set of rows of partial products from a multiplicand and a multiplier, it maintains the partial products as a list of rows on the `PartialProductArray` base. Its primary inputs are the multiplicand, multiplier, `RadixEncoder`, and whether the operands are signed. The partial product generator produces a set of addends in shifted position to be added. The main output of the component is @@ -113,7 +113,7 @@ The partial product generator produces a set of addends in shifted position to b ### Radix Encoding -An argument to the `PartialProductGenerator` is the `RadixEncoder` to be used. The [`RadixEncoder`] takes a single argument which is the radix (power of 2) to be used. +An argument to the `PartialProductGenerator` is the `RadixEncoder` to be used. The [RadixEncoder](https://intel.github.io/rohd-hcl/rohd_hcl/RadixEncoder-class.html) takes a single argument which is the radix (power of 2) to be used. Instead of using the 1's in the multiplier to select shifted versions of the multiplicand to add in a partial product matrix, radix-encoding will encode multiples of the multiplicand by examining adjacent bits of the multiplier. For radix-4, for example, for a multiplier of size M, instead of M rows of partial products, M/2 rows are formed by selecting from multiples [-2, -1, 0, 1, 2] of the multiplicand. These multiples are computed from an 3 bit slices, overlapped by 1 bit, of the multiplier. Higher radixes use wider slices of the multiplier to encode fewer multiples and therefore fewer rows. @@ -136,11 +136,11 @@ Our `RadixEncoder` module is general, creating selection tables for arbitrary Bo The `PartialProductSignExtension` defines the API for doing different kinds of sign extension on the `PartialProductArray`, from very simplistic for helping design new arithmetics to fairly standard to even compact, rectangular forms. -- None: no sign extension. -- Brute: full width extension which is robust but costly. -- StopBit: A standard form which has the inverse-sign and a '1' stop bit in each row -- Compact: A form that eliminates a final sign in an otherwise empty final row. -- CompactRect: An enhanced form of compact that can handle rectangular multiplications. +- `None`: no sign extension. +- `Brute`: full width extension which is robust but costly. +- `StopBit`: A standard form which has the inverse-sign and a '1' stop-bit in each row +- `Compact`: A form that eliminates a final sign in an otherwise empty final row. +- `CompactRect`: An enhanced form of compact that can handle rectangular multiplications. ### Partial Product Visualization @@ -175,7 +175,7 @@ You can also generate a Markdown form of the same matrix: Once you have a partial product matrix, you would like to add up the addends. Traditionally this is done using compression trees which instantiate 2:1 and 3:2 column compressors (or carry-save adders) to reduce the matrix to two addends. The final two addends are often added with an efficient final adder. -Our [ColumnCompressor](https://intel.github.io/rohd-hcl/rohd_hcl/ColumnCompressor-class.html) class uses a delay-driven approach to efficiently compress the rows of the partial product matrix. Its only argument is a `PartialProductArray` (base class of `PartialProductGenerator`), and it creates a list of `ColumnQueue`s containing the final two addends stored by column after compression. An `extractRow`routine can be used to extract the columns. `ColumnCompressor` currently has an extension `EvaluateColumnCompressor` which can be used to print out the compression progress. Here is the legend for these printouts. +Our [ColumnCompressor](https://intel.github.io/rohd-hcl/rohd_hcl/ColumnCompressor-class.html) class uses a delay-driven approach to efficiently compress the rows of the partial product matrix. Its only argument is a `PartialProductArray` (base class of `PartialProductGenerator`), and it creates a list of `ColumnQueue`s containing the final two addends stored by column after compression. An `extractRow`routine can be used to extract the columns. `ColumnCompressor` currently has an extension `EvaluateColumnCompressor` which can be used to print out the compression progress. Here is the legend for these printouts. - `ppR,C` = partial product entry at row R, column C - `sR,C` = sum term coming last from row R, column C @@ -221,7 +221,7 @@ Any adder can be used as the final adder of the final two addends produced from Here is a code snippet that shows how these components can be used to create a multiplier. -First the partial product generator is used (`PartialProductGenerator`), which we pass in the `RadixEncoder`, whether the operands are signed. We operate on this generator with a compact sign extension class for rectangular products (`CompactRectSignExtension`). Note that sign extension is needed regardless of whether operands are signed or not due to Booth encoding. +First the partial product generator is used (`PartialProductGenerator`), which we pass in the `RadixEncoder`, whether the operands are signed. We operate on this generator with a compact sign extension class for rectangular products (`CompactRectSignExtension`). Note that sign extension is needed regardless of whether operands are signed or not due to Booth encoding. Next, we use the `ColumnCompressor` to compress the partial products into two final addends. diff --git a/lib/src/arithmetic/addend_compressor.dart b/lib/src/arithmetic/addend_compressor.dart index 8c4d30c42..5cca11a35 100644 --- a/lib/src/arithmetic/addend_compressor.dart +++ b/lib/src/arithmetic/addend_compressor.dart @@ -10,7 +10,7 @@ import 'package:collection/collection.dart'; import 'package:meta/meta.dart'; import 'package:rohd/rohd.dart'; -import 'package:rohd_hcl/src/arithmetic/multiplier_lib.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; /// Base class for bit-level column compressor function abstract class BitCompressor extends Module { @@ -25,7 +25,7 @@ abstract class BitCompressor extends Module { Logic get carry => output('carry'); /// Construct a column compressor - BitCompressor(Logic compressBits) { + BitCompressor(Logic compressBits, {super.name = 'bit_compressor'}) { this.compressBits = addInput( 'compressBits', compressBits, @@ -39,7 +39,7 @@ abstract class BitCompressor extends Module { /// 2-input column compressor (half-adder) class Compressor2 extends BitCompressor { /// Construct a 2-input compressor (half-adder) - Compressor2(super.compressBits) { + Compressor2(super.compressBits, {super.name = 'compressor_2'}) { sum <= compressBits.xor(); carry <= compressBits.and(); } @@ -48,7 +48,7 @@ class Compressor2 extends BitCompressor { /// 3-input column compressor (full-adder) class Compressor3 extends BitCompressor { /// Construct a 3-input column compressor (full-adder) - Compressor3(super.compressBits) { + Compressor3(super.compressBits, {super.name = 'compressor_3'}) { sum <= compressBits.xor(); carry <= mux(compressBits[0], compressBits.slice(2, 1).or(), @@ -237,19 +237,31 @@ class ColumnCompressor { BitCompressor compressor; if (depth > 3) { inputs.add(queue.removeFirst()); - compressor = - Compressor3([for (final i in inputs) i.logic].swizzle()); + compressor = Compressor3( + [for (final i in inputs) i.logic].swizzle(), + name: 'cmp3_iter${iteration}_col$col'); } else { - compressor = - Compressor2([for (final i in inputs) i.logic].swizzle()); + compressor = Compressor2( + [for (final i in inputs) i.logic].swizzle(), + name: 'cmp2_iter${iteration}_col$col'); } final t = CompressTerm( - CompressTermType.sum, compressor.sum, inputs, 0, col); + CompressTermType.sum, + compressor.sum.named('cmp_sum_iter${iteration}_c$col', + naming: Naming.mergeable), + inputs, + 0, + col); terms.add(t); columns[col].add(t); if (col < columns.length - 1) { final t = CompressTerm( - CompressTermType.carry, compressor.carry, inputs, 0, col); + CompressTermType.carry, + compressor.carry.named('cmp_carry_iter${iteration}_c$col', + naming: Naming.mergeable), + inputs, + 0, + col); columns[col + 1].add(t); terms.add(t); } diff --git a/lib/src/arithmetic/adder.dart b/lib/src/arithmetic/adder.dart index ca80e8ddc..c1cf28881 100644 --- a/lib/src/arithmetic/adder.dart +++ b/lib/src/arithmetic/adder.dart @@ -76,13 +76,21 @@ class NativeAdder extends Adder { if (a.width != b.width) { throw RohdHclException('inputs of a and b should have same width.'); } + final aExtended = + a.zeroExtend(a.width + 1).named('aExtended', naming: Naming.mergeable); + final bExtended = + b.zeroExtend(a.width + 1).named('bExtended', naming: Naming.mergeable); + final aPlusb = (aExtended + bExtended) + .named('aExtended_plus_bExtended', naming: Naming.mergeable); if (carryIn == null) { - sum <= a.zeroExtend(a.width + 1) + b.zeroExtend(b.width + 1); + sum <= aPlusb; } else { + final cinExtendend = carryIn! + .zeroExtend(a.width + 1) + .named('carryInExtended', naming: Naming.mergeable); sum <= - a.zeroExtend(a.width + 1) + - b.zeroExtend(b.width + 1) + - carryIn!.zeroExtend(a.width + 1); + (aPlusb + cinExtendend) + .named('sumWithCarryIn', naming: Naming.mergeable); } } } diff --git a/lib/src/arithmetic/fixed_to_float.dart b/lib/src/arithmetic/fixed_to_float.dart index 19682749c..16b27020d 100644 --- a/lib/src/arithmetic/fixed_to_float.dart +++ b/lib/src/arithmetic/fixed_to_float.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // fixed_to_float.dart @@ -57,8 +57,10 @@ class FixedToFloat extends Module { final absValue = Logic(name: 'absValue', width: fixed.width) ..gets(mux(_float.sign, ~(fixed - 1), fixed)); - final jBit = - ParallelPrefixPriorityEncoder(absValue.reversed).out.zeroExtend(iWidth); + final jBit = ParallelPrefixPriorityEncoder(absValue.reversed) + .out + .zeroExtend(iWidth) + .named('jBit'); // Extract mantissa final mantissa = Logic(name: 'mantissa', width: mantissaWidth); @@ -75,8 +77,8 @@ class FixedToFloat extends Module { } // Align mantissa - final absValueShifted = - Logic(width: max(absValue.width, mantissaWidth + 2)); + final absValueShifted = Logic( + width: max(absValue.width, mantissaWidth + 2), name: 'absValueShifted'); if (absValue.width < mantissaWidth + 2) { final zeros = Const(0, width: mantissaWidth + 2 - absValue.width); absValueShifted <= [absValue, zeros].swizzle() << j; @@ -89,19 +91,25 @@ class FixedToFloat extends Module { sticky <= absValueShifted.getRange(0, -mantissaWidth - 2).or(); /// Round to nearest even: mantissa | guard sticky - final roundUp = guard & (sticky | mantissa[0]); - final mantissaRounded = mux(roundUp, mantissa + 1, mantissa); + final roundUp = (guard & (sticky | mantissa[0])).named('roundUp'); + final mantissaRounded = + mux(roundUp, mantissa + 1, mantissa).named('roundedMantissa'); // Calculate biased exponent final eRaw = mux( - absValueShifted[-1], - Const(bias + fixed.width - fixed.n - 1, width: iWidth) - j, - Const(0, width: iWidth)); - final eRawRne = mux(roundUp & ~mantissaRounded.or(), eRaw + 1, eRaw); + absValueShifted[-1], + (Const(bias + fixed.width - fixed.n - 1, width: iWidth) - j) + .named('eShift'), + Const(0, width: iWidth)) + .named('eRaw'); + final eRawRne = + mux(roundUp & ~mantissaRounded.or(), eRaw + 1, eRaw).named('eRawRNE'); // Select output handling corner cases - final expoLessThanOne = eRawRne[-1] | ~eRawRne.or(); - final expoMoreThanMax = ~eRawRne[-1] & (eRawRne.gt(eMax)); + final expoLessThanOne = + (eRawRne[-1] | ~eRawRne.or()).named('expLessThanOne'); + final expoMoreThanMax = + (~eRawRne[-1] & (eRawRne.gt(eMax))).named('expMoreThanMax'); Combinational([ If.block([ Iff(~absValue.or(), [ diff --git a/lib/src/arithmetic/floating_point/floating_point_adder.dart b/lib/src/arithmetic/floating_point/floating_point_adder.dart index be070d533..6176732ec 100644 --- a/lib/src/arithmetic/floating_point/floating_point_adder.dart +++ b/lib/src/arithmetic/floating_point/floating_point_adder.dart @@ -11,6 +11,8 @@ import 'package:meta/meta.dart'; import 'package:rohd/rohd.dart'; import 'package:rohd_hcl/rohd_hcl.dart'; +// TODO(desmonddak): add variable width output as we did with fpmultiply + /// An abstract API for floating point adders. abstract class FloatingPointAdder extends Module { /// Width of the output exponent field. @@ -43,9 +45,10 @@ abstract class FloatingPointAdder extends Module { late final FloatingPoint b; /// getter for the computed [FloatingPoint] output. - late final FloatingPoint sum = - FloatingPoint(exponentWidth: exponentWidth, mantissaWidth: mantissaWidth) - ..gets(output('sum')); + + late final FloatingPoint sum = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth, name: 'sum') + ..gets(output('sum')); /// Add two floating point numbers [a] and [b], returning result in [sum]. /// - [clk], [reset], [enable] are optional inputs to control a pipestage @@ -62,23 +65,49 @@ abstract class FloatingPointAdder extends Module { b.mantissa.width != mantissaWidth) { throw RohdHclException('FloatingPoint widths must match'); } - this.clk = (clk != null) ? addInput('clk', clk) : clk; - this.enable = (enable != null) ? addInput('enable', enable) : enable; - this.reset = (reset != null) ? addInput('clk', reset) : reset; + this.clk = (clk != null) ? addInput('clk', clk) : null; + this.reset = (reset != null) ? addInput('reset', reset) : null; + this.enable = (enable != null) ? addInput('enable', enable) : null; + this.a = a.clone(name: 'a')..gets(addInput('a', a, width: a.width)); + this.b = b.clone(name: 'b')..gets(addInput('b', b, width: b.width)); - this.a = a.clone()..gets(addInput('a', a, width: a.width)); - this.b = b.clone()..gets(addInput('b', b, width: b.width)); addOutput('sum', width: exponentWidth + mantissaWidth + 1); } /// Swapping two FloatingPoint structures based on a conditional @protected (FloatingPoint, FloatingPoint) swap( - Logic swap, (FloatingPoint, FloatingPoint) toSwap) => - ( - toSwap.$1.clone()..gets(mux(swap, toSwap.$2, toSwap.$1)), - toSwap.$2.clone()..gets(mux(swap, toSwap.$1, toSwap.$2)) - ); + Logic swap, (FloatingPoint, FloatingPoint) toSwap) { + final in1 = toSwap.$1.named('swapIn_${toSwap.$1.name}'); + final in2 = toSwap.$2.named('swapIn_${toSwap.$2.name}'); + + final out1 = mux(swap, in2, in1).named('swapOut_larger'); + final out2 = mux(swap, in1, in2).named('swapOut_smaller'); + final first = a.clone(name: 'larger')..gets(out1); + final second = a.clone(name: 'smaller')..gets(out2); + return (first, second); + } + + /// Sort two FloatingPointNumbers and swap + @protected + (FloatingPoint larger, FloatingPoint smaller) sortFp( + (FloatingPoint, FloatingPoint) toSort) { + final ae = toSort.$1.exponent; + final be = toSort.$2.exponent; + final am = toSort.$1.mantissa; + final bm = toSort.$2.mantissa; + final doSwap = (ae.lt(be) | + (ae.eq(be) & am.lt(bm)) | + ((ae.eq(be) & am.eq(bm)) & toSort.$1.sign)) + .named('doSwap'); + + final swapped = swap(doSwap, toSort); + + final larger = swapped.$1.clone(name: 'larger')..gets(swapped.$1); + final smaller = swapped.$2.clone(name: 'smaller')..gets(swapped.$2); + + return (larger, smaller); + } /// Pipelining helper that uses the context for signals clk/enable/reset Logic localFlop(Logic input) => diff --git a/lib/src/arithmetic/floating_point/floating_point_adder_round.dart b/lib/src/arithmetic/floating_point/floating_point_adder_round.dart index 00de61b18..5998cc1ad 100644 --- a/lib/src/arithmetic/floating_point/floating_point_adder_round.dart +++ b/lib/src/arithmetic/floating_point/floating_point_adder_round.dart @@ -11,6 +11,11 @@ import 'dart:math'; import 'package:rohd/rohd.dart'; import 'package:rohd_hcl/rohd_hcl.dart'; +// TODO(desmonddak): factor rounding into a utility by merging +// the near and far bits and creating a LGRS algorithm on that word. + +// TODO(desmondak): investigate how to implement other forms of rounding. + /// An adder module for variable FloatingPoint type with rounding. // This is a Seidel/Even adder, dual-path implementation. class FloatingPointAdderRound extends FloatingPointAdder { @@ -35,53 +40,46 @@ class FloatingPointAdderRound extends FloatingPointAdder { exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); output('sum') <= outputSum; - // Ensure that the larger number is wired as 'a' - final ae = this.a.exponent; - final be = this.b.exponent; - final am = this.a.mantissa; - final bm = this.b.mantissa; - final doSwap = ae.lt(be) | - (ae.eq(be) & am.lt(bm)) | - ((ae.eq(be) & am.eq(bm)) & this.a.sign); - - final FloatingPoint a; - final FloatingPoint b; - (a, b) = swap(doSwap, (super.a, super.b)); - // Seidel: S.EFF = effectiveSubtraction - final effectiveSubtraction = a.sign ^ b.sign ^ (subtract ?? Const(0)); - final isNaN = a.isNaN | - b.isNaN | - (a.isInfinity & b.isInfinity & effectiveSubtraction); - final isInf = a.isInfinity | b.isInfinity; - - final exponentSubtractor = OnesComplementAdder(a.exponent, b.exponent, + final effectiveSubtraction = + (a.sign ^ b.sign ^ (subtract ?? Const(0))).named('effSubtraction'); + final isNaN = (a.isNaN | + b.isNaN | + (a.isInfinity & b.isInfinity & effectiveSubtraction)) + .named('isNaN'); + final isInf = (a.isInfinity | b.isInfinity).named('isInf'); + + final exponentSubtractor = OnesComplementAdder( + super.a.exponent, super.b.exponent, subtract: true, adderGen: adderGen, name: 'exponent_sub'); - final signDelta = exponentSubtractor.sign; + final signDelta = exponentSubtractor.sign.named('signDelta'); - final delta = exponentSubtractor.sum; + final delta = exponentSubtractor.sum.named('expDelta'); // Seidel: (sl, el, fl) = larger; (ss, es, fs) = smaller final (larger, smaller) = swap(signDelta, (a, b)); final fl = mux( - larger.isNormal, - [larger.isNormal, larger.mantissa].swizzle(), - [larger.mantissa, Const(0)].swizzle()); + larger.isNormal, + [larger.isNormal, larger.mantissa].swizzle(), + [larger.mantissa, Const(0)].swizzle(), + ).named('fullLarger'); final fs = mux( - smaller.isNormal, - [smaller.isNormal, smaller.mantissa].swizzle(), - [smaller.mantissa, Const(0)].swizzle()); + smaller.isNormal, + [smaller.isNormal, smaller.mantissa].swizzle(), + [smaller.mantissa, Const(0)].swizzle(), + ).named('fullSmaller'); // Seidel: flp larger preshift, normally in [2,4) final sigWidth = fl.width + 1; final largeShift = mux(effectiveSubtraction, fl.zeroExtend(sigWidth) << 1, - fl.zeroExtend(sigWidth)); + fl.zeroExtend(sigWidth)) + .named('largeShift'); final smallShift = mux(effectiveSubtraction, fs.zeroExtend(sigWidth) << 1, - fs.zeroExtend(sigWidth)); - - final zeroExp = Const(0, width: exponentWidth); + fs.zeroExtend(sigWidth)) + .named('smallShift'); + final zeroExp = a.zeroExponent; final largeOperand = largeShift; // // R Datapath: Far exponents or addition @@ -89,116 +87,151 @@ class FloatingPointAdderRound extends FloatingPointAdder { final extendWidthRPath = min(mantissaWidth + 3, pow(2, exponentWidth).toInt() - 3); - final smallerFullRPath = - [smallShift, Const(0, width: extendWidthRPath)].swizzle(); + final smallerFullRPath = [smallShift, Const(0, width: extendWidthRPath)] + .swizzle() + .named('smallerFullRpath'); - final smallerAlignRPath = smallerFullRPath >>> exponentSubtractor.sum; - final smallerOperandRPath = smallerAlignRPath.slice( - smallerAlignRPath.width - 1, - smallerAlignRPath.width - largeOperand.width); + final smallerAlignRPath = (smallerFullRPath >>> exponentSubtractor.sum) + .named('smallerAlignedRpath'); + final smallerOperandRPath = smallerAlignRPath + .slice(smallerAlignRPath.width - 1, + smallerAlignRPath.width - largeOperand.width) + .named('smallerOperandRpath'); /// R Pipestage here: - final aIsNormalLatched = localFlop(a.isNormal); - final bIsNormalLatched = localFlop(b.isNormal); - final effectiveSubtractionLatched = localFlop(effectiveSubtraction); - final largeOperandLatched = localFlop(largeOperand); - final smallerOperandRPathLatched = localFlop(smallerOperandRPath); - final smallerAlignRPathLatched = localFlop(smallerAlignRPath); - final largerExpLatched = localFlop(larger.exponent); - final deltaLatched = localFlop(delta); - final isInfLatched = localFlop(isInf); - final isNaNLatched = localFlop(isNaN); - - final carryRPath = Logic(); + final aIsNormalFlopped = localFlop(a.isNormal); + final bIsNormalFlopped = localFlop(b.isNormal); + final effectiveSubtractionFlopped = localFlop(effectiveSubtraction); + final largeOperandFlopped = localFlop(largeOperand); + final smallerOperandRPathFlopped = localFlop(smallerOperandRPath); + final smallerAlignRPathFlopped = localFlop(smallerAlignRPath); + final largerExpFlopped = localFlop(larger.exponent); + final deltaFlopped = localFlop(delta); + final isInfFlopped = localFlop(isInf); + final isNaNFlopped = localFlop(isNaN); + + final carryRPath = Logic(name: 'carryRpath'); final significandAdderRPath = OnesComplementAdder( - largeOperandLatched, smallerOperandRPathLatched, - subtractIn: effectiveSubtractionLatched, + largeOperandFlopped, smallerOperandRPathFlopped, + subtractIn: effectiveSubtractionFlopped, carryOut: carryRPath, adderGen: adderGen, name: 'rpath_significand_adder'); - final lowBitsRPath = - smallerAlignRPathLatched.slice(extendWidthRPath - 1, 0); - final lowAdderRPath = OnesComplementAdder( - carryRPath.zeroExtend(extendWidthRPath), - mux(effectiveSubtractionLatched, ~lowBitsRPath, lowBitsRPath), - adderGen: adderGen, - name: 'rpath_lowadder'); - - final preStickyRPath = - lowAdderRPath.sum.slice(lowAdderRPath.sum.width - 4, 0).or(); - final stickyBitRPath = lowAdderRPath.sum[-3] | preStickyRPath; + final lowBitsRPath = smallerAlignRPathFlopped + .slice(extendWidthRPath - 1, 0) + .named('lowbitsRpath'); + + final lowAdderRPathSum = OnesComplementAdder( + carryRPath.zeroExtend(extendWidthRPath), + mux(effectiveSubtractionFlopped, ~lowBitsRPath, lowBitsRPath), + adderGen: adderGen, + name: 'rpath_lowadder') + .sum + .named('lowAdderSumRpath'); + + final preStickyRPath = lowAdderRPathSum + .slice(lowAdderRPathSum.width - 4, 0) + .or() + .named('preStickyRpath'); + final stickyBitRPath = + (lowAdderRPathSum[-3] | preStickyRPath).named('stickyBitRpath'); final earlyGRSRPath = [ - lowAdderRPath.sum - .slice(lowAdderRPath.sum.width - 2, lowAdderRPath.sum.width - 3), + lowAdderRPathSum.slice( + lowAdderRPathSum.width - 2, lowAdderRPathSum.width - 3), preStickyRPath - ].swizzle(); + ].swizzle().named('earlyGRSRpath'); - final sumRPath = significandAdderRPath.sum.slice(mantissaWidth + 1, 0); - final sumP1RPath = - (significandAdderRPath.sum + 1).slice(mantissaWidth + 1, 0); + final sumRPath = + significandAdderRPath.sum.slice(mantissaWidth + 1, 0).named('sumRpath'); + // TODO(desmonddak): we should use a compound adder here + final sumP1RPath = (significandAdderRPath.sum + 1) + .named('sumPlusOneRpath') + .slice(mantissaWidth + 1, 0); final sumLeadZeroRPath = - ~sumRPath[-1] & (aIsNormalLatched | bIsNormalLatched); + (~sumRPath[-1] & (aIsNormalFlopped | bIsNormalFlopped)) + .named('sumlead0Rpath'); final sumP1LeadZeroRPath = - ~sumP1RPath[-1] & (aIsNormalLatched | bIsNormalLatched); + (~sumP1RPath[-1] & (aIsNormalFlopped | bIsNormalFlopped)) + .named('sumP1lead0Rpath'); - final selectRPath = lowAdderRPath.sum[-1]; - final shiftGRSRPath = [earlyGRSRPath[2], stickyBitRPath].swizzle(); + final selectRPath = lowAdderRPathSum[-1].named('selectRpath'); + final shiftGRSRPath = + [earlyGRSRPath[2], stickyBitRPath].swizzle().named('shiftGRSRpath'); final mergedSumRPath = mux( - sumLeadZeroRPath, - [sumRPath, earlyGRSRPath].swizzle().slice(sumRPath.width + 1, 0), - [sumRPath, shiftGRSRPath].swizzle()); + sumLeadZeroRPath, + [sumRPath, earlyGRSRPath] + .swizzle() + .named('sumEarlyGRSRpath') + .slice(sumRPath.width + 1, 0), + [sumRPath, shiftGRSRPath].swizzle()) + .named('mergedSumRpath'); final mergedSumP1RPath = mux( - sumP1LeadZeroRPath, - [sumP1RPath, earlyGRSRPath].swizzle().slice(sumRPath.width + 1, 0), - [sumP1RPath, shiftGRSRPath].swizzle()); - - final finalSumLGRSRPath = - mux(selectRPath, mergedSumP1RPath, mergedSumRPath); + sumP1LeadZeroRPath, + [sumP1RPath, earlyGRSRPath] + .swizzle() + .named('sump1EarlyGRSRPath') + .slice(sumRPath.width + 1, 0), + [sumP1RPath, shiftGRSRPath].swizzle()) + .named('mergedSumP1RPath'); + + final finalSumLGRSRPath = mux(selectRPath, mergedSumP1RPath, mergedSumRPath) + .named('finalSumLGRSRpath'); // RNE: guard & (lsb | round | sticky) - final rndRPath = finalSumLGRSRPath[2] & - (finalSumLGRSRPath[3] | finalSumLGRSRPath[1] | finalSumLGRSRPath[0]); + final rndRPath = (finalSumLGRSRPath[2] & + (finalSumLGRSRPath[3] | + finalSumLGRSRPath[1] | + finalSumLGRSRPath[0])) + .named('rndRpath'); // Rounding from 1111 to 0000. final incExpRPath = - rndRPath & sumLeadZeroRPath.eq(Const(1)) & sumP1LeadZeroRPath.eq(0); + (rndRPath & sumLeadZeroRPath.eq(Const(1)) & sumP1LeadZeroRPath.eq(0)) + .named('incExpRrpath'); - final firstZeroRPath = mux(selectRPath, ~sumP1RPath[-1], ~sumRPath[-1]); + final firstZeroRPath = mux(selectRPath, ~sumP1RPath[-1], ~sumRPath[-1]) + .named('firstZero_rpath'); - final expDecr = ParallelPrefixDecr(largerExpLatched, - ppGen: ppTree, name: 'exp_decrement'); - final expIncr = ParallelPrefixIncr(largerExpLatched, - ppGen: ppTree, name: 'exp_increment'); + final expDecr = ParallelPrefixDecr(largerExpFlopped, + ppGen: ppTree, name: 'expDecrement'); + final expIncr = ParallelPrefixIncr(largerExpFlopped, + ppGen: ppTree, name: 'expIncrement'); final exponentRPath = Logic(width: exponentWidth); Combinational([ If.block([ // Subtract 1 from exponent - Iff(~incExpRPath & effectiveSubtractionLatched & firstZeroRPath, + Iff(~incExpRPath & effectiveSubtractionFlopped & firstZeroRPath, [exponentRPath < expDecr.out]), // Add 1 to exponent ElseIf( - ~effectiveSubtractionLatched & + ~effectiveSubtractionFlopped & (incExpRPath & firstZeroRPath | ~incExpRPath & ~firstZeroRPath), [exponentRPath < expIncr.out]), // Add 2 to exponent - ElseIf(incExpRPath & effectiveSubtractionLatched & ~firstZeroRPath, - [exponentRPath < largerExpLatched << 1]), - Else([exponentRPath < largerExpLatched]) + ElseIf(incExpRPath & effectiveSubtractionFlopped & ~firstZeroRPath, + [exponentRPath < largerExpFlopped << 1]), + Else([exponentRPath < largerExpFlopped]) ]) ]); - final sumMantissaRPath = mux(selectRPath, sumP1RPath, sumRPath) + - rndRPath.zeroExtend(sumRPath.width); - final mantissaRPath = sumMantissaRPath << - mux(selectRPath, sumP1LeadZeroRPath, sumLeadZeroRPath); + final sumMantissaRPath = + mux(selectRPath, sumP1RPath, sumRPath).named('selectSumMantissa_rpath'); + // TODO(desmonddak): the '+' operator fails to pick up names directly + final sumMantissaRPathRnd = (sumMantissaRPath + + rndRPath.zeroExtend(sumRPath.width).named('rndExtend_rpath')) + .named('sumMantissaRndRpath'); + final mantissaRPath = (sumMantissaRPathRnd << + mux(selectRPath, sumP1LeadZeroRPath, sumLeadZeroRPath)) + .named('mantissaRpath'); // // N Datapath here: close exponents, subtraction // - final smallOperandNPath = smallShift >>> (a.exponent[0] ^ b.exponent[0]); + final smallOperandNPath = + (smallShift >>> (a.exponent[0] ^ b.exponent[0])).named('smallOperand'); final significandSubtractorNPath = OnesComplementAdder( largeOperand, smallOperandNPath, @@ -206,10 +239,11 @@ class FloatingPointAdderRound extends FloatingPointAdder { adderGen: adderGen, name: 'npath_significand_sub'); - final significandNPath = - significandSubtractorNPath.sum.slice(smallOperandNPath.width - 1, 0); + final significandNPath = significandSubtractorNPath.sum + .slice(smallOperandNPath.width - 1, 0) + .named('significandNpath'); - final validLeadOneNPath = Logic(); + final validLeadOneNPath = Logic(name: 'validLead1Npath'); final leadOneNPathPre = ParallelPrefixPriorityEncoder( significandNPath.reversed, ppGen: ppTree, @@ -217,65 +251,76 @@ class FloatingPointAdderRound extends FloatingPointAdder { name: 'npath_leadingOne') .out; // Limit leadOne to exponent range and match widths - final leadOneNPath = (leadOneNPathPre.width > exponentWidth) - ? mux( - leadOneNPathPre - .gte(a.inf().exponent.zeroExtend(leadOneNPathPre.width)), - a.inf().exponent, - leadOneNPathPre.getRange(0, exponentWidth)) - : leadOneNPathPre.zeroExtend(exponentWidth); + final leadOneNPath = ((leadOneNPathPre.width > exponentWidth) + ? mux( + leadOneNPathPre + .gte(a.inf().exponent.zeroExtend(leadOneNPathPre.width)), + a.inf().exponent, + leadOneNPathPre.getRange(0, exponentWidth)) + : leadOneNPathPre.zeroExtend(exponentWidth)) + .named('leadOneNpath'); // N pipestage here: - final significandNPathLatched = localFlop(significandNPath); - final significandSubtractorNPathSignLatched = + final significandNPathFlopped = localFlop(significandNPath); + final significandSubtractorNPathSignFlopped = localFlop(significandSubtractorNPath.sign); - final leadOneNPathLatched = localFlop(leadOneNPath); - final validLeadOneNPathLatched = localFlop(validLeadOneNPath); - final largerSignLatched = localFlop(larger.sign); - final smallerSignLatched = localFlop(smaller.sign); + final leadOneNPathFlopped = localFlop(leadOneNPath); + final validLeadOneNPathFlopped = localFlop(validLeadOneNPath); + final largerSignFlopped = localFlop(larger.sign); + final smallerSignFlopped = localFlop(smaller.sign); final expCalcNPath = OnesComplementAdder( - largerExpLatched, leadOneNPathLatched.zeroExtend(exponentWidth), - subtractIn: effectiveSubtractionLatched, + largerExpFlopped, leadOneNPathFlopped.zeroExtend(exponentWidth), + subtractIn: effectiveSubtractionFlopped, adderGen: adderGen, name: 'npath_expcalc'); - final preExpNPath = expCalcNPath.sum.slice(exponentWidth - 1, 0); + final preExpNPath = + expCalcNPath.sum.slice(exponentWidth - 1, 0).named('preExpNpath'); final posExpNPath = - preExpNPath.or() & ~expCalcNPath.sign & validLeadOneNPathLatched; + (preExpNPath.or() & ~expCalcNPath.sign & validLeadOneNPathFlopped) + .named('posExpNpath'); - final exponentNPath = mux(posExpNPath, preExpNPath, zeroExp); + final exponentNPath = + mux(posExpNPath, preExpNPath, zeroExp).named('exponentNpath'); - final preMinShiftNPath = ~leadOneNPathLatched.or() | ~largerExpLatched.or(); + final preMinShiftNPath = + (~leadOneNPathFlopped.or() | ~largerExpFlopped.or()) + .named('preMinShiftNpath'); final minShiftNPath = - mux(posExpNPath | preMinShiftNPath, leadOneNPathLatched, expDecr.out); - final notSubnormalNPath = aIsNormalLatched | bIsNormalLatched; + mux(posExpNPath | preMinShiftNPath, leadOneNPathFlopped, expDecr.out) + .named('minShiftNpath'); + final notSubnormalNPath = aIsNormalFlopped | bIsNormalFlopped; - final shiftedSignificandNPath = - (significandNPathLatched << minShiftNPath).slice(mantissaWidth, 1); + final shiftedSignificandNPath = (significandNPathFlopped << minShiftNPath) + .named('shiftedSignificandNpath') + .slice(mantissaWidth, 1); final finalSignificandNPath = mux( - notSubnormalNPath, - shiftedSignificandNPath, - significandNPathLatched.slice(significandNPathLatched.width - 1, 2)); + notSubnormalNPath, + shiftedSignificandNPath, + significandNPathFlopped.slice(significandNPathFlopped.width - 1, 2)) + .named('finalSignificandNpath'); - final signNPath = mux(significandSubtractorNPathSignLatched, - smallerSignLatched, largerSignLatched); + final signNPath = mux(significandSubtractorNPathSignFlopped, + smallerSignFlopped, largerSignFlopped) + .named('signNpath'); - final isR = deltaLatched.gte(Const(2, width: delta.width)) | - ~effectiveSubtractionLatched; + final isR = (deltaFlopped.gte(Const(2, width: delta.width)) | + ~effectiveSubtractionFlopped) + .named('isR'); Combinational([ - If(isNaNLatched, then: [ + If(isNaNFlopped, then: [ outputSum < outputSum.nan, ], orElse: [ - If(isInfLatched, then: [ - outputSum < outputSum.inf(sign: largerSignLatched), + If(isInfFlopped, then: [ + outputSum < outputSum.inf(sign: largerSignFlopped), ], orElse: [ If(isR, then: [ - outputSum.sign < largerSignLatched, + outputSum.sign < largerSignFlopped, outputSum.exponent < exponentRPath, outputSum.mantissa < mantissaRPath.slice(mantissaRPath.width - 2, 1), diff --git a/lib/src/arithmetic/floating_point/floating_point_adder_simple.dart b/lib/src/arithmetic/floating_point/floating_point_adder_simple.dart index b91ae5eb2..d679380b1 100644 --- a/lib/src/arithmetic/floating_point/floating_point_adder_simple.dart +++ b/lib/src/arithmetic/floating_point/floating_point_adder_simple.dart @@ -30,65 +30,80 @@ class FloatingPointAdderSimple extends FloatingPointAdder { super.name = 'floatingpoint_adder_simple'}) : super() { final outputSum = FloatingPoint( - exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + name: 'sum'); output('sum') <= outputSum; - // Ensure that the larger number is wired as 'a' - final ae = this.a.exponent; - final be = this.b.exponent; - final am = this.a.mantissa; - final bm = this.b.mantissa; - final doSwap = ae.lt(be) | - (ae.eq(be) & am.lt(bm)) | - ((ae.eq(be) & am.eq(bm)) & super.a.sign); - final FloatingPoint a; - final FloatingPoint b; - (a, b) = swap(doSwap, (super.a, super.b)); + final (larger, smaller) = sortFp((super.a, super.b)); - final isInf = a.isInfinity | b.isInfinity; - final isNaN = - a.isNaN | b.isNaN | (a.isInfinity & b.isInfinity & (a.sign ^ b.sign)); + final isInf = (larger.isInfinity | smaller.isInfinity).named('isInf'); + final isNaN = (larger.isNaN | + smaller.isNaN | + (larger.isInfinity & + smaller.isInfinity & + (larger.sign ^ smaller.sign))) + .named('isNaN'); // Align and add mantissas - final expDiff = a.exponent - b.exponent; + final expDiff = (larger.exponent - smaller.exponent).named('expDiff'); final aMantissa = mux( - a.isNormal, - [Const(1), a.mantissa, Const(0, width: mantissaWidth + 1)].swizzle(), - [a.mantissa, Const(0, width: mantissaWidth + 2)].swizzle()); + larger.isNormal, + [Const(1), larger.mantissa, Const(0, width: mantissaWidth + 1)] + .swizzle(), + [larger.mantissa, Const(0, width: mantissaWidth + 2)].swizzle()); final bMantissa = mux( - b.isNormal, - [Const(1), b.mantissa, Const(0, width: mantissaWidth + 1)].swizzle(), - [b.mantissa, Const(0, width: mantissaWidth + 2)].swizzle()); + smaller.isNormal, + [Const(1), smaller.mantissa, Const(0, width: mantissaWidth + 1)] + .swizzle(), + [smaller.mantissa, Const(0, width: mantissaWidth + 2)].swizzle()); final adder = SignMagnitudeAdder( - a.sign, aMantissa, b.sign, bMantissa >>> expDiff, adderGen); + larger.sign, aMantissa, smaller.sign, bMantissa >>> expDiff, adderGen); - final intSum = adder.sum.slice(adder.sum.width - 1, 0); + final intSum = adder.sum.slice(adder.sum.width - 1, 0).named('intSum'); - final aSignLatched = localFlop(a.sign); - final aExpLatched = localFlop(a.exponent); + final aSignLatched = localFlop(larger.sign); + final aExpLatched = localFlop(larger.exponent); final sumLatched = localFlop(intSum); final isInfLatched = localFlop(isInf); final isNaNLatched = localFlop(isNaN); - final mantissa = - sumLatched.reversed.getRange(0, min(intSum.width, intSum.width)); - final leadOneValid = Logic(); + final mantissa = sumLatched.reversed + .getRange(0, min(intSum.width, intSum.width)) + .named('mantissa'); + final leadOneValid = Logic(name: 'leadOneValid'); final leadOnePre = ParallelPrefixPriorityEncoder(mantissa, ppGen: ppTree, valid: leadOneValid) - .out; + .out + .named('leadOnePre'); // Limit leadOne to exponent range and match widths final infExponent = outputSum.inf(sign: aSignLatched).exponent; - final leadOne = (leadOnePre.width > exponentWidth) - ? mux(leadOnePre.gte(infExponent.zeroExtend(leadOnePre.width)), - infExponent, leadOnePre.getRange(0, exponentWidth)) - : leadOnePre.zeroExtend(exponentWidth); + final leadOne = ((leadOnePre.width > exponentWidth) + ? mux(leadOnePre.gte(infExponent.zeroExtend(leadOnePre.width)), + infExponent, leadOnePre.getRange(0, exponentWidth)) + : leadOnePre.zeroExtend(exponentWidth)) + .named('leadOne'); - final leadOneDominates = leadOne.gt(aExpLatched) | ~leadOneValid; - final outExp = - mux(leadOneDominates, a.zeroExponent, aExpLatched - leadOne + 1); + final leadOneDominates = + (leadOne.gt(aExpLatched) | ~leadOneValid).named('leadOneDominates'); + final normalExp = (aExpLatched - leadOne + 1).named('normalExponent'); + final outExp = mux(leadOneDominates, larger.zeroExponent, normalExp) + .named('outExponent'); - final realIsInf = isInfLatched | outExp.eq(infExponent); + final realIsInf = + (isInfLatched | outExp.eq(infExponent)).named('realIsInf'); + + final shiftMantissabyExp = + (sumLatched << (aExpLatched + 1).named('expPlus1')) + .named('shiftMantissaByExp', naming: Naming.mergeable) + .getRange(intSum.width - mantissaWidth, intSum.width) + .named('shiftMantissaByExpSliced'); + final shiftMantissabyLeadOne = + (sumLatched << (leadOne + 1).named('leadOnePlus1')) + .named('sumShiftLeadOnePlus1') + .getRange(intSum.width - mantissaWidth, intSum.width) + .named('shiftMantissaLeadPlus1Sliced', naming: Naming.mergeable); Combinational([ If.block([ @@ -96,25 +111,17 @@ class FloatingPointAdderSimple extends FloatingPointAdder { outputSum < outputSum.nan, ]), ElseIf(realIsInf, [ - // ROHD 0.6.0 trace error if we use the following outputSum < outputSum.inf(sign: aSignLatched), - // outputSum.sign < aSignLatched, - // outputSum.exponent < infExponent, - // outputSum.mantissa < Const(0, width: mantissaWidth, fill: true), ]), ElseIf(leadOneDominates, [ outputSum.sign < aSignLatched, - outputSum.exponent < a.zeroExponent, - outputSum.mantissa < - (sumLatched << aExpLatched + 1) - .getRange(intSum.width - mantissaWidth, intSum.width), + outputSum.exponent < larger.zeroExponent, + outputSum.mantissa < shiftMantissabyExp, ]), Else([ outputSum.sign < aSignLatched, - outputSum.exponent < aExpLatched - leadOne + 1, - outputSum.mantissa < - (sumLatched << leadOne + 1) - .getRange(intSum.width - mantissaWidth, intSum.width), + outputSum.exponent < normalExp, + outputSum.mantissa < shiftMantissabyLeadOne, ]) ]) ]); diff --git a/lib/src/arithmetic/floating_point/floating_point_multiplier.dart b/lib/src/arithmetic/floating_point/floating_point_multiplier.dart index 2a3f07408..4b7f0da20 100644 --- a/lib/src/arithmetic/floating_point/floating_point_multiplier.dart +++ b/lib/src/arithmetic/floating_point/floating_point_multiplier.dart @@ -14,10 +14,10 @@ import 'package:rohd_hcl/rohd_hcl.dart'; /// An abstract API for floating-point multipliers. abstract class FloatingPointMultiplier extends Module { /// Width of the output exponent field. - final int exponentWidth; + late final int exponentWidth; /// Width of the output mantissa field. - final int mantissaWidth; + late final int mantissaWidth; /// The [clk] : if a non-null clock signal is passed in, a pipestage is added /// to the adder to help optimize frequency. @@ -47,9 +47,15 @@ abstract class FloatingPointMultiplier extends Module { FloatingPoint(exponentWidth: exponentWidth, mantissaWidth: mantissaWidth) ..gets(output('product')); + /// The internal FloatingPoint logic to set + late final FloatingPoint internalProduct; + /// Multiply two floating point numbers [a] and [b], returning result in /// [product]. /// + /// If you specify the optional [outProduct], the multiplier + /// will output into the specified output allowing for a wider output. + /// /// - [clk], [reset], [enable] are optional inputs to control a pipestage /// (only inserted if [clk] is provided). /// - [ppGen] is the type of [ParallelPrefix] used in internal adder @@ -58,23 +64,37 @@ abstract class FloatingPointMultiplier extends Module { {Logic? clk, Logic? reset, Logic? enable, + FloatingPoint? outProduct, // ignore: avoid_unused_constructor_parameters ParallelPrefix Function(List, Logic Function(Logic, Logic)) ppGen = KoggeStone.new, - super.name = 'floating_point_multiplier'}) - : exponentWidth = a.exponent.width, - mantissaWidth = a.mantissa.width { - if (b.exponent.width != exponentWidth || - b.mantissa.width != mantissaWidth) { + super.name = 'floating_point_multiplier'}) { + if (b.exponent.width != a.exponent.width || + b.mantissa.width != a.mantissa.width) { throw RohdHclException('FloatingPoint widths must match'); } + exponentWidth = + (outProduct == null) ? a.exponent.width : outProduct.exponent.width; + mantissaWidth = + (outProduct == null) ? a.mantissa.width : outProduct.mantissa.width; + + internalProduct = FloatingPoint( + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + name: 'product'); + addOutput('product', width: exponentWidth + mantissaWidth + 1); + output('product') <= internalProduct; + + if (outProduct != null) { + outProduct <= output('product'); + } + this.clk = (clk != null) ? addInput('clk', clk) : clk; this.enable = (enable != null) ? addInput('enable', enable) : enable; this.reset = (reset != null) ? addInput('clk', reset) : reset; - this.a = a.clone()..gets(addInput('a', a, width: a.width)); - this.b = b.clone()..gets(addInput('b', b, width: b.width)); - addOutput('product', width: a.exponent.width + a.mantissa.width + 1); + this.a = a.clone(name: 'a')..gets(addInput('a', a, width: a.width)); + this.b = b.clone(name: 'b')..gets(addInput('b', b, width: b.width)); } /// Pipelining helper that uses the context for signals clk/enable/reset diff --git a/lib/src/arithmetic/floating_point/floating_point_multiplier_simple.dart b/lib/src/arithmetic/floating_point/floating_point_multiplier_simple.dart index ba5a7c722..edeb6ab25 100644 --- a/lib/src/arithmetic/floating_point/floating_point_multiplier_simple.dart +++ b/lib/src/arithmetic/floating_point/floating_point_multiplier_simple.dart @@ -9,102 +9,140 @@ import 'package:rohd/rohd.dart'; import 'package:rohd_hcl/rohd_hcl.dart'; -import 'package:rohd_hcl/src/arithmetic/partial_product_sign_extend.dart'; /// A multiplier module for FloatingPoint logic. class FloatingPointMultiplierSimple extends FloatingPointMultiplier { /// Multiply two FloatingPoint numbers [a] and [b], returning result /// in [product] FloatingPoint. - /// - [radix] is the Booth encoder radix used with options [2,4,8,16] - /// ((default=4). - /// - [adderGen] is an adder generator to be used in the primary adder - /// functions. - /// - [ppTree] is an parallel prefix tree generator to be used in internal - /// functions. + /// - [multGen] is a multiplier generator to be used in the mantissa + /// multiplication. + /// - [ppTree] is an parallel prefix tree generator to be used in the + /// leading one detection ([ParallelPrefixPriorityEncoder]). + /// + /// The multiplier currently does not support a [product] with narrower + /// exponent or mantissa fields and will throw an exception. FloatingPointMultiplierSimple(super.a, super.b, {super.clk, super.reset, super.enable, - int radix = 4, - Adder Function(Logic a, Logic b, {Logic? carryIn}) adderGen = - NativeAdder.new, + super.outProduct, + Multiplier Function(Logic a, Logic b, + {Logic? clk, Logic? reset, Logic? enable, String name}) + multGen = NativeMultiplier.new, ParallelPrefix Function( List inps, Logic Function(Logic term1, Logic term2) op) ppTree = KoggeStone.new, super.name}) { - final product = FloatingPoint( - exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - output('product') <= product; - final a = this.a; - final b = this.b; - + if (exponentWidth < a.exponent.width) { + throw RohdHclException('product exponent width must be >= ' + ' input exponent width'); + } + if (mantissaWidth < a.mantissa.width) { + throw RohdHclException('product mantissa width must be >= ' + ' input mantissa width'); + } final aMantissa = mux(a.isNormal, [a.isNormal, a.mantissa].swizzle(), - [a.mantissa, Const(0)].swizzle()); + [a.mantissa, Const(0)].swizzle()) + .named('aMantissa'); final bMantissa = mux(b.isNormal, [b.isNormal, b.mantissa].swizzle(), - [b.mantissa, Const(0)].swizzle()); - - final productExp = a.exponent.zeroExtend(exponentWidth + 2) + - b.exponent.zeroExtend(exponentWidth + 2) - - a.bias.zeroExtend(exponentWidth + 2); - - final pp = PartialProductGeneratorCompactRectSignExtension( - aMantissa, bMantissa, RadixEncoder(radix)); - final compressor = - ColumnCompressor(pp, clk: clk, reset: reset, enable: enable) - ..compress(); - final adder = adderGen(compressor.extractRow(0), compressor.extractRow(1)); - // Input mantissas have implicit lead: product mantissa width is (mw+1)*2) - final mantissa = adder.sum.getRange(0, (mantissaWidth + 1) * 2); - - final isInf = a.isInfinity | b.isInfinity; - final isNaN = a.isNaN | - b.isNaN | - ((a.isInfinity | b.isInfinity) & (a.isZero | b.isZero)); + [b.mantissa, Const(0)].swizzle()) + .named('bMantissa'); + + // TODO(desmonddak): do this calculation using the maximum exponent width + // Then adapt to the product exponent width. + final expCalcWidth = exponentWidth + 2; + final addBias = + (a.bias.zeroExtend(expCalcWidth) + b.bias.zeroExtend(expCalcWidth)) + .named('addBias'); + final deltaBias = + (product.bias.zeroExtend(expCalcWidth) - addBias).named('rebias'); + final addExp = (a.exponent.zeroExtend(expCalcWidth) + + b.exponent.zeroExtend(expCalcWidth)) + .named('addExp'); + final productExp = (addExp + deltaBias).named('productExp'); + + final mantissaMult = multGen(aMantissa, bMantissa, + clk: clk, reset: reset, enable: enable, name: 'mantissa_mult'); + + final mantissa = mantissaMult.product + .getRange(0, (a.mantissa.width + 1) * 2) + .named('mantissa'); + + // TODO(desmonddak): This is where we need to either truncate or round to + // the product mantissa width. Today it simply is expanded only, but + // upon narrowing, it will need to truncate for simple multiplication. + + final isInf = (a.isInfinity | b.isInfinity).named('isInf'); + final isNaN = (a.isNaN | + b.isNaN | + ((a.isInfinity | b.isInfinity) & (a.isZero | b.isZero))) + .named('isNaN'); final productExpLatch = localFlop(productExp); - final aSignLatch = localFlop(a.sign); - final bSignLatch = localFlop(b.sign); + final aSignLatch = + localFlop(a.sign).named('a_sign', naming: Naming.renameable); + final bSignLatch = + localFlop(b.sign).named('b_sign', naming: Naming.renameable); final isInfLatch = localFlop(isInf); final isNaNLatch = localFlop(isNaN); - final leadingOnePos = ParallelPrefixPriorityEncoder(mantissa.reversed, + final leadingOnePosPre = ParallelPrefixPriorityEncoder(mantissa.reversed, ppGen: ppTree, name: 'leading_one_encoder') .out - .zeroExtend(exponentWidth + 2); + .named('leadingOneRaw') + .zeroExtend(exponentWidth + 2) + .named('leadingOneRawExtended', naming: Naming.mergeable); + + final leadingOnePos = mux( + leadingOnePosPre.gt(mantissa.width), + Const(product.bias.value.toInt() + 1, + width: leadingOnePosPre.width), + leadingOnePosPre) + .named('leadingOnePosition'); + + final remainingExp = + ((productExpLatch - leadingOnePos).named('productExpMinusLeadOne') + 1) + .named('remainingExp'); + + final internalOverflow = (~remainingExp[-1] & + remainingExp.gte(Const(1, width: exponentWidth, fill: true) + .zeroExtend(exponentWidth + 2))) + .named('internalOverflow'); - final shifter = SignedShifter( - mantissa, - mux(productExpLatch[-1] | productExpLatch.lt(leadingOnePos), - productExpLatch, leadingOnePos), - name: 'mantissa_shifter'); + final overFlow = (isInfLatch | internalOverflow).named('overflow'); - final remainingExp = productExpLatch - leadingOnePos + 1; + final fullMantissa = (mantissaWidth + 1 > mantissa.width) + ? [ + mantissa, + Const(0, width: mantissaWidth + 1 - mantissa.width, fill: true) + ].swizzle().named('extendMantissa') + : mantissa.named('fullMantissa'); - final overFlow = isInfLatch | - (~remainingExp[-1] & - remainingExp.abs().gte(Const(1, width: exponentWidth, fill: true) - .zeroExtend(exponentWidth + 2))); + final fullShift = SignedShifter( + fullMantissa, + mux(productExpLatch[-1] | productExpLatch.lt(leadingOnePos), + productExpLatch, leadingOnePos), + name: 'full_mantissa_shifter') + .shifted + .named('shiftMantissa'); + final finalMantissa = fullShift + .getRange(fullShift.width - mantissaWidth - 1, fullShift.width - 1) + .named('finalMantissa'); Combinational([ If(isNaNLatch, then: [ - product < product.nan, + internalProduct < product.nan, ], orElse: [ If(overFlow, then: [ - // TODO(desmonddak): use this line after trace issue is resolved - // product < product.inf(inSign: aSignLatch ^ bSignLatch), - product.sign < aSignLatch ^ bSignLatch, - product.exponent < product.nan.exponent, - product.mantissa < Const(0, width: mantissaWidth, fill: true), + internalProduct < product.inf(sign: aSignLatch ^ bSignLatch), ], orElse: [ - product.sign < aSignLatch ^ bSignLatch, + internalProduct.sign < aSignLatch ^ bSignLatch, If(remainingExp[-1], then: [ - product.exponent < Const(0, width: exponentWidth) + internalProduct.exponent < Const(0, width: exponentWidth) ], orElse: [ - product.exponent < remainingExp.getRange(0, exponentWidth), + internalProduct.exponent < remainingExp.getRange(0, exponentWidth), ]), - // Remove the leading one for implicit representation - product.mantissa < - shifter.shifted.getRange(-mantissaWidth - 1, mantissa.width - 1) + internalProduct.mantissa < finalMantissa ]) ]) ]); diff --git a/lib/src/arithmetic/multiplicand_selector.dart b/lib/src/arithmetic/multiplicand_selector.dart index 75b5c15dd..59bd72e9c 100644 --- a/lib/src/arithmetic/multiplicand_selector.dart +++ b/lib/src/arithmetic/multiplicand_selector.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // multiplicand_selector.dart @@ -27,6 +27,9 @@ class MultiplicandSelector { /// Place to store [multiples] of the [multiplicand] (e.g. *1, *2, *-1, *-2..) late LogicArray multiples; + /// Multiples sliced into columns for select to access + late final multiplesSlice = []; + /// Build a [MultiplicandSelector] generationg required [multiples] of /// [multiplicand] to [select] using a [RadixEncoder] argument. /// @@ -47,7 +50,7 @@ class MultiplicandSelector { } final width = multiplicand.width + shift; final numMultiples = radix ~/ 2; - multiples = LogicArray([numMultiples], width); + multiples = LogicArray([numMultiples], width, name: 'multiples'); final Logic extendedMultiplicand; if (selectSignedMultiplicand == null) { extendedMultiplicand = signedMultiplicand @@ -77,18 +80,36 @@ class MultiplicandSelector { _ => throw RohdHclException('Radix is beyond 16') }; } + for (var c = 0; c < width; c++) { + multiplesSlice.add(getMultiples(c)); + } + } + + /// Compute the multiples of the multiplicand at current bit position + Logic getMultiples(int col) { + final columnMultiples = [ + for (var i = 0; i < multiples.elements.length; i++) + multiples.elements[i][col] + ].swizzle().named('multiples_c$col', naming: Naming.mergeable); + return columnMultiples.reversed; } /// Retrieve the multiples of the multiplicand at current bit position - Logic getMultiples(int col) => [ - for (var i = 0; i < multiples.elements.length; i++) - multiples.elements[i][col] - ].swizzle().reversed; + Logic fetchMultiples(int col) => multiplesSlice[col]; - Logic _select(Logic multiples, RadixEncode encode) => - (encode.multiples & multiples).or() ^ encode.sign; + // _select attempts to name signals that RadixEncode cannot due to trace + Logic _select(Logic multiples, RadixEncode encode) { + final eMultiples = encode.multiples + .named('encoded_multiple_r${encode.row}', naming: Naming.mergeable); + final eSign = encode.sign + .named('encode_sign_r${encode.row}', naming: Naming.mergeable); + return (eMultiples & multiples).or() ^ eSign; + } /// Select the partial product term from the multiples using a RadixEncode - Logic select(int col, RadixEncode encode) => - _select(getMultiples(col), encode); + Logic select(int col, RadixEncode encode) { + final mults = fetchMultiples(col) + .named('select_r${encode.row}_c$col', naming: Naming.mergeable); + return _select(mults, encode); + } } diff --git a/lib/src/arithmetic/multiplier.dart b/lib/src/arithmetic/multiplier.dart index 5714d70f8..118deb615 100644 --- a/lib/src/arithmetic/multiplier.dart +++ b/lib/src/arithmetic/multiplier.dart @@ -16,6 +16,18 @@ import 'package:rohd_hcl/src/arithmetic/partial_product_sign_extend.dart'; /// An abstract class for all multiplier implementations. abstract class Multiplier extends Module { + /// The clk for pipelining the multiplication. + @protected + Logic? clk; + + /// Optional reset for configurable pipestaging. + @protected + Logic? reset; + + /// Optional enable for configurable pipestaging. + @protected + Logic? enable; + /// The multiplicand input [a]. @protected Logic get a => input('a'); @@ -65,12 +77,28 @@ abstract class Multiplier extends Module { /// Optional [selectSignedMultiplier] allows for runtime configuration of /// signed or unsigned operation, overriding the [signedMultiplier] static /// configuration. + /// If [clk] is not null then a set of flops are used to make the multiply + /// a 2-cycle latency operation. [reset] and [enable] are optional + /// inputs to control these flops when [clk] is provided. Multiplier(Logic a, Logic b, - {this.signedMultiplicand = false, + {Logic? clk, + Logic? reset, + Logic? enable, + this.signedMultiplicand = false, this.signedMultiplier = false, Logic? selectSignedMultiplicand, Logic? selectSignedMultiplier, - super.name}) { + super.name = 'multiplier'}) { + if (signedMultiplicand && (selectSignedMultiplicand != null)) { + throw RohdHclException('multiplicand sign reconfiguration requires ' + 'signedMultiplicand=false'); + } + if (signedMultiplier && (selectSignedMultiplier != null)) { + throw RohdHclException('sign reconfiguration requires signed=false'); + } + this.clk = (clk != null) ? addInput('clk', clk) : null; + this.reset = (reset != null) ? addInput('reset', reset) : null; + this.enable = (enable != null) ? addInput('enable', enable) : null; a = addInput('a', a, width: a.width); b = addInput('b', b, width: b.width); @@ -92,8 +120,83 @@ abstract class Multiplier extends Module { } } +/// A class which wraps the native '*' operator so that it can be passed +/// into other modules as a parameter for using the native operation. +class NativeMultiplier extends Multiplier { + /// The multiplication results of the multiplier. + @override + Logic get product => output('product'); + + /// The width of input [a] and [b] must be the same. + NativeMultiplier(super.a, super.b, + {super.clk, + super.reset, + super.enable, + super.signedMultiplicand = false, + super.signedMultiplier = false, + super.selectSignedMultiplicand, + super.selectSignedMultiplier, + super.name = 'native_multiplier'}) { + if (a.width != b.width) { + throw RohdHclException('inputs of a and b should have same width.'); + } + final pW = a.width + b.width; + final product = addOutput('product', width: pW); + + final Logic extendedMultiplicand; + final Logic extendedMultiplier; + if (selectSignedMultiplicand == null) { + extendedMultiplicand = + signedMultiplicand ? a.signExtend(pW) : a.zeroExtend(pW); + } else { + final len = a.width; + final sign = a[len - 1]; + final extension = [ + for (var i = len; i < pW; i++) + mux(selectSignedMultiplicand!, sign, Const(0)) + ]; + extendedMultiplicand = (a.elements + extension).rswizzle(); + } + if (selectSignedMultiplier == null) { + extendedMultiplier = + (signedMultiplier ? b.signExtend(pW) : b.zeroExtend(pW)) + .named('extended_multiplier', naming: Naming.mergeable); + } else { + final len = b.width; + final sign = b[len - 1]; + final extension = [ + for (var i = len; i < pW; i++) + mux(selectSignedMultiplier!, sign, Const(0)) + ]; + extendedMultiplier = (b.elements + extension) + .rswizzle() + .named('extended_multiplier', naming: Naming.mergeable); + } + + final internalProduct = + (extendedMultiplicand * extendedMultiplier).named('internalProduct'); + product <= condFlop(clk, reset: reset, en: enable, internalProduct); + } +} + +// TODO(desmonddak): add a multiply generator option to MAC +// TODO(desmonddak): add a variable width output as we did with fp multiply +// as well as a variable width accumulate input + /// An abstract class for all multiply accumulate implementations. abstract class MultiplyAccumulate extends Module { + /// The clk for pipelining the multiplication. + @protected + Logic? clk; + + /// Optional reset for configurable pipestaging. + @protected + Logic? reset; + + /// Optional enable for configurable pipestaging. + @protected + Logic? enable; + /// The input to the multiplier pin [a]. @protected Logic get a => input('a'); @@ -155,13 +258,19 @@ abstract class MultiplyAccumulate extends Module { /// signed or unsigned operation, overriding the [signedAddend] static /// configuration. MultiplyAccumulate(Logic a, Logic b, Logic c, - {this.signedMultiplicand = false, + {Logic? clk, + Logic? reset, + Logic? enable, + this.signedMultiplicand = false, this.signedMultiplier = false, this.signedAddend = false, Logic? selectSignedMultiplicand, Logic? selectSignedMultiplier, Logic? selectSignedAddend, super.name}) { + this.clk = (clk != null) ? addInput('clk', clk) : null; + this.reset = (reset != null) ? addInput('reset', reset) : null; + this.enable = (enable != null) ? addInput('enable', enable) : null; a = addInput('a', a, width: a.width); b = addInput('b', b, width: b.width); c = addInput('c', c, width: c.width); @@ -190,15 +299,6 @@ abstract class MultiplyAccumulate extends Module { /// An implementation of an integer multiplier using compression trees. class CompressionTreeMultiplier extends Multiplier { - /// The clk for the pipelined version of column compression. - Logic? clk; - - /// Optional reset for configurable pipestage - Logic? reset; - - /// Optional enable for configurable pipestage. - Logic? enable; - /// The final product of the multiplier module. @override Logic get product => output('product'); @@ -231,9 +331,9 @@ class CompressionTreeMultiplier extends Multiplier { /// inputs to control these flops when [clk] is provided. If [clk] is null, /// the [ColumnCompressor] is built as a combinational tree of compressors. CompressionTreeMultiplier(super.a, super.b, int radix, - {this.clk, - this.reset, - this.enable, + {super.clk, + super.reset, + super.enable, super.signedMultiplicand = false, super.signedMultiplier = false, super.selectSignedMultiplicand, @@ -244,12 +344,9 @@ class CompressionTreeMultiplier extends Multiplier { {String name}) seGen = CompactRectSignExtension.new, super.name = 'compression_tree_multiplier'}) { - clk = (clk != null) ? addInput('clk', clk!) : null; - reset = (reset != null) ? addInput('reset', reset!) : null; - enable = (enable != null) ? addInput('enable', enable!) : null; - +// Should be done in base TODO(desmonddak): final product = addOutput('product', width: a.width + b.width); - final pp = PartialProductGeneratorBasic( + final pp = PartialProductGenerator( a, b, RadixEncoder(radix), @@ -271,18 +368,6 @@ class CompressionTreeMultiplier extends Multiplier { /// An implementation of an integer multiply-accumulate using compression trees class CompressionTreeMultiplyAccumulate extends MultiplyAccumulate { - /// The clk for the pipelined version of column compression. - @protected - Logic? get clk => tryInput('clk'); - - /// Optional reset for configurable pipestage - @protected - Logic? get reset => tryInput('reset'); - - /// Optional enable for configurable pipestage. - @protected - Logic? get enable => tryInput('enable'); - /// The final product of the multiplier module. @override Logic get accumulate => output('accumulate'); @@ -317,7 +402,7 @@ class CompressionTreeMultiplyAccumulate extends MultiplyAccumulate { /// signed or unsigned operation, overriding the [signedAddend] static /// configuration. /// - /// If[clk] is not null then a set of flops are used to latch the output + /// If [clk] is not null then a set of flops are used to latch the output /// after compression. [reset] and [enable] are optional /// inputs to control these flops when [clk] is provided. If [clk] is null, /// the [ColumnCompressor] is built as a combinational tree of compressors. @@ -337,12 +422,8 @@ class CompressionTreeMultiplyAccumulate extends MultiplyAccumulate { {String name}) seGen = CompactRectSignExtension.new, super.name = 'compression_tree_mac'}) { - clk = (clk != null) ? addInput('clk', clk) : null; - reset = (reset != null) ? addInput('reset', reset) : null; - enable = (enable != null) ? addInput('enable', enable) : null; - final accumulate = addOutput('accumulate', width: a.width + b.width + 1); - final pp = PartialProductGeneratorBasic( + final pp = PartialProductGenerator( a, b, RadixEncoder(radix), diff --git a/lib/src/arithmetic/multiplier_encoder.dart b/lib/src/arithmetic/multiplier_encoder.dart index e96317b4e..5d2d994ea 100644 --- a/lib/src/arithmetic/multiplier_encoder.dart +++ b/lib/src/arithmetic/multiplier_encoder.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // muliplier_encoder.dart @@ -18,20 +18,29 @@ class RadixEncode extends LogicStructure { /// Which multiples need to be selected final Logic multiples; - /// 'sign' of multiple + /// 'sign' of multiple. final Logic sign; + /// The [row] that is encoded by this RadixEncode (encoding an + /// overlapping segment of the multiplier). + late final int row; + /// Structure for holding Radix Encoding - RadixEncode({required int numMultiples}) + RadixEncode(int row, {required int numMultiples}) : this._( - Logic(width: numMultiples, name: 'multiples'), Logic(name: 'sign')); - - RadixEncode._(this.multiples, this.sign, {String? name}) + Logic( + width: numMultiples, + name: 'multiples', + naming: Naming.mergeable), + Logic(name: 'sign'), + row); + + RadixEncode._(this.multiples, this.sign, this.row, {String? name}) : super([multiples, sign], name: name ?? 'RadixLogic'); @override RadixEncode clone({String? name}) => - RadixEncode(numMultiples: multiples.width); + RadixEncode(row, numMultiples: multiples.width); } /// Base interface for radix radixEncoder @@ -47,7 +56,7 @@ class RadixEncoder { } /// Encode a multiplier slice into the Booth encoded value - RadixEncode encode(Logic multiplierSlice) { + RadixEncode encode(Logic multiplierSlice, int row) { if (multiplierSlice.width != log2Ceil(radix) + 1) { throw RohdHclException('multiplier slice width ${multiplierSlice.width}' 'must be same length as log(radix)+1=${log2Ceil(radix) + 1}'); @@ -56,7 +65,8 @@ class RadixEncoder { final inputXor = Logic(width: width); inputXor <= (multiplierSlice ^ (multiplierSlice >>> 1)) - .slice(multiplierSlice.width - 1, 0); + .slice(multiplierSlice.width - 1, 0) + .named('${multiplierSlice.name}_xor', naming: Naming.mergeable); final multiples = []; for (var i = 2; i < radix + 1; i += 2) { @@ -73,11 +83,16 @@ class RadixEncoder { for (var j = 0; j < width - 1; j++) if (multiplesDisagree[j].isZero) if (senseMultiples[j].isZero) ~inputXor[j] else inputXor[j] - ].swizzle().and()); + ].swizzle().and().named('multiple${i}_of_${multiplierSlice.name}', + naming: Naming.mergeable)); } - return RadixEncode._(multiples.rswizzle(), - multiples.rswizzle().or() & multiplierSlice[multiplierSlice.width - 1]); + final multiplesR = multiples + .rswizzle() + .named('multiples_reversed_r$row', naming: Naming.mergeable); + + return RadixEncode._(multiplesR, + multiplesR.or() & multiplierSlice[multiplierSlice.width - 1], row); } } @@ -96,6 +111,8 @@ class MultiplierEncoder { /// Store the [RadixEncoder] used. late final RadixEncoder _encoder; + late final _encodings = []; + /// Generate the Booth encoding of an input [multiplier] using /// [radixEncoder]. /// @@ -122,20 +139,26 @@ class MultiplierEncoder { // slices overlap by 1 and start at -1a if (selectSignedMultiplier == null) { _extendedMultiplier = (signedMultiplier - ? multiplier.signExtend(rows * (log2Ceil(radixEncoder.radix))) - : multiplier.zeroExtend(rows * (log2Ceil(radixEncoder.radix)))); + ? multiplier.signExtend(rows * (log2Ceil(radixEncoder.radix))) + : multiplier.zeroExtend(rows * (log2Ceil(radixEncoder.radix)))) + .named('extended_multiplier', naming: Naming.mergeable); } else { final len = multiplier.width; final sign = multiplier[len - 1]; final extension = [ - for (var i = len - 1; i < (rows * (log2Ceil(radixEncoder.radix))); i++) + for (var i = len; i < (rows * (log2Ceil(radixEncoder.radix))); i++) mux(selectSignedMultiplier, sign, Const(0)) ]; - _extendedMultiplier = (multiplier.elements + extension).rswizzle(); + _extendedMultiplier = (multiplier.elements + extension) + .rswizzle() + .named('extended_multiplier', naming: Naming.mergeable); + } + for (var i = 0; i < rows; i++) { + _encodings.add(getEncoding(i)); } } - /// Retrieve the Booth encoding for the row + /// Compute the Booth encoding for the row RadixEncode getEncoding(int row) { if (row >= rows) { throw RohdHclException('row $row is not < number of encoding rows $rows'); @@ -149,7 +172,15 @@ class MultiplierEncoder { _extendedMultiplier.slice(base + log2Ceil(_encoder.radix) - 1, base), Const(0) ].swizzle() - ]; - return _encoder.encode(multiplierSlice.first); + ].first.named('mult_slice_r$row', naming: Naming.mergeable); + return _encoder.encode(multiplierSlice, row); + } + + /// Retrieve the Booth encoding for the row + RadixEncode fetchEncoding(int row) { + if (row >= rows) { + throw RohdHclException('row $row is not < number of encoding rows $rows'); + } + return _encodings[row]; } } diff --git a/lib/src/arithmetic/ones_complement_adder.dart b/lib/src/arithmetic/ones_complement_adder.dart index 09b918798..657aca7ad 100644 --- a/lib/src/arithmetic/ones_complement_adder.dart +++ b/lib/src/arithmetic/ones_complement_adder.dart @@ -53,31 +53,35 @@ class OnesComplementAdder extends Adder { " configuration, or a boolean parameter 'subtract' for " 'generation time configuration, but not both.'); } - final doSubtract = - subtractIn ?? (subtract != null ? Const(subtract) : Const(0)); - final ax = a.zeroExtend(a.width); - final bx = b.zeroExtend(b.width); + final doSubtract = + (subtractIn ?? (subtract != null ? Const(subtract) : Const(0))) + .named('dosubtract', naming: Naming.mergeable); - final adder = - adderGen(ax, mux(doSubtract, ~bx, bx), carryIn: carryIn ?? Const(0)); + final adderSum = + adderGen(a, mux(doSubtract, ~b, b), carryIn: carryIn ?? Const(0)) + .sum + .named('adderSum', naming: Naming.mergeable); if (this.carryOut != null) { - this.carryOut! <= adder.sum[-1]; + this.carryOut! <= adderSum[-1]; } - final endAround = adder.sum[-1]; - final magnitude = adder.sum.slice(a.width - 1, 0); + final endAround = adderSum[-1].named('endaround'); + final magnitude = adderSum.slice(a.width - 1, 0).named('magnitude'); + + final incrementer = ParallelPrefixIncr(magnitude); + final magnitudep1 = incrementer.out.named('magnitude_plus1'); sum <= mux( doSubtract, mux( endAround, - [if (this.carryOut != null) magnitude else magnitude + 1] + [if (this.carryOut != null) magnitude else magnitudep1] .first, ~magnitude) .zeroExtend(sum.width), - adder.sum); + adderSum); _sign <= mux(doSubtract, ~endAround, Const(0)); } } diff --git a/lib/src/arithmetic/parallel_prefix_operations.dart b/lib/src/arithmetic/parallel_prefix_operations.dart index 9c062eed7..f1cec58a7 100644 --- a/lib/src/arithmetic/parallel_prefix_operations.dart +++ b/lib/src/arithmetic/parallel_prefix_operations.dart @@ -103,13 +103,14 @@ class KoggeStone extends ParallelPrefix { while (skip < inps.length) { for (var i = inps.length - 1; i >= skip; --i) { - iseq[i] = op(iseq[i - skip], iseq[i]); + iseq[i] = op(iseq[i - skip], iseq[i]) + .named('ks_skip${skip}_i$i', naming: Naming.mergeable); } skip *= 2; } iseq.forEachIndexed((i, el) { - _oseq[i] <= el; + _oseq[i] <= el.named('o_$i', naming: Naming.mergeable); }); } } @@ -130,7 +131,8 @@ class BrentKung extends ParallelPrefix { var skip = 2; while (skip <= inps.length) { for (var i = skip - 1; i < inps.length; i += skip) { - iseq[i] = op(iseq[i - skip ~/ 2], iseq[i]); + iseq[i] = op(iseq[i - skip ~/ 2], iseq[i]) + .named('reduce_$i', naming: Naming.mergeable); } skip *= 2; } @@ -139,18 +141,20 @@ class BrentKung extends ParallelPrefix { skip = largestPow2LessThan(inps.length); while (skip > 2) { for (var i = 3 * (skip ~/ 2) - 1; i < inps.length; i += skip) { - iseq[i] = op(iseq[i - skip ~/ 2], iseq[i]); + iseq[i] = op(iseq[i - skip ~/ 2], iseq[i]) + .named('prefix_$i', naming: Naming.mergeable); } skip ~/= 2; } // Final row for (var i = 2; i < inps.length; i += 2) { - iseq[i] = op(iseq[i - 1], iseq[i]); + iseq[i] = + op(iseq[i - 1], iseq[i]).named('final_$i', naming: Naming.mergeable); } iseq.forEachIndexed((i, el) { - _oseq[i] <= el; + _oseq[i] <= el.named('o_$i', naming: Naming.mergeable); }); } } @@ -186,7 +190,8 @@ class ParallelPrefixPriorityFinder extends Module { super.name = 'parallel_prefix_finder'}) { inp = addInput('inp', inp, width: inp.width); final u = ParallelPrefixOrScan(inp, ppGen: ppGen); - addOutput('out', width: inp.width) <= (u.out & ~(u.out << Const(1))); + addOutput('out', width: inp.width) <= + (u.out & ~(u.out << Const(1))).named('pos', naming: Naming.mergeable); } } @@ -223,11 +228,16 @@ class ParallelPrefixPriorityEncoder extends Module { valid <= this.valid!; } final u = ParallelPrefixPriorityFinder(inp, ppGen: ppGen); - final pos = OneHotToBinary(u.out).binary.zeroExtend(sz); + final pos = OneHotToBinary(u.out) + .binary + .zeroExtend(sz) + .named('pos', naming: Naming.mergeable); if (this.valid != null) { this.valid! <= pos.or() | inp[0]; } - out <= mux(pos.or() | inp[0], pos, Const(inp.width + 1, width: sz)); + out <= + mux(pos.or() | inp[0], pos, Const(inp.width + 1, width: sz)) + .named('encoded_pos', naming: Naming.mergeable); } } @@ -247,17 +257,21 @@ class ParallelPrefixAdder extends Adder { l.insert( 0, [(a[0] & b[0]) | (a[0] & cin) | (b[0] & cin), a[0] | b[0] | cin] - .swizzle()); + .swizzle() + .named('pg_base', naming: Naming.mergeable)); final u = ppGen( - l, (lhs, rhs) => [rhs[1] | rhs[0] & lhs[1], rhs[0] & lhs[0]].swizzle()); + l, + (lhs, rhs) => [rhs[1] | rhs[0] & lhs[1], rhs[0] & lhs[0]] + .swizzle() + .named('pg', naming: Naming.mergeable)); sum <= [ u.val[a.width - 1][1], List.generate( a.width, - (i) => (i == 0) - ? a[i] ^ b[i] ^ cin - : a[i] ^ b[i] ^ u.val[i - 1][1]).rswizzle() + (i) => + ((i == 0) ? a[i] ^ b[i] ^ cin : a[i] ^ b[i] ^ u.val[i - 1][1]) + .named('t_$i')).rswizzle() ].swizzle(); } } @@ -277,7 +291,9 @@ class ParallelPrefixIncr extends Module { final u = ppGen(inp.elements, (lhs, rhs) => rhs & lhs); addOutput('out', width: inp.width) <= (List.generate( - inp.width, (i) => ((i == 0) ? ~inp[i] : inp[i] ^ u.val[i - 1])) + inp.width, + (i) => + ((i == 0) ? ~inp[i] : inp[i] ^ u.val[i - 1]).named('o_$i')) .rswizzle()); } } @@ -297,7 +313,9 @@ class ParallelPrefixDecr extends Module { final u = ppGen((~inp).elements, (lhs, rhs) => rhs & lhs); addOutput('out', width: inp.width) <= (List.generate( - inp.width, (i) => ((i == 0) ? ~inp[i] : inp[i] ^ u.val[i - 1])) + inp.width, + (i) => + ((i == 0) ? ~inp[i] : inp[i] ^ u.val[i - 1]).named('o_$i')) .rswizzle()); } } diff --git a/lib/src/arithmetic/partial_product_generator.dart b/lib/src/arithmetic/partial_product_generator.dart index c8ceeb962..634e2b8e6 100644 --- a/lib/src/arithmetic/partial_product_generator.dart +++ b/lib/src/arithmetic/partial_product_generator.dart @@ -19,7 +19,8 @@ class SignBit extends Logic { bool inverted = false; /// Construct a sign bit to store - SignBit(Logic inl, {this.inverted = false}) : super(name: inl.name) { + SignBit(Logic inl, {this.inverted = false}) + : super(name: '${inl.name}_signbit', naming: Naming.mergeable) { this <= inl; } } @@ -230,19 +231,19 @@ abstract class PartialProductGeneratorBase extends PartialProductArray { this.selectSignedMultiplicand, this.selectSignedMultiplier, super.name = 'ppg'}) { - if (signedMultiplier && (selectSignedMultiplier != null)) { - throw RohdHclException('sign reconfiguration requires signed=false'); - } if (signedMultiplicand && (selectSignedMultiplicand != null)) { throw RohdHclException('multiplicand sign reconfiguration requires ' 'signedMultiplicand=false'); } - encoder = MultiplierEncoder(multiplier, radixEncoder, - signedMultiplier: signedMultiplier, - selectSignedMultiplier: selectSignedMultiplier); + if (signedMultiplier && (selectSignedMultiplier != null)) { + throw RohdHclException('sign reconfiguration requires signed=false'); + } selector = MultiplicandSelector(radixEncoder.radix, multiplicand, signedMultiplicand: signedMultiplicand, selectSignedMultiplicand: selectSignedMultiplicand); + encoder = MultiplierEncoder(multiplier, radixEncoder, + signedMultiplier: signedMultiplier, + selectSignedMultiplier: selectSignedMultiplier); if (multiplicand.width < selector.shift) { throw RohdHclException('multiplicand width must be greater than ' @@ -256,6 +257,8 @@ abstract class PartialProductGeneratorBase extends PartialProductArray { } /// Perform sign extension (defined in child classes) + @Deprecated('Replace this call with a construction of a ' + '[PartialProductSignExtension] class') @protected void signExtend(); @@ -264,7 +267,10 @@ abstract class PartialProductGeneratorBase extends PartialProductArray { partialProducts = >[]; for (var row = 0; row < encoder.rows; row++) { partialProducts.add(List.generate( - selector.width, (i) => selector.select(i, encoder.getEncoding(row)))); + selector.width, + (i) => selector + .select(i, encoder.fetchEncoding(row)) + .named('pp_r${row}_c$i', naming: Naming.mergeable))); } for (var row = 0; row < rows; row++) { rowShift.add(row * shift); diff --git a/lib/src/arithmetic/partial_product_sign_extend.dart b/lib/src/arithmetic/partial_product_sign_extend.dart index caffc6eb5..03dd678a9 100644 --- a/lib/src/arithmetic/partial_product_sign_extend.dart +++ b/lib/src/arithmetic/partial_product_sign_extend.dart @@ -159,12 +159,12 @@ class NoneSignExtension extends PartialProductSignExtension { } /// A concrete base class for partial product generation -class PartialProductGeneratorBasic extends PartialProductGeneratorBase { +class PartialProductGenerator extends PartialProductGeneratorBase { /// The extension routine we will be using. late final PartialProductSignExtension extender; /// Construct a none sign extending Partial Product Generator - PartialProductGeneratorBasic( + PartialProductGenerator( super.multiplicand, super.multiplier, super.radixEncoder, {super.signedMultiplicand, super.signedMultiplier, @@ -197,7 +197,13 @@ class BruteSignExtension extends PartialProductSignExtension { throw RohdHclException('Partial Product array already sign-extended'); } isSignExtended = true; - final signs = [for (var r = 0; r < rows; r++) encoder.getEncoding(r).sign]; + final signs = [ + for (var r = 0; r < rows; r++) + encoder + .fetchEncoding(r) + .sign + .named('sign_r$r', naming: Naming.mergeable) + ]; for (var row = 0; row < rows; row++) { final addend = partialProducts[row]; final Logic sign; @@ -223,12 +229,14 @@ class BruteSignExtension extends PartialProductSignExtension { /// A wrapper class for [BruteSignExtension] we used /// during refactoring to be compatible with old calls. +@Deprecated('Use BruteSignExtension class after PartialProductGeneratorBasic') class PartialProductGeneratorBruteSignExtension extends PartialProductGeneratorBase { /// The extension routine we will be using. late final PartialProductSignExtension extender; /// Construct a compact rect sign extending Partial Product Generator + @Deprecated('Use BruteSignExtension class after PartialProductGeneratorBasic') PartialProductGeneratorBruteSignExtension( super.multiplicand, super.multiplier, super.radixEncoder, {super.signedMultiplicand, @@ -274,7 +282,13 @@ class CompactSignExtension extends PartialProductSignExtension { final lastRowSignPos = shift * lastRow; final alignRow0Sign = firstRowQStart - lastRowSignPos; - final signs = [for (var r = 0; r < rows; r++) encoder.getEncoding(r).sign]; + final signs = [ + for (var r = 0; r < rows; r++) + encoder + .fetchEncoding(r) + .sign + .named('sign_r$r', naming: Naming.mergeable) + ]; final propagate = List.generate(rows, (i) => List.filled(0, Logic(), growable: true)); @@ -282,7 +296,8 @@ class CompactSignExtension extends PartialProductSignExtension { for (var row = 0; row < rows; row++) { propagate[row].add(signs[row]); for (var col = 0; col < 2 * (shift - 1); col++) { - propagate[row].add(partialProducts[row][col]); + propagate[row].add(partialProducts[row][col] + .named('propagate_r${row}_c$col', naming: Naming.mergeable)); } // Last row has extend sign propagation to Q start if (row == lastRow) { @@ -292,44 +307,52 @@ class CompactSignExtension extends PartialProductSignExtension { } } for (var col = 1; col < propagate[row].length; col++) { - propagate[row][col] = propagate[row][col] & propagate[row][col - 1]; + propagate[row][col] = (propagate[row][col] & propagate[row][col - 1]) + .named('propagate_r${row}_c$col', naming: Naming.mergeable); } } final m = List.generate(rows, (i) => List.filled(0, Logic(), growable: true)); for (var row = 0; row < rows; row++) { for (var c = 0; c < shift - 1; c++) { - m[row].add(partialProducts[row][c] ^ propagate[row][c]); + m[row].add((partialProducts[row][c] ^ propagate[row][c]) + .named('m_r${row}_c$c', naming: Naming.mergeable)); } m[row].addAll(List.filled(shift - 1, Logic())); } while (m[lastRow].length < alignRow0Sign) { m[lastRow].add(Logic()); } - + // TODO(desmonddak): this seems unused when looking at Verilog output for (var i = shift - 1; i < m[lastRow].length; i++) { - m[lastRow][i] = lastAddend[i] ^ - (i < alignRow0Sign ? propagate[lastRow][i] : Const(0)); + m[lastRow][i] = (lastAddend[i] ^ + (i < alignRow0Sign ? propagate[lastRow][i] : Const(0))) + .named('m_lastr_c$i', naming: Naming.mergeable); } final remainders = List.filled(rows, Logic()); for (var row = 0; row < lastRow; row++) { - remainders[row] = propagate[row][shift - 1]; + remainders[row] = propagate[row][shift - 1] + .named('remainders_r$row', naming: Naming.mergeable); } - remainders[lastRow] <= propagate[lastRow][max(alignRow0Sign, 0)]; + remainders[lastRow] <= + propagate[lastRow][max(alignRow0Sign, 0)] + .named('remainders_lastrow', naming: Naming.mergeable); // Compute Sign extension for row==0 - final Logic firstSign; + final firstSign = Logic(name: 'firstsign', naming: Naming.mergeable); if (selectSignedMultiplicand == null) { - firstSign = - signedMultiplicand ? SignBit(firstAddend.last) : SignBit(signs[0]); + firstSign <= + (signedMultiplicand ? SignBit(firstAddend.last) : SignBit(signs[0])); } else { - firstSign = + firstSign <= SignBit(mux(selectSignedMultiplicand!, firstAddend.last, signs[0])); } final q = [ - firstSign ^ remainders[lastRow], - ~(firstSign & ~remainders[lastRow]), + (firstSign ^ remainders[lastRow]) + .named('qfirst', naming: Naming.mergeable), + (~(firstSign & ~remainders[lastRow])) + .named('q_last', naming: Naming.mergeable), ]; q.insertAll(1, List.filled(shift - 1, ~q[1])); @@ -365,12 +388,15 @@ class CompactSignExtension extends PartialProductSignExtension { /// A wrapper class for [CompactSignExtension] we used /// during refactoring to be compatible with old calls. +@Deprecated('Use CompactSignExtension class after PartialProductGeneratorBasic') class PartialProductGeneratorCompactSignExtension extends PartialProductGeneratorBase { /// The extension routine we will be using. late final PartialProductSignExtension extender; /// Construct a compact sign extending Partial Product Generator + @Deprecated( + 'Use CompactSignExtension class after PartialProductGeneratorBasic') PartialProductGeneratorCompactSignExtension( super.multiplicand, super.multiplier, super.radixEncoder, {super.signedMultiplicand, @@ -418,7 +444,13 @@ class StopBitsSignExtension extends PartialProductSignExtension { ? (finalCarryRelPos / shift).floor() : 0; - final signs = [for (var r = 0; r < rows; r++) encoder.getEncoding(r).sign]; + final signs = [ + for (var r = 0; r < rows; r++) + encoder + .fetchEncoding(r) + .sign + .named('sign_r$r', naming: Naming.mergeable) + ]; for (var row = 0; row < rows; row++) { final addend = partialProducts[row]; @@ -469,16 +501,17 @@ class StopBitsSignExtension extends PartialProductSignExtension { } } -// - -/// A wrapper class for [StopBitsSignExtension] we used -/// during refactoring to be compatible with old calls. +/// Stop-bits based sign extension +@Deprecated( + 'Use StopBitsSignExtension class after PartialProductGeneratorBasic') class PartialProductGeneratorStopBitsSignExtension extends PartialProductGeneratorBase { /// The extension routine we will be using. late final PartialProductSignExtension extender; /// Construct a stop bits sign extending Partial Product Generator + @Deprecated( + 'Use StopBitsSignExtension class after PartialProductGeneratorBasic') PartialProductGeneratorStopBitsSignExtension( super.multiplicand, super.multiplier, super.radixEncoder, {super.signedMultiplicand, @@ -496,14 +529,18 @@ class PartialProductGeneratorStopBitsSignExtension } } -/// A wrapper class for CompactRectSignExtension we used +/// A wrapper class for [CompactRectSignExtension] we used /// during refactoring to be compatible with old calls. +@Deprecated( + 'Use CompactRectSignExtension class after PartialProductGeneratorBasic') class PartialProductGeneratorCompactRectSignExtension extends PartialProductGeneratorBase { /// The extension routine we will be using. late final PartialProductSignExtension extender; /// Construct a compact rect sign extending Partial Product Generator + @Deprecated( + 'Use CompactRectSignExtension class after PartialProductGeneratorBasic') PartialProductGeneratorCompactRectSignExtension( super.multiplicand, super.multiplier, super.radixEncoder, {super.signedMultiplicand, @@ -546,7 +583,13 @@ class CompactRectSignExtension extends PartialProductSignExtension { final align = firstRowQStart - lastRowSignPos; - final signs = [for (var r = 0; r < rows; r++) encoder.getEncoding(r).sign]; + final signs = [ + for (var r = 0; r < rows; r++) + encoder + .fetchEncoding(r) + .sign + .named('sign_r$r', naming: Naming.mergeable) + ]; // Compute propgation info for folding sign bits into main rows final propagate = @@ -555,7 +598,8 @@ class CompactRectSignExtension extends PartialProductSignExtension { for (var row = 0; row < rows; row++) { propagate[row].add(SignBit(signs[row])); for (var col = 0; col < 2 * (shift - 1); col++) { - propagate[row].add(partialProducts[row][col]); + propagate[row].add(partialProducts[row][col] + .named('propagate_r${row}_c$col', naming: Naming.mergeable)); } // Last row has extend sign propagation to Q start if (row == lastRow) { @@ -566,7 +610,8 @@ class CompactRectSignExtension extends PartialProductSignExtension { } // Now compute the propagation logic for (var col = 1; col < propagate[row].length; col++) { - propagate[row][col] = propagate[row][col] & propagate[row][col - 1]; + propagate[row][col] = (propagate[row][col] & propagate[row][col - 1]) + .named('propagate_r${row}_c$col', naming: Naming.mergeable); } } @@ -575,7 +620,8 @@ class CompactRectSignExtension extends PartialProductSignExtension { List.generate(rows, (i) => List.filled(0, Logic(), growable: true)); for (var row = 0; row < rows; row++) { for (var c = 0; c < shift - 1; c++) { - m[row].add(partialProducts[row][c] ^ propagate[row][c]); + m[row].add((partialProducts[row][c] ^ propagate[row][c]) + .named('m_r${row}_c$c', naming: Naming.mergeable)); } m[row].addAll(List.filled(shift - 1, Logic())); } @@ -585,14 +631,17 @@ class CompactRectSignExtension extends PartialProductSignExtension { } for (var i = shift - 1; i < m[lastRow].length; i++) { m[lastRow][i] = - lastAddend[i] ^ (i < align ? propagate[lastRow][i] : Const(0)); + (lastAddend[i] ^ (i < align ? propagate[lastRow][i] : Const(0))) + .named('m_lastrow_$i', naming: Naming.mergeable); } final remainders = List.filled(rows, Logic()); for (var row = 0; row < lastRow; row++) { - remainders[row] = propagate[row][shift - 1]; + remainders[row] = propagate[row][shift - 1] + .named('remainder_r$row', naming: Naming.mergeable); } - remainders[lastRow] = propagate[lastRow][align > 0 ? align : 0]; + remainders[lastRow] = propagate[lastRow][align > 0 ? align : 0] + .named('remainder_lastrow', naming: Naming.mergeable); // Merge 'm' into the LSBs of each addend for (var row = 0; row < rows; row++) { @@ -618,13 +667,13 @@ class CompactRectSignExtension extends PartialProductSignExtension { // Insert the lastRow sign: Either in firstRow's Q if there is a // collision or in another row if it lands beyond the Q sign extension - final Logic firstSign; + final firstSign = Logic(name: 'firstsign', naming: Naming.mergeable); if (selectSignedMultiplicand == null) { - firstSign = - signedMultiplicand ? SignBit(firstAddend.last) : SignBit(signs[0]); + firstSign <= + (signedMultiplicand ? SignBit(firstAddend.last) : SignBit(signs[0])); } else { - firstSign = - SignBit(mux(selectSignedMultiplicand!, firstAddend.last, signs[0])); + firstSign <= + (SignBit(mux(selectSignedMultiplicand!, firstAddend.last, signs[0]))); } final lastSign = SignBit(remainders[lastRow]); // Compute Sign extension MSBs for firstRow diff --git a/lib/src/arithmetic/signals/fixed_point_logic.dart b/lib/src/arithmetic/signals/fixed_point_logic.dart index 9c24e1126..cf503b8cb 100644 --- a/lib/src/arithmetic/signals/fixed_point_logic.dart +++ b/lib/src/arithmetic/signals/fixed_point_logic.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // fixed_point_logic.dart @@ -43,6 +43,7 @@ class FixedPoint extends Logic { } /// Clone for I/O ports. + @override FixedPoint clone({String? name}) => FixedPoint(signed: signed, m: m, n: n); /// Cast logic to fixed point diff --git a/lib/src/arithmetic/signals/floating_point_logic.dart b/lib/src/arithmetic/signals/floating_point_logic.dart index ebe33ae94..20675f3fb 100644 --- a/lib/src/arithmetic/signals/floating_point_logic.dart +++ b/lib/src/arithmetic/signals/floating_point_logic.dart @@ -24,21 +24,39 @@ class FloatingPoint extends LogicStructure { /// [sign] bit with '1' representing a negative number final Logic sign; + /// Utility to keep track of the Logic structure name by attaching it + /// to the Logic signal name in the output Verilog. + static String _nameJoin(String? structName, String signalName) { + if (structName == null) { + return signalName; + } + return '${structName}_$signalName'; + } + /// [FloatingPoint] Constructor for a variable size binary /// floating point number - FloatingPoint({required int exponentWidth, required int mantissaWidth}) + FloatingPoint( + {required int exponentWidth, required int mantissaWidth, String? name}) : this._( - Logic(name: 'sign'), - Logic(width: exponentWidth, name: 'exponent'), - Logic(width: mantissaWidth, name: 'mantissa')); - - FloatingPoint._(this.sign, this.exponent, this.mantissa, {String? name}) - : super([mantissa, exponent, sign], name: name ?? 'FloatingPoint'); + Logic(name: _nameJoin(name, 'sign'), naming: Naming.mergeable), + Logic( + width: exponentWidth, + name: _nameJoin(name, 'exponent'), + naming: Naming.mergeable), + Logic( + width: mantissaWidth, + name: _nameJoin(name, 'mantissa'), + naming: Naming.mergeable), + name: name); + + FloatingPoint._(this.sign, this.exponent, this.mantissa, {super.name}) + : super([mantissa, exponent, sign]); @override FloatingPoint clone({String? name}) => FloatingPoint( exponentWidth: exponent.width, mantissaWidth: mantissa.width, + name: name, ); /// Return the [FloatingPointValue] @@ -47,42 +65,45 @@ class FloatingPoint extends LogicStructure { /// Return a Logic true if this FloatingPoint contains a normal number, /// defined as having mantissa in the range [1,2) - late final Logic isNormal = Logic(name: 'isNormal', naming: Naming.mergeable) - ..gets(exponent.neq(LogicValue.zero.zeroExtend(exponent.width))); + late final Logic isNormal = exponent + .neq(LogicValue.zero.zeroExtend(exponent.width)) + .named(_nameJoin('isNormal', name), naming: Naming.mergeable); /// Return a Logic true if this FloatingPoint is Not a Number (NaN) /// by having its exponent field set to the NaN value (typically all /// ones) and a non-zero mantissa. - late final isNaN = Logic(name: 'isNaN', naming: Naming.mergeable) - ..gets(exponent.eq(floatingPointValue.nan.exponent) & mantissa.or()); + late final isNaN = exponent.eq(floatingPointValue.nan.exponent) & + mantissa.or().named( + _nameJoin('isNaN', name), + naming: Naming.mergeable, + ); /// Return a Logic true if this FloatingPoint is an infinity /// by having its exponent field set to the NaN value (typically all /// ones) and a zero mantissa. - late final isInfinity = Logic(name: 'isInfinity', naming: Naming.mergeable) - ..gets(exponent.eq(floatingPointValue.infinity.exponent) & ~mantissa.or()); + late final isInfinity = + (exponent.eq(floatingPointValue.infinity.exponent) & ~mantissa.or()) + .named(_nameJoin('isInfinity', name), naming: Naming.mergeable); /// Return a Logic true if this FloatingPoint is an zero /// by having its exponent field set to the NaN value (typically all /// ones) and a zero mantissa. - late final isZero = Logic(name: 'isZero', naming: Naming.mergeable) - ..gets(exponent.eq(floatingPointValue.zero.exponent) & ~mantissa.or()); + late final isZero = + (exponent.eq(floatingPointValue.zero.exponent) & ~mantissa.or()) + .named(_nameJoin('isZero', name), naming: Naming.mergeable); /// Return the zero exponent representation for this type of FloatingPoint - late final zeroExponent = Logic( - name: 'zeroExponent', naming: Naming.mergeable, width: exponent.width) - ..gets(Const(LogicValue.zero, width: exponent.width)); + late final zeroExponent = Const(LogicValue.zero, width: exponent.width) + .named(_nameJoin('zeroExponent', name), naming: Naming.mergeable); /// Return the one exponent representation for this type of FloatingPoint - late final oneExponent = Logic( - name: 'oneExponent', naming: Naming.mergeable, width: exponent.width) - ..gets(Const(LogicValue.one, width: exponent.width)); + late final oneExponent = Const(LogicValue.one, width: exponent.width) + .named(_nameJoin('oneExponent', name), naming: Naming.mergeable); /// Return the exponent Logic value representing the true zero exponent /// 2^0 = 1 often termed [bias] or the offset of the stored exponent. - late final bias = - Logic(name: 'bias', naming: Naming.mergeable, width: exponent.width) - ..gets(Const((1 << exponent.width - 1) - 1, width: exponent.width)); + late final bias = Const((1 << exponent.width - 1) - 1, width: exponent.width) + .named(_nameJoin('bias', name), naming: Naming.mergeable); /// Construct a FloatingPoint that represents infinity for this FP type. FloatingPoint inf({Logic? sign, bool negative = false}) => FloatingPoint.inf( @@ -129,53 +150,66 @@ class FloatingPoint extends LogicStructure { /// Single floating point representation class FloatingPoint32 extends FloatingPoint { /// Construct a 32-bit (single-precision) floating point number - FloatingPoint32() + FloatingPoint32({super.name}) : super( exponentWidth: FloatingPoint32Value.exponentWidth, mantissaWidth: FloatingPoint32Value.mantissaWidth); + + @override + FloatingPoint32 clone({String? name}) => FloatingPoint32(name: name); } /// Double floating point representation class FloatingPoint64 extends FloatingPoint { /// Construct a 64-bit (double-precision) floating point number - FloatingPoint64() + FloatingPoint64({super.name}) : super( exponentWidth: FloatingPoint64Value.exponentWidth, mantissaWidth: FloatingPoint64Value.mantissaWidth); + @override + FloatingPoint64 clone({String? name}) => FloatingPoint64(name: name); } /// Eight-bit floating point representation for deep learning: E4M3 class FloatingPoint8E4M3 extends FloatingPoint { /// Construct an 8-bit floating point number - FloatingPoint8E4M3() + FloatingPoint8E4M3({super.name}) : super( mantissaWidth: FloatingPoint8E4M3Value.mantissaWidth, exponentWidth: FloatingPoint8E4M3Value.exponentWidth); + @override + FloatingPoint8E4M3 clone({String? name}) => FloatingPoint8E4M3(name: name); } /// Eight-bit floating point representation for deep learning: E5M2 class FloatingPoint8E5M2 extends FloatingPoint { /// Construct an 8-bit floating point number - FloatingPoint8E5M2() + FloatingPoint8E5M2({super.name}) : super( mantissaWidth: FloatingPoint8E5M2Value.mantissaWidth, exponentWidth: FloatingPoint8E5M2Value.exponentWidth); + @override + FloatingPoint8E5M2 clone({String? name}) => FloatingPoint8E5M2(name: name); } /// Sixteen-bit BF16 floating point representation class FloatingPointBF16 extends FloatingPoint { /// Construct a BF16 16-bit floating point number - FloatingPointBF16() + FloatingPointBF16({super.name}) : super( mantissaWidth: FloatingPointBF16Value.mantissaWidth, exponentWidth: FloatingPointBF16Value.exponentWidth); + @override + FloatingPointBF16 clone({String? name}) => FloatingPointBF16(name: name); } /// Sixteen-bit floating point representation class FloatingPoint16 extends FloatingPoint { /// Construct a 16-bit floating point number - FloatingPoint16() + FloatingPoint16({super.name}) : super( mantissaWidth: FloatingPoint16Value.mantissaWidth, exponentWidth: FloatingPoint16Value.exponentWidth); + @override + FloatingPoint16 clone({String? name}) => FloatingPoint16(name: name); } diff --git a/lib/src/arithmetic/values/floating_point_values/floating_point_32_value.dart b/lib/src/arithmetic/values/floating_point_values/floating_point_32_value.dart index fe9b23c12..38c8f7c87 100644 --- a/lib/src/arithmetic/values/floating_point_values/floating_point_32_value.dart +++ b/lib/src/arithmetic/values/floating_point_values/floating_point_32_value.dart @@ -9,6 +9,7 @@ // Max Korbel // Desmond A Kirkpatrick +import 'dart:math'; import 'dart:typed_data'; import 'package:meta/meta.dart'; import 'package:rohd/rohd.dart'; @@ -70,6 +71,15 @@ class FloatingPoint32Value extends FloatingPointValue { : super.ofInts( exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + /// Generate a random [FloatingPoint32Value], supplying random seed [rv]. + factory FloatingPoint32Value.random(Random rv, {bool normal = false}) { + final randFloat = FloatingPointValue.random(rv, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + normal: normal); + return FloatingPoint32Value.ofLogicValue(randFloat.value); + } + /// Numeric conversion of a [FloatingPoint32Value] from a host double factory FloatingPoint32Value.ofDouble(double inDouble) { final byteData = ByteData(4)..setFloat32(0, inDouble); @@ -84,6 +94,15 @@ class FloatingPoint32Value extends FloatingPointValue { mantissa: accum.slice(mantissaWidth - 1, 0)); } + /// Convert a floating point number into a [FloatingPoint32Value] + /// representation. This form performs NO ROUNDING. + factory FloatingPoint32Value.ofDoubleUnrounded(double inDouble) { + final fpv = FloatingPointValue.ofDoubleUnrounded(inDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + + return FloatingPoint32Value.ofLogicValue(fpv.value); + } + /// Construct a [FloatingPoint32Value] from a Logic word factory FloatingPoint32Value.ofLogicValue(LogicValue val) => FloatingPointValue.buildOfLogicValue( diff --git a/lib/src/arithmetic/values/floating_point_values/floating_point_bf16_value.dart b/lib/src/arithmetic/values/floating_point_values/floating_point_bf16_value.dart index a8e83a999..b5dc3b71b 100644 --- a/lib/src/arithmetic/values/floating_point_values/floating_point_bf16_value.dart +++ b/lib/src/arithmetic/values/floating_point_values/floating_point_bf16_value.dart @@ -9,6 +9,8 @@ // Max Korbel // Desmond A Kirkpatrick +import 'dart:math'; + import 'package:meta/meta.dart'; import 'package:rohd/rohd.dart'; import 'package:rohd_hcl/rohd_hcl.dart'; @@ -69,6 +71,15 @@ class FloatingPointBF16Value extends FloatingPointValue { : super.ofInts( exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + /// Generate a random [FloatingPointBF16Value], supplying random seed [rv]. + factory FloatingPointBF16Value.random(Random rv, {bool normal = false}) { + final randFloat = FloatingPointValue.random(rv, + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth, + normal: normal); + return FloatingPointBF16Value.ofLogicValue(randFloat.value); + } + /// Numeric conversion of a [FloatingPointBF16Value] from a host double factory FloatingPointBF16Value.ofDouble(double inDouble) { final fpv = FloatingPointValue.ofDouble(inDouble, @@ -77,6 +88,59 @@ class FloatingPointBF16Value extends FloatingPointValue { return FloatingPointBF16Value.ofLogicValue(fpv.value); } + /// Convert a floating point number into a [FloatingPointBF16Value] + /// representation. This form performs NO ROUNDING. + factory FloatingPointBF16Value.ofDoubleUnrounded(double inDouble) { + final fpv = FloatingPointValue.ofDoubleUnrounded(inDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + + return FloatingPointBF16Value.ofLogicValue(fpv.value); + } + + // TODO(desmonddak): We need to add all these operators for subclasses unless + // We figure out a way to use templates to do them. Currently just BF16 + + /// Multiply operation for [FloatingPointBF16Value] + @override + FloatingPointBF16Value operator *( + covariant FloatingPointBF16Value multiplicand) { + final fpv = super * multiplicand; + return FloatingPointBF16Value.ofLogicValue(fpv.value); + } + + /// Addition operation for [FloatingPointBF16Value] + @override + FloatingPointBF16Value operator +(covariant FloatingPointBF16Value addend) { + final fpv = super + addend; + return FloatingPointBF16Value.ofLogicValue(fpv.value); + } + + /// Divide operation for [FloatingPointBF16Value] + @override + FloatingPointBF16Value operator /(covariant FloatingPointBF16Value divisor) { + final fpv = super / divisor; + return FloatingPointBF16Value.ofLogicValue(fpv.value); + } + + /// Subtract operation for [FloatingPointBF16Value] + @override + FloatingPointBF16Value operator -(covariant FloatingPointBF16Value subend) { + final fpv = super - subend; + return FloatingPointBF16Value.ofLogicValue(fpv.value); + } + + /// Negate operation for [FloatingPointBF16Value] + @override + FloatingPointBF16Value negate() => FloatingPointBF16Value( + sign: sign.isZero ? LogicValue.one : LogicValue.zero, + exponent: exponent, + mantissa: mantissa); + + /// Absolute value operation for [FloatingPointBF16Value] + @override + FloatingPointBF16Value abs() => FloatingPointBF16Value( + sign: LogicValue.zero, exponent: exponent, mantissa: mantissa); + /// Construct a [FloatingPointBF16Value] from a Logic word factory FloatingPointBF16Value.ofLogicValue(LogicValue val) => FloatingPointValue.buildOfLogicValue( diff --git a/lib/src/arithmetic/values/floating_point_values/floating_point_value.dart b/lib/src/arithmetic/values/floating_point_values/floating_point_value.dart index 1e1c1052f..19877e37c 100644 --- a/lib/src/arithmetic/values/floating_point_values/floating_point_value.dart +++ b/lib/src/arithmetic/values/floating_point_values/floating_point_value.dart @@ -825,4 +825,32 @@ class FloatingPointValue implements Comparable { /// Absolute value operation for [FloatingPointValue] FloatingPointValue abs() => FloatingPointValue( sign: LogicValue.zero, exponent: exponent, mantissa: mantissa); + + /// Return true if the other [FloatingPointValue] is within a rounding + /// error of this value. + bool withinRounding(FloatingPointValue other) { + if (this != other) { + final diff = (abs() - other.abs()).abs(); + if (diff.compareTo(ulp()) == 1) { + return false; + } + } + return true; + } + + /// Compute the unit in the last place for the given [FloatingPointValue] + FloatingPointValue ulp() { + if (exponent.toInt() > mantissa.width) { + final newExponent = + LogicValue.ofInt(exponent.toInt() - mantissa.width, exponent.width); + return FloatingPointValue.ofBinaryStrings( + sign.bitString, newExponent.bitString, '0' * (mantissa.width)); + } else { + // TODO(desmonddak): need to handle exponent < mantissa width by + // shifting the 1 for ULP incrementally, not just putting it at + // the end. + return FloatingPointValue.ofBinaryStrings( + sign.bitString, exponent.bitString, '${'0' * (mantissa.width - 1)}1'); + } + } } diff --git a/lib/src/component_config/components/config_floating_point_multiplier_simple.dart b/lib/src/component_config/components/config_floating_point_multiplier_simple.dart index 63620ca6c..b8db9e50c 100644 --- a/lib/src/component_config/components/config_floating_point_multiplier_simple.dart +++ b/lib/src/component_config/components/config_floating_point_multiplier_simple.dart @@ -38,10 +38,29 @@ class FloatingPointMultiplierSimpleConfigurator extends Configurator { BrentKung: BrentKung.new }; + /// Map from Type to Function for Mantissa Multiplier + static Map< + Type, + Multiplier Function(Logic term1, Logic term2, + {Logic? clk, + Logic? reset, + Logic? enable, + String name})> multGeneratorMap = { + NativeMultiplier: NativeMultiplier.new, + CompressionTreeMultiplier: (term1, term2, + {Logic? clk, Logic? reset, Logic? enable, String? name}) => + CompressionTreeMultiplier(term1, term2, 4, name: name!) + // TODO(desmonddak): put tree type, adder type, and radix options here + }; + /// Controls the type of [Adder] used for internal adders. final adderTypeKnob = ChoiceConfigKnob(adderGeneratorMap.keys.toList(), value: NativeAdder); + /// Controls the type of [Multiplier] used for mantissa multiplication. + final multTypeKnob = + ChoiceConfigKnob(multGeneratorMap.keys.toList(), value: NativeMultiplier); + /// Controls the type of [ParallelPrefix] tree used in the internal functions. final prefixTreeKnob = ChoiceConfigKnob(treeGeneratorMap.keys.toList(), value: KoggeStone); @@ -65,13 +84,13 @@ class FloatingPointMultiplierSimpleConfigurator extends Configurator { FloatingPoint( exponentWidth: exponentWidthKnob.value, mantissaWidth: mantissaWidthKnob.value), - adderGen: adderGeneratorMap[adderTypeKnob.value]!, + multGen: multGeneratorMap[multTypeKnob.value]!, ppTree: treeGeneratorMap[prefixTreeKnob.value]!); @override late final Map> knobs = UnmodifiableMapView({ - 'Adder type': adderTypeKnob, - 'Prefix tree type': prefixTreeKnob, + // 'Adder type': adderTypeKnob, + // 'Prefix tree type': prefixTreeKnob, 'Exponent width': exponentWidthKnob, 'Mantissa width': mantissaWidthKnob, 'Pipelined': pipelinedKnob, diff --git a/lib/src/encodings/tree_one_hot_to_binary.dart b/lib/src/encodings/tree_one_hot_to_binary.dart index d6df0ecb2..5a36252c3 100644 --- a/lib/src/encodings/tree_one_hot_to_binary.dart +++ b/lib/src/encodings/tree_one_hot_to_binary.dart @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // // tree_one_hot_to_binary.dart @@ -41,8 +41,8 @@ class _NodeOneHotToBinary extends Module { } else { final mid = 1 << (log2Ceil(wid) - 1); addOutput('binary', width: log2Ceil(mid + 1)); - final hi = onehot.getRange(mid).zeroExtend(mid); - final lo = onehot.getRange(0, mid).zeroExtend(mid); + final hi = onehot.getRange(mid).zeroExtend(mid).named('hi'); + final lo = onehot.getRange(0, mid).zeroExtend(mid).named('lo'); final recurse = lo | hi; final response = _NodeOneHotToBinary(recurse).binary; binary <= [hi.or(), response].swizzle(); diff --git a/lib/src/utils.dart b/lib/src/utils.dart index 585cae441..b4a52a144 100644 --- a/lib/src/utils.dart +++ b/lib/src/utils.dart @@ -90,7 +90,8 @@ Logic condFlop( (clk == null) ? d : flop(clk, d, - en: en, - reset: reset, - resetValue: resetValue, - asyncReset: asyncReset); + en: en, + reset: reset, + resetValue: resetValue, + asyncReset: asyncReset) + .named('${d.name}_flopped'); diff --git a/pubspec.yaml b/pubspec.yaml index 8367e615b..4a6f6ca25 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -12,9 +12,11 @@ environment: dependencies: collection: ^1.18.0 meta: ^1.9.1 - rohd: ^0.6.0 + rohd: ^0.6.1 rohd_vf: ^0.6.0 dev_dependencies: logging: ^1.0.1 test: ^1.25.0 + + diff --git a/test/arithmetic/addend_compressor_test.dart b/test/arithmetic/addend_compressor_test.dart index 56214dcb7..20704cb58 100644 --- a/test/arithmetic/addend_compressor_test.dart +++ b/test/arithmetic/addend_compressor_test.dart @@ -36,7 +36,7 @@ class CompressorTestMod extends Module { clk = addInput('clk', iclk); } - final pp = PartialProductGeneratorBasic(a, b, encoder, + final pp = PartialProductGenerator(a, b, encoder, signedMultiplicand: signed, signedMultiplier: signed); CompactRectSignExtension(pp).signExtend(); @@ -113,7 +113,7 @@ void main() { selectSignedMultiplicand!.put(signed ? 1 : 0); selectSignedMultiplier!.put(signed ? 1 : 0); } - final pp = PartialProductGeneratorBasic(a, b, encoder, + final pp = PartialProductGenerator(a, b, encoder, signedMultiplicand: !useSelect & signed, signedMultiplier: !useSelect & signed, selectSignedMultiplicand: selectSignedMultiplicand, diff --git a/test/arithmetic/floating_point/floating_point_adder_simple_test.dart b/test/arithmetic/floating_point/floating_point_adder_simple_test.dart index 4ce667333..1aeede8f6 100644 --- a/test/arithmetic/floating_point/floating_point_adder_simple_test.dart +++ b/test/arithmetic/floating_point/floating_point_adder_simple_test.dart @@ -41,23 +41,16 @@ void main() { final expectedDouble = fp1.floatingPointValue.toDouble() + fp2.floatingPointValue.toDouble(); - final expectedRound = FloatingPointValue.ofDouble(expectedDouble, - exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( expectedDouble, exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - - if ((computed.mantissa != expectedNoRound.mantissa) & - (computed.mantissa != expectedRound.mantissa)) { - expect(computed, equals(expectedRound), reason: ''' + expect(computed.withinRounding(expectedNoRound), true, reason: ''' $fv1 (${fv1.toDouble()})\t+ $fv2 (${fv2.toDouble()})\t= $computed (${computed.toDouble()})\tcomputed $expectedNoRound (${expectedNoRound.toDouble()})\texpected '''); - } } }); @@ -91,25 +84,17 @@ void main() { fp2.put(fv2.value); final computed = adder.sum.floatingPointValue; - - final expectedRound = FloatingPointValue.ofDouble( - fv1.toDouble() + fv2.toDouble(), - exponentWidth: exponentWidth, - mantissaWidth: mantissaWidth); - final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( fv1.toDouble() + fv2.toDouble(), exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - if ((computed != expectedNoRound) & (computed != expectedRound)) { - expect(computed, equals(expectedRound), reason: ''' + expect(computed.withinRounding(expectedNoRound), true, reason: ''' $fv1 (${fv1.toDouble()})\t+ $fv2 (${fv2.toDouble()})\t= $computed (${computed.toDouble()})\tcomputed $expectedNoRound (${expectedNoRound.toDouble()})\texpected '''); - } } } } @@ -172,24 +157,18 @@ void main() { final expectedDouble = fp1.floatingPointValue.toDouble() + fp2.floatingPointValue.toDouble(); - final expectedRound = FloatingPointValue.ofDouble(expectedDouble, - exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( expectedDouble, exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - final expected = expectedNoRound; final computed = adder.sum.floatingPointValue; - if ((computed != expectedNoRound) && (computed != expectedRound)) { - expect(computed, equals(expected), reason: ''' + expect(computed.withinRounding(expectedNoRound), true, reason: ''' $fv1 (${fv1.toDouble()})\t+ $fv2 (${fv2.toDouble()})\t= $computed (${computed.toDouble()})\tcomputed $expectedNoRound (${expectedNoRound.toDouble()})\texpected '''); - } } }); test('FP: simple adder narrow singleton test', () { @@ -260,24 +239,17 @@ void main() { final computed = adder.sum.floatingPointValue; - final expectedRound = FloatingPointValue.ofDouble( - fv1.toDouble() + fv2.toDouble(), - exponentWidth: exponentWidth, - mantissaWidth: mantissaWidth); - final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( fv1.toDouble() + fv2.toDouble(), exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - if ((computed != expectedNoRound) & (computed != expectedRound)) { - expect(computed, equals(expectedNoRound), reason: ''' + expect(computed.withinRounding(expectedNoRound), true, reason: ''' $fv1 (${fv1.toDouble()})\t+ $fv2 (${fv2.toDouble()})\t= $computed (${computed.toDouble()})\tcomputed $expectedNoRound (${expectedNoRound.toDouble()})\texpected '''); - } } await Simulator.endSimulation(); }); @@ -308,24 +280,17 @@ void main() { final computed = adder.sum.floatingPointValue; - final expectedRound = FloatingPointValue.ofDouble( - fv1.toDouble() + fv2.toDouble(), - exponentWidth: exponentWidth, - mantissaWidth: mantissaWidth); - final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( fv1.toDouble() + fv2.toDouble(), exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - if ((computed != expectedNoRound) & (computed != expectedRound)) { - expect(computed, equals(expectedNoRound), reason: ''' + expect(computed.withinRounding(expectedNoRound), true, reason: ''' $fv1 (${fv1.toDouble()})\t+ $fv2 (${fv2.toDouble()})\t= $computed (${computed.toDouble()})\tcomputed $expectedNoRound (${expectedNoRound.toDouble()})\texpected '''); - } } }); @@ -353,27 +318,21 @@ void main() { final computed = adder.sum.floatingPointValue; - final expectedRound = FloatingPointValue.ofDouble( - fv1.toDouble() + fv2.toDouble(), - exponentWidth: exponentWidth, - mantissaWidth: mantissaWidth); - final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( fv1.toDouble() + fv2.toDouble(), exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - if ((computed != expectedNoRound) & (computed != expectedRound)) { - expect(computed, equals(expectedNoRound), reason: ''' + expect(computed.withinRounding(expectedNoRound), true, reason: ''' $fv1 (${fv1.toDouble()})\t+ $fv2 (${fv2.toDouble()})\t= $computed (${computed.toDouble()})\tcomputed $expectedNoRound (${expectedNoRound.toDouble()})\texpected '''); - } } }); - test('FP: simple adder general singleton test', () { + + test('FP: simple adder general singleton test', () async { FloatingPointValue ofString(String s) => FloatingPointValue.ofSpacedBinaryString(s); @@ -387,6 +346,8 @@ void main() { fp1.put(fv1); fp2.put(fv2); final adder = FloatingPointAdderSimple(fp1, fp2); + await adder.build(); + final exponentWidth = adder.sum.exponent.width; final mantissaWidth = adder.sum.mantissa.width; diff --git a/test/arithmetic/floating_point/floating_point_multiplier_test.dart b/test/arithmetic/floating_point/floating_point_multiplier_test.dart index 15716c9ac..22acf17cf 100644 --- a/test/arithmetic/floating_point/floating_point_multiplier_test.dart +++ b/test/arithmetic/floating_point/floating_point_multiplier_test.dart @@ -49,7 +49,7 @@ void main() { final computed = multiply.product.floatingPointValue; expect(computed, equals(expected), reason: ''' - $fv1 (${fv1.toDouble()})\t+ + $fv1 (${fv1.toDouble()})\t* $fv2 (${fv2.toDouble()})\t= $computed (${computed.toDouble()})\tcomputed $expected (${expected.toDouble()})\texpected @@ -104,7 +104,7 @@ void main() { final computed = multiply.product.floatingPointValue; expect(computed, equals(expected), reason: ''' - $fv1 (${fv1.toDouble()})\t+ + $fv1 (${fv1.toDouble()})\t* $fv2 (${fv2.toDouble()})\t= $computed (${computed.toDouble()})\tcomputed $expected (${expected.toDouble()})\texpected @@ -148,7 +148,7 @@ void main() { final computed = multiply.product.floatingPointValue; expect(computed, equals(expected), reason: ''' - $fv1 (${fv1.toDouble()})\t+ + $fv1 (${fv1.toDouble()})\t* $fv2 (${fv2.toDouble()})\t= $computed (${computed.toDouble()})\tcomputed $expected (${expected.toDouble()})\texpected @@ -171,6 +171,47 @@ void main() { fp1.put(0); fp2.put(0); final multiplier = FloatingPointMultiplierSimple(fp1, fp2); + + final value = Random(51); + var cnt = 1000; + while (cnt > 0) { + final fv1 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(fv1); + fp2.put(fv2); + + final expected = FloatingPointValue.ofDoubleUnrounded( + fv1.toDouble() * fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + final computed = multiplier.product.floatingPointValue; + + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t* + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); + cnt--; + } + }); + + test('FP: simple multiplier full random with compression tree mult', + () async { + const exponentWidth = 4; + const mantissaWidth = 4; + + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(0); + fp2.put(0); + final multiplier = FloatingPointMultiplierSimple(fp1, fp2, + multGen: (a, b, {clk, reset, enable, name = 'multiplier'}) => + CompressionTreeMultiplier(a, b, 4, name: name)); final value = Random(51); var cnt = 1000; @@ -198,7 +239,7 @@ void main() { } }); - test('FP: simple multiplier singleton', () { + test('FP: simple multiplier singleton', () async { const exponentWidth = 4; const mantissaWidth = 4; final fp1 = FloatingPoint( @@ -207,7 +248,7 @@ void main() { final fp2 = FloatingPoint( exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - final fv2 = FloatingPointValue.ofBinaryStrings('0', '1100', '0000'); + final fv2 = FloatingPointValue.ofBinaryStrings('1', '1100', '0000'); final doubleProduct = fv1.toDouble() * fv2.toDouble(); final expected = FloatingPointValue.ofDoubleUnrounded(doubleProduct, @@ -217,17 +258,271 @@ void main() { fp2.put(fv2.value); final multiply = FloatingPointMultiplierSimple(fp1, fp2); + await multiply.build(); final computed = multiply.product.floatingPointValue; expect(computed, equals(expected), reason: ''' - $fv1 (${fv1.toDouble()})\t+ + $fv1 (${fv1.toDouble()})\t* + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); + }); + test('FP: simple multiplier specify wider output', () async { + const exponentWidth = 4; + const mantissaWidth = 4; + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv1 = FloatingPointValue.ofBinaryStrings('1', '1000', '0011'); + + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = FloatingPointValue.ofBinaryStrings('1', '0001', '0001'); + + final doubleProduct = fv1.toDouble() * fv2.toDouble(); + + final fpout = + FloatingPoint(exponentWidth: 4, mantissaWidth: mantissaWidth * 5); + + final expected = FloatingPointValue.ofDoubleUnrounded(doubleProduct, + exponentWidth: fpout.exponent.width, + mantissaWidth: fpout.mantissa.width); + + fp1.put(fv1.value); + fp2.put(fv2.value); + fpout.put(0); + + final multiply = + FloatingPointMultiplierSimple(fp1, fp2, outProduct: fpout); + await multiply.build(); + final computed = multiply.product.floatingPointValue; + + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t* $fv2 (${fv2.toDouble()})\t= $computed (${computed.toDouble()})\tcomputed $expected (${expected.toDouble()})\texpected '''); }); + test('FP: simple multiplier bug wider output', () async { + const exponentWidth = 8; + const mantissaWidth = 7; + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv1 = + FloatingPointValue.ofBinaryStrings('0', '00000110', '1010000'); + + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = + FloatingPointValue.ofBinaryStrings('1', '01100010', '1110000'); + + final doubleProduct = fv1.toDouble() * fv2.toDouble(); + + final fpout = FloatingPoint(exponentWidth: 8, mantissaWidth: 14); + + final expected = FloatingPointValue.ofDoubleUnrounded(doubleProduct, + exponentWidth: fpout.exponent.width, + mantissaWidth: fpout.mantissa.width); + + fp1.put(fv1.value); + fp2.put(fv2.value); + fpout.put(0); + + final multiply = + FloatingPointMultiplierSimple(fp1, fp2, outProduct: fpout); + await multiply.build(); + final computed = multiply.product.floatingPointValue; + + expect(computed.withinRounding(expected), true, reason: ''' + $fv1 (${fv1.toDouble()})\t* + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); + }); + + test('FP: simple multiplier bf16 to fp32', () { + final a = FloatingPointBF16(); + final b = FloatingPointBF16(); + + final out = FloatingPoint32(); + a.put(FloatingPointBF16Value.ofDouble(1.2)); + b.put(FloatingPointBF16Value.ofDouble(2.1)); + + final dut = FloatingPointMultiplierSimple(a, b, outProduct: out); + + final result = dut.product; + + expect( + result.floatingPointValue, + FloatingPoint32Value.ofDouble(a.floatingPointValue.toDouble() * + b.floatingPointValue.toDouble())); + }); + + test('FP: simple multiplier wide random', () async { + const exponentWidth = 8; + const mantissaWidth = 7; + + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(0); + fp2.put(0); + + const expOutWidth = 8; + const mantOutWidth = 23; + final fpout = FloatingPoint( + exponentWidth: expOutWidth, mantissaWidth: mantOutWidth); + // ignore: cascade_invocations + fpout.put(0); + final multiplier = + FloatingPointMultiplierSimple(fp1, fp2, outProduct: fpout); + + final value = Random(51); + var cnt = 100; + while (cnt > 0) { + final fv1 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(fv1); + fp2.put(fv2); + + final expected = FloatingPointValue.ofDoubleUnrounded( + fv1.toDouble() * fv2.toDouble(), + exponentWidth: fpout.exponent.width, + mantissaWidth: fpout.mantissa.width); + final computed = multiplier.product.floatingPointValue; + // print('c=$computed e=$expected'); + expect(computed.withinRounding(expected), true, reason: ''' + $fv1 (${fv1.toDouble()})\t* + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); + cnt--; + } + }); + + test('FP: simple multiplier sweep wide random', () async { + const exponentWidth = 3; + const mantissaWidth = 3; + + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(0); + fp2.put(0); + + for (var expOutWidth = 3; expOutWidth < 5; expOutWidth++) { + for (var mantOutWidth = 3; mantOutWidth < 16; mantOutWidth += 4) { + final fpout = FloatingPoint( + exponentWidth: expOutWidth, mantissaWidth: mantOutWidth); + // ignore: cascade_invocations + fpout.put(0); + final multiplier = + FloatingPointMultiplierSimple(fp1, fp2, outProduct: fpout); + + final value = Random(51); + var cnt = 100; + while (cnt > 0) { + final fv1 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = FloatingPointValue.random(value, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + fp1.put(fv1); + fp2.put(fv2); + + final expected = FloatingPointValue.ofDoubleUnrounded( + fv1.toDouble() * fv2.toDouble(), + exponentWidth: fpout.exponent.width, + mantissaWidth: fpout.mantissa.width); + final computed = multiplier.product.floatingPointValue; + + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t* + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); + cnt--; + } + } + } + }); + + test('FP: simple multiplier singleton pipelined', () async { + final clk = SimpleClockGenerator(10).clk; + + const exponentWidth = 4; + const mantissaWidth = 4; + final fp1 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv1 = FloatingPointValue.ofBinaryStrings('0', '0111', '0000'); + + final fp2 = FloatingPoint( + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + final fv2 = FloatingPointValue.ofBinaryStrings('0', '1101', '0101'); + + final expected = FloatingPointValue.ofDoubleUnrounded( + fv1.toDouble() * fv2.toDouble(), + exponentWidth: exponentWidth, + mantissaWidth: mantissaWidth); + + fp1.put(fv1.value); + fp2.put(fv2.value); + + final multiply = FloatingPointMultiplierSimple(fp1, fp2, clk: clk); + + unawaited(Simulator.run()); + await clk.nextNegedge; + fp1.put(0); + fp2.put(0); + final computed = multiply.product.floatingPointValue; + + expect(computed, equals(expected), reason: ''' + $fv1 (${fv1.toDouble()})\t* + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expected (${expected.toDouble()})\texpected +'''); + await Simulator.endSimulation(); + }); + test('FP: simple multiplier fp32: random', () { + final fp1 = FloatingPoint32(); + final fp2 = FloatingPoint32(); + fp1.put(0); + fp2.put(0); + final dut = FloatingPointMultiplierSimple(fp1, fp2); + final value = Random(513); + for (var i = 0; i < 50; i++) { + final fv1 = FloatingPoint32Value.random(value); + final fv2 = FloatingPoint32Value.random(value); + fp1.put(fv1); + fp2.put(fv2); + final computed = dut.product.floatingPointValue; + + final expectedDouble = fp1.floatingPointValue.toDouble() * + fp2.floatingPointValue.toDouble(); + final expectedNoRound = + FloatingPoint32Value.ofDoubleUnrounded(expectedDouble); + + // If the error is due to a rounding error, then ignore + if (!computed.withinRounding(expectedNoRound)) { + expect(computed, equals(expectedNoRound), reason: ''' + $fv1 (${fv1.toDouble()})\t* + $fv2 (${fv2.toDouble()})\t= + $computed (${computed.toDouble()})\tcomputed + $expectedNoRound (${expectedNoRound.toDouble()})\texpected +'''); + } + } + }); }); - test('FP: simple multiplier singleton pipelined', () async { + test('FP: simple multiplier singleton pipelined compression-tree', () async { final clk = SimpleClockGenerator(10).clk; const exponentWidth = 4; @@ -248,7 +543,11 @@ void main() { fp1.put(fv1.value); fp2.put(fv2.value); - final multiply = FloatingPointMultiplierSimple(fp1, fp2, clk: clk); + final multiply = FloatingPointMultiplierSimple(fp1, fp2, + clk: clk, + multGen: (a, b, {clk, reset, enable, name = 'multiplier'}) => + CompressionTreeMultiplier(a, b, 4, + clk: clk, reset: reset, enable: enable, name: name)); unawaited(Simulator.run()); await clk.nextNegedge; diff --git a/test/arithmetic/floating_point/floating_point_value_test.dart b/test/arithmetic/floating_point/floating_point_value_test.dart index 9dd624aad..b6670c0eb 100644 --- a/test/arithmetic/floating_point/floating_point_value_test.dart +++ b/test/arithmetic/floating_point/floating_point_value_test.dart @@ -409,4 +409,17 @@ void main() { } } }); + test('FPV: rounding check', () async { + final fpv1 = FloatingPoint32Value.ofDouble(1); + final fpv2 = FloatingPoint32Value.ofDouble(0.5); + final fpv3 = FloatingPoint32Value.ofDoubleUnrounded( + FloatingPoint32Value.getFloatingPointConstant( + FloatingPointConstants.smallestPositiveSubnormal) + .toDouble() + + fpv1.toDouble()); + + expect(fpv1.withinRounding(fpv2), false); + expect(fpv1.withinRounding(fpv1), true); + expect(fpv1.withinRounding(fpv3), true); + }); } diff --git a/test/arithmetic/multiplier_encoder_test.dart b/test/arithmetic/multiplier_encoder_test.dart index 6ed65be82..d7fde0f07 100644 --- a/test/arithmetic/multiplier_encoder_test.dart +++ b/test/arithmetic/multiplier_encoder_test.dart @@ -168,10 +168,8 @@ void main() { final width = log2Ceil(radix) + (signedMultiplier ? 1 : 0); for (final signExtension in SignExtension.values.where((e) => e != SignExtension.none)) { - final pp = PartialProductGeneratorBasic( - Logic(name: 'X', width: width), - Logic(name: 'Y', width: width), - RadixEncoder(radix), + final pp = PartialProductGenerator(Logic(name: 'X', width: width), + Logic(name: 'Y', width: width), RadixEncoder(radix), signedMultiplicand: signedMultiplicand, signedMultiplier: signedMultiplier); currySignExtensionFunction(signExtension)(pp).signExtend(); @@ -200,7 +198,7 @@ void main() { a.put(X); b.put(Y); final PartialProductGeneratorBase pp; - pp = PartialProductGeneratorBasic(a, b, encoder, + pp = PartialProductGenerator(a, b, encoder, signedMultiplicand: signedMultiplicand, signedMultiplier: signedMultiplier); StopBitsSignExtension(pp).signExtend(); @@ -226,10 +224,8 @@ void main() { .where((e) => e != SignExtension.none)) { final width = log2Ceil(radix) + (signMultiplier ? 1 : 0); final PartialProductGeneratorBase pp; - pp = PartialProductGeneratorBasic( - Logic(name: 'X', width: width), - Logic(name: 'Y', width: width), - encoder, + pp = PartialProductGenerator(Logic(name: 'X', width: width), + Logic(name: 'Y', width: width), encoder, signedMultiplicand: signMultiplicand, signedMultiplier: signMultiplier, selectSignedMultiplicand: @@ -260,10 +256,8 @@ void main() { selectSignMultiplicand.put(selectMultiplicand ? 1 : 0); selectSignMultiplier.put(selectMultiplier ? 1 : 0); final PartialProductGeneratorBase pp; - pp = PartialProductGeneratorStopBitsSignExtension( - Logic(name: 'X', width: width), - Logic(name: 'Y', width: width), - encoder, + pp = PartialProductGenerator(Logic(name: 'X', width: width), + Logic(name: 'Y', width: width), encoder, signedMultiplicand: !selectMultiplicand, signedMultiplier: !selectMultiplier, selectSignedMultiplicand: @@ -271,6 +265,8 @@ void main() { selectSignedMultiplier: selectMultiplier ? selectSignMultiplier : null); + CompactRectSignExtension(pp).signExtend(); + const i = 6; const j = -6; final X = SignedBigInt.fromSignedInt(i, width, @@ -291,10 +287,8 @@ void main() { for (final signExtension in SignExtension.values.where((e) => e != SignExtension.none)) { // final ppg = curryPartialProductGenerator(signExtension); - final pp = PartialProductGeneratorBasic( - Logic(name: 'X', width: width), - Logic(name: 'Y', width: width), - encoder); + final pp = PartialProductGenerator(Logic(name: 'X', width: width), + Logic(name: 'Y', width: width), encoder); currySignExtensionFunction(signExtension)(pp).signExtend(); testPartialProductExhaustive(pp); @@ -316,7 +310,7 @@ void main() { SignExtension.stopBits, SignExtension.compactRect ]) { - final pp = PartialProductGeneratorBasic( + final pp = PartialProductGenerator( Logic(name: 'X', width: widthX), Logic(name: 'Y', width: widthY), encoder, @@ -339,10 +333,8 @@ void main() { for (var width = shift; width < min(5, 2 * shift); width++) { for (final signExtension in SignExtension.values.where((e) => e != SignExtension.none)) { - final pp = PartialProductGeneratorBasic( - Logic(name: 'X', width: width), - Logic(name: 'Y', width: width), - encoder); + final pp = PartialProductGenerator(Logic(name: 'X', width: width), + Logic(name: 'Y', width: width), encoder); currySignExtensionFunction(signExtension)(pp).signExtend(); testPartialProductExhaustive(pp); @@ -366,7 +358,7 @@ void main() { SignExtension.stopBits, SignExtension.compactRect ]) { - final pp = PartialProductGeneratorBasic( + final pp = PartialProductGenerator( Logic(name: 'X', width: widthX), Logic(name: 'Y', width: widthY), encoder, @@ -402,7 +394,7 @@ void main() { final skew = align.$3; - final pp = PartialProductGeneratorBasic(Logic(name: 'X', width: width), + final pp = PartialProductGenerator(Logic(name: 'X', width: width), Logic(name: 'Y', width: width + skew), encoder); CompactRectSignExtension(pp).signExtend(); @@ -421,7 +413,7 @@ void main() { final multiplicand = Logic(width: widthX); final multiplier = Logic(width: widthY); for (final signed in [false, true]) { - final ppg = PartialProductGeneratorBasic( + final ppg = PartialProductGenerator( multiplicand, multiplier, radixEncoder, signedMultiplicand: signed, signedMultiplier: signed); CompactSignExtension(ppg).signExtend(); @@ -460,7 +452,7 @@ void main() { logicX.put(X); logicY.put(Y); logicZ.put(Z); - final pp = PartialProductGeneratorBasic(logicX, logicY, encoder, + final pp = PartialProductGenerator(logicX, logicY, encoder, signedMultiplicand: true, signedMultiplier: true); CompactRectSignExtension(pp).signExtend(); @@ -510,7 +502,7 @@ void main() { logicX.put(X); logicY.put(Y); logicZ.put(Z); - final pp = PartialProductGeneratorBasic(logicX, logicY, encoder, + final pp = PartialProductGenerator(logicX, logicY, encoder, signedMultiplicand: true, signedMultiplier: true); CompactRectSignExtension(pp).signExtend(); diff --git a/test/arithmetic/multiplier_test.dart b/test/arithmetic/multiplier_test.dart index ad7ad75e5..909e73c45 100644 --- a/test/arithmetic/multiplier_test.dart +++ b/test/arithmetic/multiplier_test.dart @@ -67,6 +67,7 @@ class SimpleMultiplier extends Multiplier { : super(a, b) { addOutput('product', width: a.width + b.width); final mult = CompressionTreeMultiplier(a, b, 4, + adderGen: ParallelPrefixAdder.new, selectSignedMultiplicand: selSignedMultiplicand, selectSignedMultiplier: selSignedMultiplier); product <= mult.product; @@ -187,14 +188,16 @@ void main() { }); MultiplierCallback curryCompressionTreeMultiplier(int radix, - ParallelPrefix Function(List, Logic Function(Logic, Logic)) ppTree, {SignExtensionFunction seGen = CompactRectSignExtension.new, + Adder Function(Logic a, Logic b, {Logic? carryIn, String name}) adderGen = + NativeAdder.new, bool signedMultiplicand = false, bool signedMultiplier = false, Logic? selectSignedMultiplicand, Logic? selectSignedMultiplier}) { + String adderName(Logic a, Logic b) => adderGen(a, b).name; String genName(Logic a, Logic b) => - seGen(PartialProductGeneratorBasic(a, b, RadixEncoder(radix))).name; + seGen(PartialProductGenerator(a, b, RadixEncoder(radix))).name; final signage = ' SD=${signedMultiplicand ? 1 : 0}' ' SM=${signedMultiplier ? 1 : 0}' ' SelD=${(selectSignedMultiplicand != null) ? 1 : 0}' @@ -205,15 +208,17 @@ void main() { signedMultiplier: signedMultiplier, selectSignedMultiplicand: selectSignedMultiplicand, selectSignedMultiplier: selectSignedMultiplier, + seGen: seGen, + adderGen: adderGen, name: 'Compression Tree Multiplier: ' - '${ppTree([Logic()], (a, b) => Logic()).name}' + '${adderName(a, b)}' '$signage R${radix}_E${genName(a, b)}'); } MultiplyAccumulateCallback curryMultiplierAsMultiplyAccumulate(int radix, - {ParallelPrefix Function(List, Logic Function(Logic, Logic)) - ppTree = KoggeStone.new, - SignExtensionFunction seGen = CompactRectSignExtension.new, + {SignExtensionFunction seGen = CompactRectSignExtension.new, + Adder Function(Logic a, Logic b, {Logic? carryIn, String name}) + adderGen = NativeAdder.new, bool signedMultiplicand = false, bool signedMultiplier = false, Logic? selectSignedMultiplicand, @@ -228,7 +233,7 @@ void main() { selectSignedMultiplier: selectSignedMultiplier, curryCompressionTreeMultiplier( radix, - ppTree, + adderGen: adderGen, seGen: seGen, signedMultiplicand: signedMultiplicand, signedMultiplier: signedMultiplier, @@ -238,8 +243,8 @@ void main() { MultiplyAccumulateCallback curryMultiplyAccumulate( int radix, { - ParallelPrefix Function(List, Logic Function(Logic, Logic)) ppTree = - KoggeStone.new, + Adder Function(Logic a, Logic b, {Logic? carryIn, String name}) adderGen = + NativeAdder.new, SignExtensionFunction seGen = CompactRectSignExtension.new, bool signedMultiplicand = false, bool signedMultiplier = false, @@ -249,31 +254,106 @@ void main() { Logic? selectSignedAddend, }) { String genName(Logic a, Logic b) => - seGen(PartialProductGeneratorBasic(a, b, RadixEncoder(radix))).name; + seGen(PartialProductGenerator(a, b, RadixEncoder(radix))).name; final signage = ' SD=${signedMultiplicand ? 1 : 0}' ' SM=${signedMultiplier ? 1 : 0}' ' SelD=${(selectSignedMultiplicand != null) ? 1 : 0}' ' SelM=${(selectSignedMultiplier != null) ? 1 : 0}'; return (a, b, c) => CompressionTreeMultiplyAccumulate(a, b, c, radix, + adderGen: adderGen, + seGen: seGen, signedMultiplicand: signedMultiplicand, signedMultiplier: signedMultiplier, signedAddend: signedAddend, selectSignedMultiplicand: selectSignedMultiplicand, selectSignedMultiplier: selectSignedMultiplier, selectSignedAddend: selectSignedAddend, - name: 'Compression Tree MAC: ${ppTree.call([ - Logic() - ], (a, b) => Logic()).name}' + name: 'Compression Tree MAC: ' ' $signage R$radix E${genName(a, b)}'); } + test('Native multiplier sweep with signage test', () async { + const width = 5; + final a = Logic(width: width); + final b = Logic(width: width); + + for (final selectSignedMultiplicand in [null, Const(0), Const(1)]) { + for (final signedMultiplicand + in (selectSignedMultiplicand == null) ? [false, true] : [false]) { + for (final selectSignedMultiplier in [null, Const(0), Const(1)]) { + for (final signedMultiplier + in (selectSignedMultiplier == null) ? [false, true] : [false]) { + final mod = NativeMultiplier(a, b, + signedMultiplicand: signedMultiplicand, + signedMultiplier: signedMultiplier); + for (var i = 0; i < pow(2, width); i++) { + for (var j = 0; j < pow(2, width); j++) { + final ai = signedMultiplicand + ? BigInt.from(i).toSigned(width) + : BigInt.from(i).toUnsigned(width); + final bi = signedMultiplier + ? BigInt.from(j).toSigned(width) + : BigInt.from(j).toUnsigned(width); + a.put(ai); + b.put(bi); + final expected = ai * bi; + final product = mod.isSignedResult() + ? mod.product.value.toBigInt().toSigned(width * 2) + : mod.product.value.toBigInt(); + expect(product, equals(expected)); + } + } + } + } + } + } + }); + +// TODO(desmonddak): must set variables in the enclosing +// module, so we can't really curry +// unless the enclosing module reads them off +// the passed in multiplier. + group('Native multiplier check', () { + for (final selectSignedMultiplicand in [null, Const(0), Const(1)]) { + // for (final selectSignedMultiplicand in [null]) { + for (final signedMultiplicand + in (selectSignedMultiplicand == null) ? [false, true] : [false]) { + for (final selectSignedMultiplier in [null, Const(0), Const(1)]) { + // for (final selectSignedMultiplier in [null]) { + for (final signedMultiplier + in (selectSignedMultiplier == null) ? [false, true] : [false]) { + testMultiplyAccumulateExhaustive( + 5, + (a, b, c) => MultiplyOnly( + a, + b, + c, + signedMultiplier: signedMultiplier, + signedMultiplicand: signedMultiplicand, + selectSignedMultiplicand: selectSignedMultiplicand, + selectSignedMultiplier: selectSignedMultiplier, + (a, b, + {selectSignedMultiplicand, + selectSignedMultiplier}) => + NativeMultiplier(a, b, + signedMultiplicand: signedMultiplicand, + signedMultiplier: signedMultiplier, + selectSignedMultiplicand: selectSignedMultiplicand, + selectSignedMultiplier: selectSignedMultiplier))); + } + } + } + } + }); group('Compression Tree Multiplier: curried random radix/ptree/width', () { for (final radix in [2, 4]) { for (final width in [3, 4]) { for (final ppTree in [KoggeStone.new, BrentKung.new, Sklansky.new]) { + Adder adderFn(Logic a, Logic b, {Logic? carryIn, String? name}) => + ParallelPrefixAdder(a, b, carryIn: carryIn, ppGen: ppTree); testMultiplyAccumulateRandom(width, 10, - curryMultiplierAsMultiplyAccumulate(radix, ppTree: ppTree)); + curryMultiplierAsMultiplyAccumulate(radix, adderGen: adderFn)); } } } @@ -285,8 +365,11 @@ void main() { for (final signExtension in SignExtension.values.where((e) => e != SignExtension.none)) { final seg = currySignExtensionFunction(signExtension); - testMultiplyAccumulateRandom(width, 10, - curryMultiplierAsMultiplyAccumulate(radix, seGen: seg)); + testMultiplyAccumulateRandom( + width, + 10, + curryMultiplierAsMultiplyAccumulate(radix, + adderGen: ParallelPrefixAdder.new, seGen: seg)); } } } @@ -305,6 +388,7 @@ void main() { width, 10, curryMultiplierAsMultiplyAccumulate(radix, + adderGen: ParallelPrefixAdder.new, signedMultiplicand: signedMultiplicand, signedMultiplier: signedMultiplier, selectSignedMultiplicand: selectSignedMultiplicand, @@ -333,6 +417,7 @@ void main() { width, 10, curryMultiplyAccumulate(radix, + adderGen: ParallelPrefixAdder.new, signedMultiplicand: signedMultiplicand, signedMultiplier: signedMultiplier, signedAddend: signedAddend, @@ -360,6 +445,7 @@ void main() { final bB = BigInt.from(-10).toSigned(width); final mod = CompressionTreeMultiplier(a, b, 4, clk: clk, + adderGen: ParallelPrefixAdder.new, selectSignedMultiplicand: signedSelect, selectSignedMultiplier: signedSelect); unawaited(Simulator.run()); @@ -392,6 +478,7 @@ void main() { final mod = CompressionTreeMultiplyAccumulate(a, b, c, 4, clk: clk, + adderGen: ParallelPrefixAdder.new, selectSignedMultiplicand: signedSelect, selectSignedMultiplier: signedSelect, selectSignedAddend: signedSelect); @@ -418,8 +505,8 @@ void main() { final b = Logic(name: 'b', width: width); const av = 12; const bv = 13; - for (final signed in [true, false]) { - for (final useSignedLogic in [true, false]) { + for (final signed in [true]) { + for (final useSignedLogic in [true]) { final bA = SignedBigInt.fromSignedInt(av, width, signed: signed); final bB = SignedBigInt.fromSignedInt(bv, width, signed: signed); @@ -436,11 +523,12 @@ void main() { b.put(bB); final mod = CompressionTreeMultiplier(a, b, 4, + adderGen: ParallelPrefixAdder.new, + seGen: StopBitsSignExtension.new, signedMultiplier: !useSignedLogic && signed, selectSignedMultiplicand: signedSelect, selectSignedMultiplier: signedSelect); await mod.build(); - mod.generateSynth(); final golden = bA * bB; final result = mod.isSignedResult() ? mod.product.value.toBigInt().toSigned(mod.product.width) @@ -594,9 +682,9 @@ void main() { a.put(6); b.put(3); - final ppG0 = PartialProductGeneratorCompactRectSignExtension( - a, b, RadixEncoder(4), + final ppG0 = PartialProductGenerator(a, b, RadixEncoder(4), signedMultiplicand: true, signedMultiplier: true); + CompactRectSignExtension(ppG0).signExtend(); final bit_0_5 = ppG0.getAbsolute(0, 5); expect(bit_0_5.value, equals(LogicValue.one));