Skip to content

Commit

Permalink
Replace E4M3-to-float with a lookup table for a 20x speed boost. (#76)
Browse files Browse the repository at this point in the history
* Replace E4M3-to-float with a lookup table for speed.

* Cut down the lookup table size.

* Lay out lookup table more efficiently.

* Inline E4M3::operator float().
  • Loading branch information
psobot authored Jul 27, 2024
1 parent 177d12c commit 6580a4b
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 16 deletions.
274 changes: 267 additions & 7 deletions cpp/E4M3.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,275 @@

#include <cmath>

static constexpr float ALL_E4M3_VALUES[256] = {
0,
-0,
0.015625,
-0.015625,
0.031250,
-0.031250,
0.0625,
-0.0625,
0.1250,
-0.1250,
0.2500,
-0.2500,
0.5000,
-0.5000,
1,
-1,
2,
-2,
4,
-4,
8,
-8,
16,
-16,
32,
-32,
64,
-64,
128,
-128,
256,
-256,
0.0009765625,
-0.0009765625,
0.0175781250,
-0.0175781250,
0.03515625,
-0.03515625,
0.07031250,
-0.07031250,
0.140625,
-0.140625,
0.281250,
-0.281250,
0.5625,
-0.5625,
1.1250,
-1.1250,
2.2500,
-2.2500,
4.5000,
-4.5000,
9,
-9,
18,
-18,
36,
-36,
72,
-72,
144,
-144,
288,
-288,
0.0019531250,
-0.0019531250,
0.01953125,
-0.01953125,
0.03906250,
-0.03906250,
0.078125,
-0.078125,
0.156250,
-0.156250,
0.3125,
-0.3125,
0.6250,
-0.6250,
1.2500,
-1.2500,
2.5000,
-2.5000,
5,
-5,
10,
-10,
20,
-20,
40,
-40,
80,
-80,
160,
-160,
320,
-320,
0.0029296875,
-0.0029296875,
0.0214843750,
-0.0214843750,
0.04296875,
-0.04296875,
0.08593750,
-0.08593750,
0.171875,
-0.171875,
0.343750,
-0.343750,
0.6875,
-0.6875,
1.3750,
-1.3750,
2.7500,
-2.7500,
5.5000,
-5.5000,
11,
-11,
22,
-22,
44,
-44,
88,
-88,
176,
-176,
352,
-352,
0.00390625,
-0.00390625,
0.02343750,
-0.02343750,
0.046875,
-0.046875,
0.093750,
-0.093750,
0.1875,
-0.1875,
0.3750,
-0.3750,
0.7500,
-0.7500,
1.5000,
-1.5000,
3,
-3,
6,
-6,
12,
-12,
24,
-24,
48,
-48,
96,
-96,
192,
-192,
384,
-384,
0.0048828125,
-0.0048828125,
0.0253906250,
-0.0253906250,
0.05078125,
-0.05078125,
0.10156250,
-0.10156250,
0.203125,
-0.203125,
0.406250,
-0.406250,
0.8125,
-0.8125,
1.6250,
-1.6250,
3.2500,
-3.2500,
6.5000,
-6.5000,
13,
-13,
26,
-26,
52,
-52,
104,
-104,
208,
-208,
416,
-416,
0.0058593750,
-0.0058593750,
0.02734375,
-0.02734375,
0.05468750,
-0.05468750,
0.109375,
-0.109375,
0.218750,
-0.218750,
0.4375,
-0.4375,
0.8750,
-0.8750,
1.7500,
-1.7500,
3.5000,
-3.5000,
7,
-7,
14,
-14,
28,
-28,
56,
-56,
112,
-112,
224,
-224,
448,
-448,
0.0068359375,
-0.0068359375,
0.0292968750,
-0.0292968750,
0.05859375,
-0.05859375,
0.11718750,
-0.11718750,
0.234375,
-0.234375,
0.468750,
-0.468750,
0.9375,
-0.9375,
1.8750,
-1.8750,
3.7500,
-3.7500,
7.5000,
-7.5000,
15,
-15,
30,
-30,
60,
-60,
120,
-120,
240,
-240,
NAN,
NAN,
};

/**
* An 8-bit floating point format with a 4-bit exponent and 3-bit mantissa.
* Inspired by: https://arxiv.org/pdf/2209.05433.pdf
*/
class E4M3 {
public:
uint8_t sign : 1, exponent : 4, mantissa : 3;
// Note: This actually ends up laid out in a byte as: 0bMMMEEEES
uint8_t sign : 1;
uint8_t exponent : 4;
uint8_t mantissa : 3;

E4M3() : E4M3(0, 0, 0) {}

Expand Down Expand Up @@ -138,12 +400,10 @@ class E4M3 {
}
}

operator float() const {
if (exponent == 0b1111 && mantissa == 0b111) {
return NAN;
}

return (sign ? -1 : 1) * powf(2, effectiveExponent()) * effectiveMantissa();
inline operator float() const {
// This is implemented with a 512-byte lookup table for speed.
// Note that the Python tests ensure that this matches the expected logic.
return ALL_E4M3_VALUES[*(const uint8_t *)this];
}

int8_t effectiveExponent() const { return -7 + exponent; }
Expand Down
7 changes: 3 additions & 4 deletions python/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,13 @@ memory usage and index size by a factor of 4 compared to :py:class:`Float32`.
.def_property_readonly(
"raw_exponent", [](E4M3 &self) { return self.exponent; },
"The raw value of the exponent part of this E4M3 number, expressed "
"as "
"an integer.")
"as an integer.")
.def_property_readonly(
"mantissa", [](E4M3 &self) { return self.mantissa; },
"mantissa", [](E4M3 &self) { return self.effectiveMantissa(); },
"The effective mantissa (non-exponent part) of this E4M3 number, "
"expressed as an integer.")
.def_property_readonly(
"raw_mantissa", [](E4M3 &self) { return self.effectiveMantissa(); },
"raw_mantissa", [](E4M3 &self) { return self.mantissa; },
"The raw value of the mantissa (non-exponent part) of this E4M3 "
"number, expressed as a floating point value.")
.def_property_readonly(
Expand Down
20 changes: 15 additions & 5 deletions python/tests/test_e4m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import math

import numpy as np

import pytest
from voyager import E4M3T, Index, Space, StorageDataType


RANGES_AND_EXPECTED_ERRORS = [
# (Min value, max value, step), expected maximum error
((-448, 448, 1), 16),
Expand All @@ -47,6 +46,17 @@ def test_range():
assert roundtrip_value > 0


@pytest.mark.parametrize("_input", list(range(256)))
def test_to_float(_input: int):
wrapper = E4M3T.from_char(_input)

if wrapper.raw_exponent == 0b1111 and wrapper.raw_mantissa == 0b111:
assert np.isnan(float(wrapper))
else:
expected = (-1 if wrapper.sign else 1) * math.pow(2, wrapper.exponent) * wrapper.mantissa
assert float(wrapper) == expected, f"Expected {expected}, but got {wrapper}"


@pytest.mark.parametrize("_input", list(range(256)))
def test_rounding_exact(_input: int):
expected = E4M3T.from_char(_input)
Expand Down Expand Up @@ -85,8 +95,8 @@ def test_rounding():
expected = min([closest_above, closest_below], key=lambda v: abs(v - _input))
if closest_above != closest_below and abs(closest_above - _input) == abs(closest_below - _input):
# Round to nearest, ties to even:
above_is_even = E4M3T(closest_above).mantissa % 2 == 0
below_is_even = E4M3T(closest_below).mantissa % 2 == 0
above_is_even = E4M3T(closest_above).raw_mantissa % 2 == 0
below_is_even = E4M3T(closest_below).raw_mantissa % 2 == 0
if above_is_even and below_is_even:
raise NotImplementedError(
"Both numbers above and below the target are even!"
Expand Down

0 comments on commit 6580a4b

Please sign in to comment.