Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce IEEE P3109 dtypes #122

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: debug-statements
- repo: https://github.com/google/pyink
rev: 23.3.1
rev: 23.10.0
hooks:
- id: pyink
language_version: python3.9
Expand Down
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* `float8_e4m3fnuz`
* `float8_e5m2`
* `float8_e5m2fnuz`
* `float8_p3109_p<p>`
- `int4` and `uint4`: low precision integer types.

See below for specifications of these number formats.
Expand Down Expand Up @@ -107,6 +108,20 @@ This type has the following characteristics:
* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s - `0b10000000`
* denormals when exponent is 0

### float8_p3109_p<p>

These types represent the types under discussion in IEEE working group P3109,
"Arithmetic Formats for Machine Learning ", parameterized by precision $p$.

These type has the following characteristics:
* Precision $p$: $2 < p < 6$
* Exponent bits, E: $8-p$
* Exponent bias: 2 ^ (E-1)
* Infinities: +Inf, -Inf
* No negative zero
* Single NaN in the -0 position: `0b10000000` == `0x80`
* Denormals when exponent is 0

## `int4` and `uint4`

4-bit integer types, where each element is represented unpacked (i.e., padded up
Expand Down
33 changes: 21 additions & 12 deletions ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = '0.3.1' # Keep in sync with pyproject.toml:version
__version__ = "0.3.1" # Keep in sync with pyproject.toml:version
__all__ = [
'__version__',
'bfloat16',
'finfo',
'float8_e4m3b11fnuz',
'float8_e4m3fn',
'float8_e4m3fnuz',
'float8_e5m2',
'float8_e5m2fnuz',
'iinfo',
'int4',
'uint4',
"__version__",
"bfloat16",
"finfo",
"float8_e4m3b11fnuz",
"float8_e4m3fn",
"float8_e4m3fnuz",
"float8_e5m2",
"float8_e5m2fnuz",
"float8_p3109_p3",
"float8_p3109_p4",
"float8_p3109_p5",
"iinfo",
"int4",
"uint4",
]

from typing import Type
Expand All @@ -37,6 +40,9 @@
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_ext import float8_e5m2
from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz
from ml_dtypes._ml_dtypes_ext import float8_p3109_p3
from ml_dtypes._ml_dtypes_ext import float8_p3109_p4
from ml_dtypes._ml_dtypes_ext import float8_p3109_p5
from ml_dtypes._ml_dtypes_ext import int4
from ml_dtypes._ml_dtypes_ext import uint4
import numpy as np
Expand All @@ -47,6 +53,9 @@
float8_e4m3fnuz: Type[np.generic]
float8_e5m2: Type[np.generic]
float8_e5m2fnuz: Type[np.generic]
float8_p3109_p3: Type[np.generic]
float8_p3109_p4: Type[np.generic]
float8_p3109_p5: Type[np.generic]
int4: Type[np.generic]
uint4: Type[np.generic]

Expand Down
190 changes: 141 additions & 49 deletions ml_dtypes/_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_ext import float8_e5m2
from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz
from ml_dtypes._ml_dtypes_ext import float8_p3109_p3
from ml_dtypes._ml_dtypes_ext import float8_p3109_p4
from ml_dtypes._ml_dtypes_ext import float8_p3109_p5

import numpy as np

_bfloat16_dtype = np.dtype(bfloat16)
Expand All @@ -30,6 +34,9 @@
_float8_e4m3fnuz_dtype = np.dtype(float8_e4m3fnuz)
_float8_e5m2_dtype = np.dtype(float8_e5m2)
_float8_e5m2fnuz_dtype = np.dtype(float8_e5m2fnuz)
_float8_p3109_p3_dtype = np.dtype(float8_p3109_p3)
_float8_p3109_p4_dtype = np.dtype(float8_p3109_p4)
_float8_p3109_p5_dtype = np.dtype(float8_p3109_p5)


class _Bfloat16MachArLike:
Expand Down Expand Up @@ -86,6 +93,29 @@ def __init__(self):
self.smallest_subnormal = float8_e5m2fnuz(smallest_subnormal)


class _Float8IEEEMachArLike:

def __init__(self, p):
# These are hard-coded in order to independently test against the computed values in the C++ implementation
if p == 3:
smallest_normal = float.fromhex("0x1p-15")
self.smallest_normal = float8_p3109_p3(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-17")
self.smallest_subnormal = float8_p3109_p3(smallest_subnormal)

if p == 4:
smallest_normal = float.fromhex("0x1p-7")
self.smallest_normal = float8_p3109_p4(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-10")
self.smallest_subnormal = float8_p3109_p4(smallest_subnormal)

if p == 5:
smallest_normal = float.fromhex("0x1p-3")
self.smallest_normal = float8_p3109_p5(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-7")
self.smallest_subnormal = float8_p3109_p5(smallest_subnormal)


class finfo(np.finfo): # pylint: disable=invalid-name,missing-class-docstring
__doc__ = np.finfo.__doc__
_finfo_cache: Dict[np.dtype, np.finfo] = {}
Expand Down Expand Up @@ -360,55 +390,117 @@ def float_to_str(f):
# pylint: enable=protected-access
return obj

@staticmethod
def _float8_p3109_p_finfo(p):
def float_to_str(f):
return "%6.2e" % float(f)

# pylint: disable=protected-access
obj = object.__new__(np.finfo)

if p == 3:
dtype = float8_p3109_p3
obj.dtype = _float8_p3109_p3_dtype
elif p == 4:
dtype = float8_p3109_p4
obj.dtype = _float8_p3109_p4_dtype
elif p == 5:
dtype = float8_p3109_p5
obj.dtype = _float8_p3109_p5_dtype
else:
raise NotImplementedError()

obj._machar = _Float8IEEEMachArLike(p)

bias = 2 ** (7 - p)
tiny = obj._machar.smallest_normal
machep = 1 - p
eps = 2.0**machep
negep = -p
epsneg = 2.0**negep
max_ = (1 - 2 ** (1 - p)) * 2**bias # 1'0000 - 0'0010 = 0'1110

if p == 3:
assert tiny == float.fromhex("0x1p-15")
assert eps == float.fromhex("0x1p-2")
assert epsneg == float.fromhex("0x1p-3")
assert max_ == float.fromhex("0x1.8p15")
elif p == 4:
assert tiny == float.fromhex("0x1p-7")
assert eps == float.fromhex("0x1p-3")
assert epsneg == float.fromhex("0x1p-4")
assert max_ == float.fromhex("0x1.Cp7")
elif p == 5:
assert tiny == float.fromhex("0x1p-3")
assert eps == float.fromhex("0x1p-4")
assert epsneg == float.fromhex("0x1p-5")
assert max_ == float.fromhex("0x1.Ep3")
else:
raise NotImplementedError()

obj.bits = 8

# nextafter(1.0, Inf) - 1.0
obj.eps = dtype(eps)

# The exponent that yields eps.
obj.machep = machep

# 1.0 = nextafter(1.0, -Inf)
obj.epsneg = dtype(epsneg)

# The exponent that yields epsneg.
obj.negep = negep

# The largest representable number.
obj.max = dtype(max_)

# The smallest representable number, typically -max.
obj.min = dtype(-max_)

obj.nexp = 8 - p
obj.nmant = p - 1
obj.iexp = obj.nexp
obj.maxexp = bias
obj.minexp = 1 - bias

# The approximate number of decimal digits to which this kind of float is precise.
obj.precision = 1 if p < 4 else 2

# The approximate decimal resolution of this type, i.e., 10**-precision.
obj.resolution = dtype(10**-obj.precision)

if not hasattr(obj, "tiny"):
obj.tiny = dtype(tiny)
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = obj._machar.smallest_normal
obj.smallest_subnormal = obj._machar.smallest_subnormal

obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(max_)
obj._str_epsneg = float_to_str(epsneg)
obj._str_eps = float_to_str(eps)
obj._str_resolution = float_to_str(obj.resolution)
# pylint: enable=protected-access
return obj

def __new__(cls, dtype):
if (
isinstance(dtype, str)
and dtype == "bfloat16"
or dtype == _bfloat16_dtype
):
if _bfloat16_dtype not in cls._finfo_cache:
cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo()
return cls._finfo_cache[_bfloat16_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3b11fnuz"
or dtype == _float8_e4m3b11fnuz_dtype
for ty, constructor in (
(_bfloat16_dtype, cls._bfloat16_finfo),
(_float8_e4m3b11fnuz_dtype, cls._float8_e4m3b11fnuz_finfo),
(_float8_e4m3fn_dtype, cls._float8_e4m3fn_finfo),
(_float8_e4m3fnuz_dtype, cls._float8_e4m3fnuz_finfo),
(_float8_e5m2_dtype, cls._float8_e5m2_finfo),
(_float8_e5m2fnuz_dtype, cls._float8_e5m2fnuz_finfo),
(_float8_p3109_p3_dtype, lambda: cls._float8_p3109_p_finfo(3)),
(_float8_p3109_p4_dtype, lambda: cls._float8_p3109_p_finfo(4)),
(_float8_p3109_p5_dtype, lambda: cls._float8_p3109_p_finfo(5)),
):
if _float8_e4m3b11fnuz_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3b11fnuz_dtype] = (
cls._float8_e4m3b11fnuz_finfo()
)
return cls._finfo_cache[_float8_e4m3b11fnuz_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3fn"
or dtype == _float8_e4m3fn_dtype
):
if _float8_e4m3fn_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3fn_dtype] = cls._float8_e4m3fn_finfo()
return cls._finfo_cache[_float8_e4m3fn_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3fnuz"
or dtype == _float8_e4m3fnuz_dtype
):
if _float8_e4m3fnuz_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3fnuz_dtype] = cls._float8_e4m3fnuz_finfo()
return cls._finfo_cache[_float8_e4m3fnuz_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e5m2"
or dtype == _float8_e5m2_dtype
):
if _float8_e5m2_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e5m2_dtype] = cls._float8_e5m2_finfo()
return cls._finfo_cache[_float8_e5m2_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e5m2fnuz"
or dtype == _float8_e5m2fnuz_dtype
):
if _float8_e5m2fnuz_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e5m2fnuz_dtype] = cls._float8_e5m2fnuz_finfo()
return cls._finfo_cache[_float8_e5m2fnuz_dtype]
if isinstance(dtype, str) and dtype == ty.name or dtype == ty:
if ty not in cls._finfo_cache:
cls._finfo_cache[ty] = constructor()
return cls._finfo_cache[ty]

return super().__new__(cls, dtype)
Loading