Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed lossy #172

Merged
merged 9 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions doc/components/fixed_point.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Fixed-Point Arithmetic

Fixed-point binary representation of numbers is useful several applications including digital signal processing and embedded systems. As a first step towards enabling fixed-point components, we created a new value system [FixedPointValue](https://intel.github.io/rohd-hcl/rohd_hcl/FixedPointValue-class.html) similar to [LogicValue](https://intel.github.io/rohd/rohd/LogicValue-class.html).
Fixed-point binary representation of numbers is useful several applications including digital signal processing and embedded systems. As a first step towards enabling fixed-point components, we created a new value system [FixedPointValue](https://intel.github.io/rohd-hcl/rohd_hcl/FixedPointValue-class.html) similar to [LogicValue](https://intel.github.io/rohd/rohd/LogicValue-class.html).

## FixedPointValue

Expand All @@ -12,12 +12,14 @@ The [FixedPoint](https://intel.github.io/rohd-hcl/rohd_hcl/FixedPoint-class.html

## FixedToFloat

This component converts a fixed-point signal to a floating point signal specified by exponent and mantissa width. The output is rounded to nearest even when applicable and set to infinity if the input exceed the representable range.
This component converts a fixed-point signal to a floating point signal specified by exponent and mantissa width. The output is rounded to the nearest even (RNE) when applicable and set to infinity if the input exceed the representable range.

## FloatToFixed

This component converts a floating-point signal to a signed fixed-point signal. Infinities and NaN's are not supported. The integer and fraction widths are auto-calculated to achieve lossles conversion.
This component converts a floating-point signal to a signed fixed-point signal. Infinities and NaN's are not supported. The integer and fraction widths are auto-calculated to achieve lossless conversion.

If the `m` and `n` integer and fraction widths are supplied, then lossy conversion is performed to fit the floating-point value into the fixed-point value. For testing, [FixedPointValue] has a `canStore` method to predetermine if a given double can fit. For execution, [FloatToFixed] can perform overflow detection by setting a `checkOverflow` option.

## Float8ToFixed

This component converts an 8-bit floating-point (FP8) representation ([FloatingPoint8E4M3Value](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint8E4M3Value-class.html) or [FloatingPoint8E5M2Value](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint8E5M2Value-class.html)) to a signed fixed-point representation. This component offers using the same hardware for both FP8 formats. Therefore, both input and output are of type [Logic](https://intel.github.io/rohd/rohd/Logic-class.html) and can be cast from/to floating point/fixed point by the producer/consumer based on the selected `mode`. Infinities and NaN's are not supported. The output width is 33bits to accomodate [FloatingPoint8E5M2Value](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint8E5M2Value-class.html) without loss.
This component converts an 8-bit floating-point (FP8) representation ([FloatingPoint8E4M3Value](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint8E4M3Value-class.html) or [FloatingPoint8E5M2Value](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint8E5M2Value-class.html)) to a signed fixed-point representation. This component offers using the same hardware for both FP8 formats. Therefore, both input and output are of type [Logic](https://intel.github.io/rohd/rohd/Logic-class.html) and can be cast from/to floating point/fixed point by the producer/consumer based on the selected `mode`. Infinities and NaN's are not supported. The output width is 33bits to accommodate [FloatingPoint8E5M2Value](https://intel.github.io/rohd-hcl/rohd_hcl/FloatingPoint8E5M2Value-class.html) without loss.
2 changes: 2 additions & 0 deletions lib/src/arithmetic/fixed_to_float.dart
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class FixedToFloat extends Module {
.zeroExtend(iWidth)
.named('jBit');

// TODO(desmonddak): refactor to use the roundRNE component

// Extract mantissa
final mantissa = Logic(name: 'mantissa', width: mantissaWidth);
final guard = Logic(name: 'guardBit');
Expand Down
101 changes: 81 additions & 20 deletions lib/src/arithmetic/float_to_fixed.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
// 2024 November 1
// Author: Soner Yaldiz <[email protected]>

import 'dart:math';

import 'package:rohd/rohd.dart';
import 'package:rohd_hcl/rohd_hcl.dart';

Expand All @@ -24,41 +26,100 @@ class FloatToFixed extends Module {
/// Width of output fractional part.
late final int n;

/// Return true if the conversion overflowed.
Logic? get overflow => tryOutput('overflow');

/// Internal representation of the output port
late final FixedPoint _fixed = FixedPoint(signed: true, m: m, n: n);

/// Output fixed point port
late final FixedPoint fixed = _fixed.clone()..gets(output('fixed'));

/// Constructor
FloatToFixed(FloatingPoint float, {super.name = 'FloatToFixed'})
/// Build a [FloatingPoint] to [FixedPoint] converter.
/// - if [m] and [n] are supplied, an m.n fixed-point output will be produced.
/// Otherwise, the converter will compute a lossless size for [m] and [n] for
/// outputing the floating-point value into a fixed-point value.
/// - [checkOverflow] set to true will cause overflow detection to happen in
/// case that loss can occur and an optional output [overflow] will be
/// produced that returns true when overflow occurs.
FloatToFixed(FloatingPoint float,
{super.name = 'FloatToFixed', int? m, int? n, bool checkOverflow = false})
: super(
definitionName: 'FloatE${float.exponent.width}'
'M${float.mantissa.width}ToFixed') {
float = float.clone()..gets(addInput('float', float, width: float.width));

final bias = FloatingPointValue.computeBias(float.exponent.width);
// E4M3 expands the max exponent by 1.
m = ((float.exponent.width == 4) & (float.mantissa.width == 3))
? bias + 1
: bias;
n = bias + float.mantissa.width - 1;
final outputWidth = m + n + 1;
final noLossM = ((float.exponent.width == 4) & (float.mantissa.width == 3))
? bias + 2
: bias + 1; // accomodate the jbit
final noLossN = bias + float.mantissa.width - 1;

this.m = m ?? noLossM;
this.n = n ?? noLossN;
final outputWidth = this.m + this.n + 1;

final jBit = Logic(name: 'jBit')..gets(float.isNormal);
final shift = Logic(name: 'shift', width: float.exponent.width)
..gets(
mux(jBit, float.exponent - 1, Const(0, width: float.exponent.width)));

final number = Logic(name: 'number', width: outputWidth)
..gets([
Const(0, width: outputWidth - float.mantissa.width - 1),
jBit,
float.mantissa
].swizzle() <<
shift);

_fixed <= mux(float.sign, ~number + 1, number);
final fullMantissa = [jBit, float.mantissa].swizzle().named('fullMantissa');

final eWidth = max(log2Ceil(this.n + this.m), float.exponent.width) + 2;
final shift = Logic(name: 'shift', width: eWidth);
final exp = (float.exponent - 1).zeroExtend(eWidth).named('expMinus1');

if (this.n > noLossN) {
shift <=
mux(jBit, exp, Const(0, width: eWidth)) +
Const(this.n - noLossN, width: eWidth).named('deltaN');
} else if (this.n == noLossN) {
shift <= mux(jBit, exp, Const(0, width: eWidth));
} else {
shift <=
mux(jBit, exp, Const(0, width: eWidth)) -
Const(noLossN - this.n, width: eWidth).named('deltaN');
}
// TODO(desmonddak): Could use signed shifter if we unified shift math
final shiftRight = ((fullMantissa.width > outputWidth)
? (~shift + 1) - (fullMantissa.width - outputWidth)
: (~shift + 1))
.named('shiftRight');

if (checkOverflow & ((this.m < noLossM) | (this.n < noLossN))) {
final overFlow = Logic(name: 'overflow');
final leadDetect = ParallelPrefixPriorityEncoder(fullMantissa.reversed,
name: 'leadone_detector');

final sWidth = max(eWidth, leadDetect.out.width);
final fShift = shift.zeroExtend(sWidth).named('wideShift');
final leadOne = leadDetect.out
.named('leadOneRaw')
.zeroExtend(sWidth)
.named('leadOne');

Combinational([
If(jBit, then: [
overFlow < shift.gte(outputWidth - float.mantissa.width - 1),
], orElse: [
If(fShift.gt(leadOne), then: [
overFlow <
(fShift - leadOne).gte(outputWidth - float.mantissa.width - 1),
], orElse: [
overFlow < Const(0),
]),
]),
]);
addOutput('overflow') <= overFlow;
}
final preNumber = ((outputWidth >= fullMantissa.width)
? fullMantissa.zeroExtend(outputWidth)
: fullMantissa.slice(-1, fullMantissa.width - outputWidth))
.named('newMantissaPreShift');
// TODO(desmonddak): Rounder is needed when shifting right

final number = mux(shift[-1], preNumber >>> shiftRight, preNumber << shift)
.named('number');

_fixed <= mux(float.sign, (~number + 1).named('negNumber'), number);
addOutput('fixed', width: outputWidth) <= _fixed;
}
}
Expand Down
4 changes: 4 additions & 0 deletions lib/src/arithmetic/signals/fixed_point_logic.dart
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class FixedPoint extends Logic {
}
}

/// Retrieve the [FixedPointValue] of this [FixedPoint] logical signal.
FixedPointValue get fixedPointValue =>
FixedPointValue(value: value, signed: signed, m: m, n: n);

/// Clone for I/O ports.
@override
FixedPoint clone({String? name}) => FixedPoint(signed: signed, m: m, n: n);
Expand Down
42 changes: 41 additions & 1 deletion lib/src/arithmetic/values/fixed_point_value.dart
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,58 @@ class FixedPointValue implements Comparable<FixedPointValue> {
return compareTo(other) == 0;
}

/// Constructs [FixedPointValue] of a Dart [double] rounding away from zero.
/// Return a string representation of FloatingPointValue.
/// return sign, exponent, mantissa as binary strings.
@override
String toString() =>
"(${signed ? '${value[-1].toString(includeWidth: false)} ' : ''}"
"${(m > 0) ? '${value.slice(m + n - 1, n).bitString} ' : ''}"
'${value.slice(n - 1, 0).toString(includeWidth: false)})';

/// Return true if double [val] be stored in FixedPointValue with [m] and [n]
/// lengths.
static bool canStore(double val,
{required bool signed, required int m, required int n}) {
final w = signed ? 1 + m + n : m + n;
if (val.isFinite) {
final bigIntegerValue = BigInt.from(val * pow(2, n));
final negBigIntegerValue = BigInt.from(-val * pow(2, n));
final l = (val < 0.0)
? max(bigIntegerValue.bitLength, negBigIntegerValue.bitLength)
: bigIntegerValue.bitLength;
return l <= w;
}
return false;
}

/// Constructs [FixedPointValue] from a Dart [double] rounding away from zero.
factory FixedPointValue.ofDouble(double val,
{required bool signed, required int m, required int n}) {
if (!signed & (val < 0)) {
throw RohdHclException('Negative input not allowed with unsigned');
}
if (!canStore(val, signed: signed, m: m, n: n)) {
throw RohdHclException('Double is too long to store in '
'FixedPointValue: $m, $n');
}
final integerValue = (val * pow(2, n)).toInt();
final w = signed ? 1 + m + n : m + n;
final v = LogicValue.ofInt(integerValue, w);
return FixedPointValue(value: v, signed: signed, m: m, n: n);
}

/// Constructs [FixedPointValue] from a Dart [double] without rounding.
factory FixedPointValue.ofDoubleUnrounded(double val,
{required bool signed, required int m, required int n}) {
if (!signed & (val < 0)) {
throw RohdHclException('Negative input not allowed with unsigned');
}
final integerValue = (val * pow(2, n + 1)).toInt();
final w = signed ? 1 + m + n : m + n;
final v = LogicValue.ofInt(integerValue >> 1, w);
return FixedPointValue(value: v, signed: signed, m: m, n: n);
}

/// Converts a fixed-point value to a Dart [double].
double toDouble() {
if (m + n > 52) {
Expand Down
30 changes: 28 additions & 2 deletions test/arithmetic/fixed_to_float_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,34 @@ void main() async {
reason: 'mantissa mismatch');
});

test('FixedToFloat: exhaustive', () async {
final fixed = FixedPoint(signed: true, m: 8, n: 8);
final dut = FixedToFloat(fixed, exponentWidth: 8, mantissaWidth: 16);
await dut.build();
for (var val = 0; val < pow(2, fixed.width); val++) {
final fixedValue = FixedPointValue(
value: LogicValue.ofInt(val, fixed.width),
signed: true,
m: fixed.m,
n: fixed.n);
fixed.put(fixedValue);
final fpv = dut.float.floatingPointValue;
final fpvExpected = FloatingPointValue.ofDouble(fixedValue.toDouble(),
exponentWidth: dut.exponentWidth, mantissaWidth: dut.mantissaWidth);
final newFixed = FixedPointValue.ofDouble(fpv.toDouble(),
signed: true, m: fixed.m, n: fixed.n);
expect(newFixed, equals(fixedValue), reason: '''
fpvdbl=${fpv.toDouble()} $fpv
${newFixed.toDouble()} $newFixed
${fixedValue.toDouble()} $fixedValue
${fixed.fixedPointValue.toDouble()} ${fixed.fixedPointValue}
''');
expect(fpv.sign, fpvExpected.sign);
expect(fpv.exponent, fpvExpected.exponent, reason: 'exponent');
expect(fpv.mantissa, fpvExpected.mantissa, reason: 'mantissa');
}
});

test('Q16.16 to E5M2 < pow(2,14)', () async {
final fixed = FixedPoint(signed: true, m: 16, n: 16);
final dut = FixedToFloat(fixed, exponentWidth: 5, mantissaWidth: 2);
Expand Down Expand Up @@ -139,8 +167,6 @@ void main() async {
}
});

// TODO(desmonddak): complete this test as now
// FloatingPointValue.ofDouble handles infinities.
test('Signed Q7.0 to E3M2', () async {
final fixed = FixedPoint(signed: true, m: 7, n: 0);
final dut = FixedToFloat(fixed, exponentWidth: 3, mantissaWidth: 2);
Expand Down
Loading