From 4131860d74b419fe498d08b4807b167bafeda173 Mon Sep 17 00:00:00 2001 From: Andrew Fitzgibbon Date: Fri, 10 Nov 2023 09:40:11 +0000 Subject: [PATCH 1/4] Introduce IEEE P3109 dtypes --- .pre-commit-config.yaml | 4 +- README.md | 15 +++ ml_dtypes/__init__.py | 33 +++--- ml_dtypes/_finfo.py | 148 +++++++++++++++++++++++++ ml_dtypes/_src/dtypes.cc | 152 ++++++++++++++++++-------- ml_dtypes/_src/ufuncs.h | 4 +- ml_dtypes/include/float8.h | 154 +++++++++++++++++++++++++-- ml_dtypes/tests/custom_float_test.py | 46 +++++--- ml_dtypes/tests/finfo_test.py | 3 + ml_dtypes/tests/float8_test.cc | 117 +++++++++++++++++++- 10 files changed, 595 insertions(+), 81 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 24ad948d..5f5a3066 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace - id: debug-statements - repo: https://github.com/google/pyink - rev: 23.3.1 + rev: 23.10.0 hooks: - id: pyink language_version: python3.9 diff --git a/README.md b/README.md index ba34d55e..9250905c 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ * `float8_e4m3fnuz` * `float8_e5m2` * `float8_e5m2fnuz` + * `float8_p3109_p

` - `int4` and `uint4`: low precision integer types. See below for specifications of these number formats. @@ -107,6 +108,20 @@ This type has the following characteristics: * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s - `0b10000000` * denormals when exponent is 0 +### float8_p3109_p

+ +These types represent the types under discussion in IEEE working group P3109, +"Arithmetic Formats for Machine Learning ", parameterized by precision $p$. + +These type has the following characteristics: + * Precision $p$: $2 < p < 6$ + * Exponent bits, E: $8-p$ + * Exponent bias: 2 ^ (E-1) + * Infinities: +Inf, -Inf + * No negative zero + * Single NaN in the -0 position: `0b10000000` == `0x80` + * Denormals when exponent is 0 + ## `int4` and `uint4` 4-bit integer types, where each element is represented unpacked (i.e., padded up diff --git a/ml_dtypes/__init__.py b/ml_dtypes/__init__.py index 7546ba96..16a6d8fb 100644 --- a/ml_dtypes/__init__.py +++ b/ml_dtypes/__init__.py @@ -12,19 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '0.3.1' # Keep in sync with pyproject.toml:version +__version__ = "0.3.1" # Keep in sync with pyproject.toml:version __all__ = [ - '__version__', - 'bfloat16', - 'finfo', - 'float8_e4m3b11fnuz', - 'float8_e4m3fn', - 'float8_e4m3fnuz', - 'float8_e5m2', - 'float8_e5m2fnuz', - 'iinfo', - 'int4', - 'uint4', + "__version__", + "bfloat16", + "finfo", + "float8_e4m3b11fnuz", + "float8_e4m3fn", + "float8_e4m3fnuz", + "float8_e5m2", + "float8_e5m2fnuz", + "float8_p3109_p3", + "float8_p3109_p4", + "float8_p3109_p5", + "iinfo", + "int4", + "uint4", ] from typing import Type @@ -37,6 +40,9 @@ from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz from ml_dtypes._ml_dtypes_ext import float8_e5m2 from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz +from ml_dtypes._ml_dtypes_ext import float8_p3109_p3 +from ml_dtypes._ml_dtypes_ext import float8_p3109_p4 +from ml_dtypes._ml_dtypes_ext import float8_p3109_p5 from ml_dtypes._ml_dtypes_ext import int4 from ml_dtypes._ml_dtypes_ext import uint4 import numpy as np @@ -47,6 +53,9 @@ float8_e4m3fnuz: Type[np.generic] float8_e5m2: Type[np.generic] float8_e5m2fnuz: Type[np.generic] +float8_p3109_p3: Type[np.generic] +float8_p3109_p4: Type[np.generic] +float8_p3109_p5: Type[np.generic] int4: Type[np.generic] uint4: Type[np.generic] diff --git a/ml_dtypes/_finfo.py b/ml_dtypes/_finfo.py index 451f2766..09bf7e8c 100644 --- a/ml_dtypes/_finfo.py +++ b/ml_dtypes/_finfo.py @@ -22,6 +22,10 @@ from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz from ml_dtypes._ml_dtypes_ext import float8_e5m2 from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz +from ml_dtypes._ml_dtypes_ext import float8_p3109_p3 +from ml_dtypes._ml_dtypes_ext import float8_p3109_p4 +from ml_dtypes._ml_dtypes_ext import float8_p3109_p5 + import numpy as np _bfloat16_dtype = np.dtype(bfloat16) @@ -30,6 +34,9 @@ _float8_e4m3fnuz_dtype = np.dtype(float8_e4m3fnuz) _float8_e5m2_dtype = np.dtype(float8_e5m2) _float8_e5m2fnuz_dtype = np.dtype(float8_e5m2fnuz) +_float8_p3109_p3_dtype = np.dtype(float8_p3109_p3) +_float8_p3109_p4_dtype = np.dtype(float8_p3109_p4) +_float8_p3109_p5_dtype = np.dtype(float8_p3109_p5) class _Bfloat16MachArLike: @@ -86,6 +93,29 @@ def __init__(self): self.smallest_subnormal = float8_e5m2fnuz(smallest_subnormal) +class _Float8IEEEMachArLike: + + def __init__(self, p): + # These are hard-coded in order to independently test against the computed values in the C++ implementation + if p == 3: + smallest_normal = float.fromhex("0x1p-15") + self.smallest_normal = float8_p3109_p3(smallest_normal) + smallest_subnormal = float.fromhex("0x1p-17") + self.smallest_subnormal = float8_p3109_p3(smallest_subnormal) + + if p == 4: + smallest_normal = float.fromhex("0x1p-7") + self.smallest_normal = float8_p3109_p4(smallest_normal) + smallest_subnormal = float.fromhex("0x1p-10") + self.smallest_subnormal = float8_p3109_p4(smallest_subnormal) + + if p == 5: + smallest_normal = float.fromhex("0x1p-3") + self.smallest_normal = float8_p3109_p5(smallest_normal) + smallest_subnormal = float.fromhex("0x1p-7") + self.smallest_subnormal = float8_p3109_p5(smallest_subnormal) + + class finfo(np.finfo): # pylint: disable=invalid-name,missing-class-docstring __doc__ = np.finfo.__doc__ _finfo_cache: Dict[np.dtype, np.finfo] = {} @@ -360,6 +390,114 @@ def float_to_str(f): # pylint: enable=protected-access return obj + @staticmethod + def _float8_p3109_p_finfo(p): + def float_to_str(f): + return "%6.2e" % float(f) + + # pylint: disable=protected-access + obj = object.__new__(np.finfo) + + if p == 3: + dtype = float8_p3109_p3 + obj.dtype = _float8_p3109_p3_dtype + elif p == 4: + dtype = float8_p3109_p4 + obj.dtype = _float8_p3109_p4_dtype + elif p == 5: + dtype = float8_p3109_p5 + obj.dtype = _float8_p3109_p5_dtype + else: + raise NotImplementedError() + + obj._machar = _Float8IEEEMachArLike(p) + + bias = 2 ** (7 - p) + tiny = obj._machar.smallest_normal + machep = 1 - p + eps = 2.0**machep + negep = -p + epsneg = 2.0**negep + max_ = (1 - 2 ** (1 - p)) * 2**bias # 1'0000 - 0'0010 = 0'1110 + + if p == 3: + assert tiny == float.fromhex("0x1p-15") + assert eps == float.fromhex("0x1p-2") + assert epsneg == float.fromhex("0x1p-3") + assert max_ == float.fromhex("0x1.8p15") + elif p == 4: + assert tiny == float.fromhex("0x1p-7") + assert eps == float.fromhex("0x1p-3") + assert epsneg == float.fromhex("0x1p-4") + assert max_ == float.fromhex("0x1.Cp7") + elif p == 5: + assert tiny == float.fromhex("0x1p-3") + assert eps == float.fromhex("0x1p-4") + assert epsneg == float.fromhex("0x1p-5") + assert max_ == float.fromhex("0x1.Ep3") + else: + raise NotImplementedError() + + obj.bits = 8 + + # nextafter(1.0, Inf) - 1.0 + obj.eps = dtype(eps) + + # The exponent that yields eps. + obj.machep = machep + + # 1.0 = nextafter(1.0, -Inf) + obj.epsneg = dtype(epsneg) + + # The exponent that yields epsneg. + obj.negep = negep + + # The largest representable number. + obj.max = dtype(max_) + + # The smallest representable number, typically -max. + obj.min = dtype(-max_) + + obj.nexp = 8 - p + obj.nmant = p - 1 + obj.iexp = obj.nexp + obj.maxexp = bias + obj.minexp = 1 - bias + + # The approximate number of decimal digits to which this kind of float is precise. + obj.precision = 1 if p < 4 else 2 + + # The approximate decimal resolution of this type, i.e., 10**-precision. + obj.resolution = dtype(10**-obj.precision) + + if not hasattr(obj, "tiny"): + obj.tiny = dtype(tiny) + if not hasattr(obj, "smallest_normal"): + obj.smallest_normal = obj._machar.smallest_normal + obj.smallest_subnormal = obj._machar.smallest_subnormal + + obj._str_tiny = float_to_str(tiny) + obj._str_smallest_normal = float_to_str(tiny) + obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) + obj._str_max = float_to_str(max_) + obj._str_epsneg = float_to_str(epsneg) + obj._str_eps = float_to_str(eps) + obj._str_resolution = float_to_str(obj.resolution) + # pylint: enable=protected-access + return obj + + @staticmethod + def _float8_p3109_p3_finfo(): + return finfo._float8_p3109_p_finfo(3) + + @staticmethod + def _float8_p3109_p4_finfo(): + return finfo._float8_p3109_p_finfo(4) + + @staticmethod + def _float8_p3109_p5_finfo(): + return finfo._float8_p3109_p_finfo(5) + def __new__(cls, dtype): if ( isinstance(dtype, str) @@ -411,4 +549,14 @@ def __new__(cls, dtype): if _float8_e5m2fnuz_dtype not in cls._finfo_cache: cls._finfo_cache[_float8_e5m2fnuz_dtype] = cls._float8_e5m2fnuz_finfo() return cls._finfo_cache[_float8_e5m2fnuz_dtype] + for type_str, test_dtype, finfo in ( + ("float8_p3109_p3", _float8_p3109_p3_dtype, cls._float8_p3109_p3_finfo), + ("float8_p3109_p4", _float8_p3109_p4_dtype, cls._float8_p3109_p4_finfo), + ("float8_p3109_p5", _float8_p3109_p5_dtype, cls._float8_p3109_p5_finfo), + ): + if isinstance(dtype, str) and dtype == type_str or dtype == test_dtype: + if test_dtype not in cls._finfo_cache: + cls._finfo_cache[test_dtype] = finfo() + return cls._finfo_cache[test_dtype] + return super().__new__(cls, dtype) diff --git a/ml_dtypes/_src/dtypes.cc b/ml_dtypes/_src/dtypes.cc index 31ac72da..859cd84f 100644 --- a/ml_dtypes/_src/dtypes.cc +++ b/ml_dtypes/_src/dtypes.cc @@ -147,6 +147,51 @@ struct TypeDescriptor : CustomFloatType { static constexpr char kNpyDescrByteorder = '='; }; +template <> +struct TypeDescriptor> : CustomFloatType> { + typedef float8_p3109_p<3> T; + static constexpr bool is_floating = true; + static constexpr bool is_integral = false; + static constexpr const char* kTypeName = "float8_p3109_p3"; + static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_p3109_p4"; + static constexpr const char* kTpDoc = "float8_p3109_p3 floating-point values"; + static constexpr char kNpyDescrKind = 'V'; + // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type + // character is unique. + static constexpr char kNpyDescrType = 'P'; + static constexpr char kNpyDescrByteorder = '='; +}; + +template <> +struct TypeDescriptor> : CustomFloatType> { + typedef float8_p3109_p<4> T; + static constexpr bool is_floating = true; + static constexpr bool is_integral = false; + static constexpr const char* kTypeName = "float8_p3109_p4"; + static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_p3109_p4"; + static constexpr const char* kTpDoc = "float8_p3109_p4 floating-point values"; + static constexpr char kNpyDescrKind = 'V'; + // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type + // character is unique. + static constexpr char kNpyDescrType = 'Q'; + static constexpr char kNpyDescrByteorder = '='; +}; + +template <> +struct TypeDescriptor> : CustomFloatType> { + typedef float8_p3109_p<5> T; + static constexpr bool is_floating = true; + static constexpr bool is_integral = false; + static constexpr const char* kTypeName = "float8_p3109_p5"; + static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_p3109_p5"; + static constexpr const char* kTpDoc = "float8_p3109_p5 floating-point values"; + static constexpr char kNpyDescrKind = 'V'; + // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type + // character is unique. + static constexpr char kNpyDescrType = 'R'; + static constexpr char kNpyDescrByteorder = '='; +}; + template <> struct TypeDescriptor : Int4TypeDescriptor { typedef int4 T; @@ -251,6 +296,21 @@ bool Initialize() { numpy.get(), &float8_e5m2fnuz_already_registered)) { return false; } + bool float8_p3109_p3_already_registered; + if (!ml_dtypes::RegisterFloatDtype>( + numpy.get(), &float8_p3109_p3_already_registered)) { + return false; + } + bool float8_p3109_p4_already_registered; + if (!ml_dtypes::RegisterFloatDtype>( + numpy.get(), &float8_p3109_p4_already_registered)) { + return false; + } + bool float8_p3109_p5_already_registered; + if (!ml_dtypes::RegisterFloatDtype>( + numpy.get(), &float8_p3109_p5_already_registered)) { + return false; + } if (!ml_dtypes::RegisterInt4Dtype(numpy.get())) { return false; @@ -285,6 +345,7 @@ bool Initialize() { success &= RegisterTwoWayCustomCast(); success &= RegisterTwoWayCustomCast(); success &= RegisterTwoWayCustomCast(); + success &= RegisterTwoWayCustomCast(); success &= RegisterTwoWayCustomCast(); success &= RegisterTwoWayCustomCast(); @@ -293,6 +354,31 @@ bool Initialize() { success &= RegisterTwoWayCustomCast(); success &= RegisterTwoWayCustomCast(); success &= RegisterTwoWayCustomCast(); + + success &= RegisterTwoWayCustomCast, bfloat16>(); + success &= RegisterTwoWayCustomCast, float8_e4m3b11fnuz>(); + success &= RegisterTwoWayCustomCast, float8_e4m3fn>(); + success &= RegisterTwoWayCustomCast, float8_e4m3fnuz>(); + success &= RegisterTwoWayCustomCast, float8_e5m2>(); + success &= RegisterTwoWayCustomCast, float8_e5m2fnuz>(); + + success &= RegisterTwoWayCustomCast, bfloat16>(); + success &= RegisterTwoWayCustomCast, float8_e4m3b11fnuz>(); + success &= RegisterTwoWayCustomCast, float8_e4m3fn>(); + success &= RegisterTwoWayCustomCast, float8_e4m3fnuz>(); + success &= RegisterTwoWayCustomCast, float8_e5m2>(); + success &= RegisterTwoWayCustomCast, float8_e5m2fnuz>(); + success &= RegisterTwoWayCustomCast, float8_p3109_p<3>>(); + + success &= RegisterTwoWayCustomCast, bfloat16>(); + success &= RegisterTwoWayCustomCast, float8_e4m3b11fnuz>(); + success &= RegisterTwoWayCustomCast, float8_e4m3fn>(); + success &= RegisterTwoWayCustomCast, float8_e4m3fnuz>(); + success &= RegisterTwoWayCustomCast, float8_e5m2>(); + success &= RegisterTwoWayCustomCast, float8_e5m2fnuz>(); + success &= RegisterTwoWayCustomCast, float8_p3109_p<3>>(); + success &= RegisterTwoWayCustomCast, float8_p3109_p<4>>(); + return success; } @@ -301,6 +387,10 @@ static PyModuleDef module_def = { "_ml_dtypes_ext", }; +typedef float8_p3109_p<3> float8_p3109_p3; +typedef float8_p3109_p<4> float8_p3109_p4; +typedef float8_p3109_p<5> float8_p3109_p5; + // TODO(phawkins): PyMODINIT_FUNC handles visibility correctly in Python 3.9+. // Just use PyMODINIT_FUNC after dropping Python 3.8 support. #if defined(WIN32) || defined(_WIN32) @@ -309,6 +399,12 @@ static PyModuleDef module_def = { #define EXPORT_SYMBOL __attribute__((visibility("default"))) #endif +template bool py_set(Safe_PyObjectPtr &m, char const *str) { + PyObject *type_ptr = + reinterpret_cast(TypeDescriptor::type_ptr); + return PyObject_SetAttrString(m.get(), str, type_ptr) >= 0; +} + extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() { Safe_PyObjectPtr m = make_safe(PyModule_Create(&module_def)); if (!m) { @@ -321,50 +417,20 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() { return nullptr; } - if (PyObject_SetAttrString( - m.get(), "float8_e4m3b11fnuz", - reinterpret_cast( - TypeDescriptor::type_ptr)) < 0) { + bool ok = (py_set(m, "float8_e4m3b11fnuz") && + py_set(m, "float8_e4m3fn") && + py_set(m, "float8_e4m3fnuz") && + py_set(m, "float8_e5m2") && + py_set(m, "float8_e5m2fnuz") && + py_set(m, "float8_p3109_p3") && + py_set(m, "float8_p3109_p4") && + py_set(m, "float8_p3109_p5") && + py_set(m, "bfloat16") && + py_set(m, "int4") && + py_set(m, "uint4")); + if (!ok) return nullptr; - } - if (PyObject_SetAttrString(m.get(), "float8_e4m3fn", - reinterpret_cast( - TypeDescriptor::type_ptr)) < - 0) { - return nullptr; - } - if (PyObject_SetAttrString(m.get(), "float8_e4m3fnuz", - reinterpret_cast( - TypeDescriptor::type_ptr)) < - 0) { - return nullptr; - } - if (PyObject_SetAttrString(m.get(), "float8_e5m2", - reinterpret_cast( - TypeDescriptor::type_ptr)) < 0) { - return nullptr; - } - if (PyObject_SetAttrString(m.get(), "float8_e5m2fnuz", - reinterpret_cast( - TypeDescriptor::type_ptr)) < - 0) { - return nullptr; - } - if (PyObject_SetAttrString(m.get(), "bfloat16", - reinterpret_cast( - TypeDescriptor::type_ptr)) < 0) { - return nullptr; - } - if (PyObject_SetAttrString( - m.get(), "int4", - reinterpret_cast(TypeDescriptor::type_ptr)) < 0) { - return nullptr; - } - if (PyObject_SetAttrString( - m.get(), "uint4", - reinterpret_cast(TypeDescriptor::type_ptr)) < 0) { - return nullptr; - } + return m.release(); } } // namespace ml_dtypes diff --git a/ml_dtypes/_src/ufuncs.h b/ml_dtypes/_src/ufuncs.h index e3262091..7261681d 100644 --- a/ml_dtypes/_src/ufuncs.h +++ b/ml_dtypes/_src/ufuncs.h @@ -316,7 +316,7 @@ using BitsType = typename GetUnsignedInteger::type; template std::pair, BitsType> SignAndMagnitude(T x) { - // For types that represent NaN by -0, (i.e. *fnuz), abs(x) remains -0 without + // For types that represent NaN by -0, (i.e. *fnuz, *p3109), abs(x) remains -0 without // flipping the sign. Therefore, we need to explicitly check the // most-significant bit. constexpr BitsType kSignMask = BitsType(1) @@ -682,7 +682,7 @@ struct NextAfter { : static_cast>(1); BitsType out_int = from_rep + magnitude_adjustment; T out = Eigen::numext::bit_cast(out_int); - // Some non-IEEE compatible formats may have a representation for NaN + // Some non-IEEE-754 compatible formats may have a representation for NaN // instead of -0, ensure we return a zero in such cases. if constexpr (!std::numeric_limits::is_iec559) { if (Eigen::numext::isnan(out)) { diff --git a/ml_dtypes/include/float8.h b/ml_dtypes/include/float8.h index 8d6d41df..9d95a493 100644 --- a/ml_dtypes/include/float8.h +++ b/ml_dtypes/include/float8.h @@ -48,6 +48,7 @@ class float8_e4m3fnuz; class float8_e4m3b11fnuz; class float8_e5m2; class float8_e5m2fnuz; +template class float8_p3109_p; template class float8_base { @@ -367,13 +368,8 @@ class float8_e5m2fnuz : public float8_base { using Base::Base; public: - explicit EIGEN_DEVICE_FUNC float8_e5m2fnuz(const float8_e5m2& f8) - : float8_e5m2fnuz(ConvertFrom(f8)) {} - explicit EIGEN_DEVICE_FUNC float8_e5m2fnuz(const float8_e4m3b11fnuz& f8) - : float8_e5m2fnuz(ConvertFrom(f8)) {} - explicit EIGEN_DEVICE_FUNC float8_e5m2fnuz(const float8_e4m3fn& f8) - : float8_e5m2fnuz(ConvertFrom(f8)) {} - explicit EIGEN_DEVICE_FUNC float8_e5m2fnuz(const float8_e4m3fnuz& f8) + template = 0> + explicit EIGEN_DEVICE_FUNC float8_e5m2fnuz(T f8) : float8_e5m2fnuz(ConvertFrom(f8)) {} constexpr float8_e5m2fnuz operator-() const { @@ -390,6 +386,49 @@ class float8_e5m2fnuz : public float8_base { explicit EIGEN_DEVICE_FUNC operator bool() const { return rep() != 0; } }; +template +class float8_p3109_p : public float8_base> { + // IEEE P3109 WG 8-bit floating point with p bits of precision. + // + // An 8-bit floating point type with 1 sign bit, + // 8-p bits exponent, and p-1 bits mantissa. + // + // This type has the following characteristics: + // * bit encoding: S1E<8-p>M + // * exponent bias: 2^(7-p) + // * infinities: +Inf at 0x7f, -Inf at 0xff + // * NaNs: Single NaN at `0b10000000` + // * denormals when exponent is 0 + + private: + typedef float8_p3109_p

this_t; + using Base = float8_base; + friend class float8_base; + using Base::Base; + + public: + + template = 0> + explicit EIGEN_DEVICE_FUNC float8_p3109_p(T f8) + : float8_p3109_p(this->ConvertFrom(f8)) {} + + constexpr float8_p3109_p

operator-() const { + // TODO: use isnan() + if ((this->rep() & 0x7f) == 0x00) { + return *this; + } + return Base::operator-(); + } + + float8_p3109_p

operator-(const float8_p3109_p

& other) const { + return Base::operator-(other); + } + + explicit EIGEN_DEVICE_FUNC operator bool() const { return this->rep() != 0; } +}; + +// ----------------------------------------- + constexpr double ConstexprAbs(double x) { return x < 0.0 ? -x : x; } constexpr double ConstexprCeil(double x) { @@ -427,6 +466,7 @@ constexpr int MaxDigits10FromDigits(int digits) { // C17 5.2.4.2.2p11: // "minimum negative integer such that 10 raised to that power is in the range // of normalized floating-point numbers" +// TODO: https://en.cppreference.com/w/cpp/types/numeric_limits/max_exponent10 says "representable" // ceil(log10(2**(emin - 1))) == ceil((emin - 1) * log10(2)); constexpr int MinExponent10FromMinExponent(int min_exponent) { return static_cast(ConstexprCeil((min_exponent - 1) * kLog10Of2)); @@ -446,6 +486,8 @@ constexpr int MaxExponent10FromMaxExponentAndDigits(int max_exponent, -0.057991946977686754, // log10(1 - 2**-4) -0.028028723600243537, + // log10(1 - 2**-5) + -0.013788284485633295 }; return static_cast(ConstexprFloor(kLog10OfOnePredecessor[digits - 3] + max_exponent * kLog10Of2)); @@ -764,6 +806,64 @@ struct numeric_limits_float8_e5m2fnuz : public numeric_limits_float8_base { } }; +template +struct numeric_limits_float8_p3109_p : public numeric_limits_float8_base { + private: + static inline constexpr const int kExponentBias = 1 << (7-p); + static inline constexpr const int kMantissaBits = p - 1; + + public: + // NOLINTBEGIN: these names must match std::numeric_limits. + static inline constexpr const int digits = p; + static inline constexpr const int digits10 = Digits10FromDigits(digits); + static inline constexpr const int max_digits10 = + MaxDigits10FromDigits(digits); + static inline constexpr const int min_exponent = (1 - kExponentBias) + 1; + static inline constexpr const int min_exponent10 = + MinExponent10FromMinExponent(min_exponent); + static inline constexpr const int max_exponent = kExponentBias - 1; + static inline constexpr const int max_exponent10 = + MaxExponent10FromMaxExponentAndDigits(max_exponent, digits); + static inline constexpr const bool is_iec559 = false; // TODO + static inline constexpr const bool has_infinity = true; + static inline constexpr const bool has_signaling_NaN = false; + // NOLINTEND + + static constexpr float8_p3109_p

min() { + return float8_p3109_p

::FromRep(1<<(p-1)); + } + static constexpr float8_p3109_p

lowest() { + return float8_p3109_p

::FromRep(0xfe); + } + static constexpr float8_p3109_p

max() { + return float8_p3109_p

::FromRep(0x7e); + } + static constexpr float8_p3109_p

epsilon() { + if constexpr (p < 5) { + constexpr int expeps = (-kMantissaBits + kExponentBias) << kMantissaBits; + return float8_p3109_p

::FromRep(expeps); + } + // p >= 5: eps is subnormal + return float8_p3109_p

::FromRep(uint8_t(1 << (kExponentBias - 1))); + } + static constexpr float8_p3109_p

round_error() { + // Return 0.5 + return float8_p3109_p

::FromRep((-1 + kExponentBias) << kMantissaBits); + } + static constexpr float8_p3109_p

infinity() { + return float8_p3109_p

::FromRep(0x7f); + } + static constexpr float8_p3109_p

quiet_NaN() { + return float8_p3109_p

::FromRep(0x80); + } + static constexpr float8_p3109_p

signaling_NaN() { + return float8_p3109_p

::FromRep(0x80); + } + static constexpr float8_p3109_p

denorm_min() { + return float8_p3109_p

::FromRep(0x01); + } +}; + } // namespace float8_internal } // namespace ml_dtypes @@ -788,6 +888,11 @@ struct numeric_limits template <> struct numeric_limits : public ml_dtypes::float8_internal::numeric_limits_float8_e5m2fnuz {}; + +template +struct numeric_limits> + : public ml_dtypes::float8_internal::numeric_limits_float8_p3109_p

{}; + } // namespace std namespace ml_dtypes { @@ -839,6 +944,16 @@ constexpr inline bool (isnan)(const float8_e5m2fnuz& a) { return a.rep() == 0x80; } +template +constexpr inline bool(isnan)(const float8_p3109_p

& a) { + return a.rep() == 0x80; +} + +template +constexpr inline float8_p3109_p

abs(const float8_p3109_p

& a) { + return isnan(a) ? a : float8_p3109_p

::FromRep(a.rep() & 0x7F); +} + template constexpr inline bool(isinf)(const float8_base& a) { if constexpr (std::numeric_limits::has_infinity) { @@ -919,6 +1034,12 @@ struct Traits : public TraitsBase { static constexpr int kExponentBias = Base::kExponentBias + 1; }; +template +struct Traits> : public TraitsBase> { + using Base = TraitsBase>; + static constexpr int kExponentBias = 1 << (Base::kExponentBits - 1); +}; + template constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff) { // Round to nearest even by adding a bias term. @@ -1279,6 +1400,8 @@ using float8_e4m3fnuz = float8_internal::float8_e4m3fnuz; using float8_e4m3b11fnuz = float8_internal::float8_e4m3b11fnuz; using float8_e5m2 = float8_internal::float8_e5m2; using float8_e5m2fnuz = float8_internal::float8_e5m2fnuz; +template +using float8_p3109_p = float8_internal::float8_p3109_p

; } // namespace ml_dtypes @@ -1345,6 +1468,12 @@ EIGEN_DEVICE_FUNC inline bool isinf_impl( return ml_dtypes::float8_internal::isinf(x); } +template +EIGEN_DEVICE_FUNC inline bool isinf_impl(const ml_dtypes::float8_p3109_p

& x) { + return ml_dtypes::float8_internal::isinf(x); +} + + template <> EIGEN_DEVICE_FUNC inline bool isnan_impl( const ml_dtypes::float8_e4m3fn& x) { @@ -1375,6 +1504,12 @@ EIGEN_DEVICE_FUNC inline bool isnan_impl( return ml_dtypes::float8_internal::isnan(x); } +template +EIGEN_DEVICE_FUNC inline bool isnan_impl(const ml_dtypes::float8_p3109_p

& x) { + return ml_dtypes::float8_internal::isnan(x); +} + + template <> EIGEN_DEVICE_FUNC inline bool isfinite_impl( const ml_dtypes::float8_e4m3fn& x) { @@ -1405,6 +1540,11 @@ EIGEN_DEVICE_FUNC inline bool isfinite_impl( return ml_dtypes::float8_internal::isfinite(x); } +template +EIGEN_DEVICE_FUNC inline bool isfinite_impl(const ml_dtypes::float8_p3109_p

& x) { + return ml_dtypes::float8_internal::isfinite(x); +} + } // namespace internal } // namespace Eigen diff --git a/ml_dtypes/tests/custom_float_test.py b/ml_dtypes/tests/custom_float_test.py index d71ae8b1..176134a6 100644 --- a/ml_dtypes/tests/custom_float_test.py +++ b/ml_dtypes/tests/custom_float_test.py @@ -35,6 +35,9 @@ float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz float8_e5m2 = ml_dtypes.float8_e5m2 float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz +float8_p3109_p3 = ml_dtypes.float8_p3109_p3 +float8_p3109_p4 = ml_dtypes.float8_p3109_p4 +float8_p3109_p5 = ml_dtypes.float8_p3109_p5 @contextlib.contextmanager @@ -105,6 +108,9 @@ def dtype_has_inf(dtype): float8_e4m3fnuz, float8_e5m2, float8_e5m2fnuz, + float8_p3109_p3, + float8_p3109_p4, + float8_p3109_p5, ] # Values that should round trip exactly to float and back. @@ -118,7 +124,7 @@ def dtype_has_inf(dtype): -0.5, float(ml_dtypes.finfo(dtype).eps), 1.0 + float(ml_dtypes.finfo(dtype).eps), - 1.0 - float(ml_dtypes.finfo(dtype).eps), + 1.0 - float(ml_dtypes.finfo(dtype).eps), # TODO: should be epsneg? -1.0 - float(ml_dtypes.finfo(dtype).eps), -1.0 + float(ml_dtypes.finfo(dtype).eps), 3.5, @@ -159,6 +165,21 @@ def dtype_has_inf(dtype): range(1 << n, 2 << n, 1 << max(0, n - 2)) for n in range(16) ) ), + float8_p3109_p3: list( + itertools.chain.from_iterable( + range(1 << n, 2 << n, 1 << max(0, n - 2)) for n in range(16) + ) + )[:-1], + float8_p3109_p4: list( + itertools.chain.from_iterable( + range(1 << n, 2 << n, 1 << max(0, n - 3)) for n in range(8) + ) + )[:-1], + float8_p3109_p5: list( + itertools.chain.from_iterable( + range(1 << n, 2 << n, 1 << max(0, n - 4)) for n in range(4) + ) + )[:-1], } BITS_TYPE = { @@ -168,6 +189,9 @@ def dtype_has_inf(dtype): float8_e4m3fnuz: np.uint8, float8_e5m2: np.uint8, float8_e5m2fnuz: np.uint8, + float8_p3109_p3: np.uint8, + float8_p3109_p4: np.uint8, + float8_p3109_p5: np.uint8, } @@ -224,19 +248,15 @@ def testRoundTripToNumpy(self, float_type): np.longdouble, ]: with self.subTest(dtype.__name__): - for v in FLOAT_VALUES[float_type]: - np.testing.assert_equal(dtype(v), dtype(float_type(dtype(v)))) + vals = FLOAT_VALUES[float_type] + for v in vals: np.testing.assert_equal(dtype(v), dtype(float_type(dtype(v)))) np.testing.assert_equal( dtype(v), dtype(float_type(np.array(v, dtype))) ) if dtype != float_type: - np.testing.assert_equal( - np.array(FLOAT_VALUES[float_type], dtype), - float_type(np.array(FLOAT_VALUES[float_type], dtype)).astype( - dtype - ), - ) + npvals = np.array(vals, dtype) + np.testing.assert_equal(npvals, float_type(npvals).astype(dtype)) def testCastBetweenCustomTypes(self, float_type): for dtype in FLOAT_DTYPES: @@ -610,9 +630,9 @@ def testArray(self, float_type): self.assertTrue((x == x).all()) def testComparisons(self, float_type): - x = np.array([30, 7, -30], dtype=np.float32) + x = np.array([15, 7, -15], dtype=np.float32) bx = x.astype(float_type) - y = np.array([17, 7, 0], dtype=np.float32) + y = np.array([13, 7, 0], dtype=np.float32) by = y.astype(float_type) np.testing.assert_equal(x == y, bx == by) np.testing.assert_equal(x != y, bx != by) @@ -729,8 +749,8 @@ def testArange(self, float_type): np.arange(-0.0, -2.0, -0.25, dtype=float_type), ) np.testing.assert_equal( - np.arange(-16.0, 16.0, 2.0, dtype=np.float32).astype(float_type), - np.arange(-16.0, 16.0, 2.0, dtype=float_type), + np.arange(-14.0, 14.0, 2.0, dtype=np.float32).astype(float_type), + np.arange(-14.0, 14.0, 2.0, dtype=float_type), ) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") diff --git a/ml_dtypes/tests/finfo_test.py b/ml_dtypes/tests/finfo_test.py index 855c00ba..343b8bda 100644 --- a/ml_dtypes/tests/finfo_test.py +++ b/ml_dtypes/tests/finfo_test.py @@ -24,6 +24,9 @@ ml_dtypes.float8_e4m3fnuz, ml_dtypes.float8_e5m2, ml_dtypes.float8_e5m2fnuz, + ml_dtypes.float8_p3109_p3, + ml_dtypes.float8_p3109_p4, + ml_dtypes.float8_p3109_p5, ] DTYPES_WITH_NO_INFINITY = [ diff --git a/ml_dtypes/tests/float8_test.cc b/ml_dtypes/tests/float8_test.cc index 960f89af..fb3f2cce 100644 --- a/ml_dtypes/tests/float8_test.cc +++ b/ml_dtypes/tests/float8_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include -#include "third_party/absl/strings/str_cat.h" +#include "absl/strings/str_cat.h" #include "unsupported/Eigen/CXX11/Tensor" namespace ml_dtypes { @@ -45,6 +45,12 @@ struct Float8TestParamNames { return "float8_e4m3fnuz"; } else if constexpr (std::is_same_v) { return "float8_e5m2fnuz"; + } else if constexpr (std::is_same_v>) { + return "float8_p3109_p<3>"; + } else if constexpr (std::is_same_v>) { + return "float8_p3109_p<4>"; + } else if constexpr (std::is_same_v>) { + return "float8_p3109_p<5>"; } return absl::StrCat(idx); } @@ -52,7 +58,8 @@ struct Float8TestParamNames { using Float8Types = ::testing::Types; + float8_e4m3fnuz, float8_e5m2fnuz, + float8_p3109_p<3>, float8_p3109_p<4>, float8_p3109_p<5>>; TYPED_TEST_SUITE(Float8Test, Float8Types, Float8TestParamNames); TEST(Float8E4m3Test, NumericLimits) { @@ -227,6 +234,106 @@ TEST(Float8E5m2fnuzTest, NumericLimits) { EXPECT_EQ(std::numeric_limits::has_signaling_NaN, false); } +// TODO: Float8E replacements +TEST(Float8IEEEP3Test, NumericLimits) { + EXPECT_TRUE( + Eigen::numext::isnan(std::numeric_limits>::quiet_NaN())); + EXPECT_TRUE(Eigen::numext::isnan( + std::numeric_limits>::signaling_NaN())); + EXPECT_EQ(static_cast(std::numeric_limits>::min()), + std::exp2(-15)); + EXPECT_EQ(static_cast(std::numeric_limits>::max()), + 49152); + EXPECT_EQ(static_cast(std::numeric_limits>::lowest()), + -49152); + EXPECT_EQ(static_cast(std::numeric_limits>::epsilon()), + 0.25); + EXPECT_EQ( + static_cast(std::numeric_limits>::round_error()), + 0.5); + EXPECT_TRUE( + Eigen::numext::isinf(std::numeric_limits>::infinity())); + EXPECT_EQ( + static_cast(std::numeric_limits>::denorm_min()), + std::exp2(-17)); + EXPECT_EQ(std::numeric_limits>::digits, 3); + EXPECT_EQ(std::numeric_limits>::digits10, 0); + EXPECT_EQ(std::numeric_limits>::max_digits10, 2); + EXPECT_EQ(std::numeric_limits>::min_exponent, -14); + EXPECT_EQ(std::numeric_limits>::min_exponent10, -4); + EXPECT_EQ(std::numeric_limits>::max_exponent, 15); + EXPECT_EQ(std::numeric_limits>::max_exponent10, 4); + EXPECT_EQ(std::numeric_limits>::is_iec559, false); + EXPECT_EQ(std::numeric_limits>::has_infinity, true); + EXPECT_EQ(std::numeric_limits>::has_signaling_NaN, false); +} + +TEST(Float8IEEEP4Test, NumericLimits) { + EXPECT_TRUE( + Eigen::numext::isnan(std::numeric_limits>::quiet_NaN())); + EXPECT_TRUE(Eigen::numext::isnan( + std::numeric_limits>::signaling_NaN())); + EXPECT_EQ(static_cast(std::numeric_limits>::min()), + std::exp2(-7)); + EXPECT_EQ(static_cast(std::numeric_limits>::max()), + 224); + EXPECT_EQ(static_cast(std::numeric_limits>::lowest()), + -224); + EXPECT_EQ(static_cast(std::numeric_limits>::epsilon()), + 0.125); + EXPECT_EQ( + static_cast(std::numeric_limits>::round_error()), + 0.5); + EXPECT_TRUE( + Eigen::numext::isinf(std::numeric_limits>::infinity())); + EXPECT_EQ( + static_cast(std::numeric_limits>::denorm_min()), + std::exp2(-10)); + EXPECT_EQ(std::numeric_limits>::digits, 4); + EXPECT_EQ(std::numeric_limits>::digits10, 0); + EXPECT_EQ(std::numeric_limits>::max_digits10, 3); + EXPECT_EQ(std::numeric_limits>::min_exponent, -6); + EXPECT_EQ(std::numeric_limits>::min_exponent10, -2); + EXPECT_EQ(std::numeric_limits>::max_exponent, 7); + EXPECT_EQ(std::numeric_limits>::max_exponent10, 2); + EXPECT_EQ(std::numeric_limits>::is_iec559, false); + EXPECT_EQ(std::numeric_limits>::has_infinity, true); + EXPECT_EQ(std::numeric_limits>::has_signaling_NaN, false); +} + +TEST(Float8IEEEP5Test, NumericLimits) { + EXPECT_TRUE( + Eigen::numext::isnan(std::numeric_limits>::quiet_NaN())); + EXPECT_TRUE(Eigen::numext::isnan( + std::numeric_limits>::signaling_NaN())); + EXPECT_EQ(static_cast(std::numeric_limits>::min()), + std::exp2(-3)); + EXPECT_EQ(static_cast(std::numeric_limits>::max()), + 15); + EXPECT_EQ(static_cast(std::numeric_limits>::lowest()), + -15); + EXPECT_EQ(static_cast(std::numeric_limits>::epsilon()), + 0.0625); + EXPECT_EQ( + static_cast(std::numeric_limits>::round_error()), + 0.5); + EXPECT_TRUE( + Eigen::numext::isinf(std::numeric_limits>::infinity())); + EXPECT_EQ( + static_cast(std::numeric_limits>::denorm_min()), + std::exp2(-7)); + EXPECT_EQ(std::numeric_limits>::digits, 5); + EXPECT_EQ(std::numeric_limits>::digits10, 1); + EXPECT_EQ(std::numeric_limits>::max_digits10, 3); + EXPECT_EQ(std::numeric_limits>::min_exponent, -2); + EXPECT_EQ(std::numeric_limits>::min_exponent10, 0); + EXPECT_EQ(std::numeric_limits>::max_exponent, 3); + EXPECT_EQ(std::numeric_limits>::max_exponent10, 0); + EXPECT_EQ(std::numeric_limits>::is_iec559, false); + EXPECT_EQ(std::numeric_limits>::has_infinity, true); + EXPECT_EQ(std::numeric_limits>::has_signaling_NaN, false); +} + TYPED_TEST(Float8Test, FromRep) { using Float8 = TypeParam; Float8 x = Float8::FromRep(0x4F); @@ -756,12 +863,18 @@ struct Float8CastTestParamNames { std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ + std::pair>, \ + std::pair>, \ + std::pair>, \ std::pair, std::pair, \ std::pair, std::pair #define GEN_TYPE_PAIRS() \ GEN_DEST_TYPES(float8_e4m3fn), GEN_DEST_TYPES(float8_e4m3b11fnuz), \ GEN_DEST_TYPES(float8_e5m2), GEN_DEST_TYPES(float8_e4m3fnuz), \ + GEN_DEST_TYPES(float8_p3109_p<3>), \ + GEN_DEST_TYPES(float8_p3109_p<4>), \ + GEN_DEST_TYPES(float8_p3109_p<5>), \ GEN_DEST_TYPES(float8_e5m2fnuz) using Float8CastTypePairs = ::testing::Types; From 8f177411e5933b743eb6e1993765fd54bfb5de2c Mon Sep 17 00:00:00 2001 From: Andrew Fitzgibbon Date: Sat, 11 Nov 2023 09:54:24 +0000 Subject: [PATCH 2/4] Address PR comments --- ml_dtypes/_finfo.py | 94 +++++++++++++-------------------------------- 1 file changed, 27 insertions(+), 67 deletions(-) diff --git a/ml_dtypes/_finfo.py b/ml_dtypes/_finfo.py index 09bf7e8c..c9907195 100644 --- a/ml_dtypes/_finfo.py +++ b/ml_dtypes/_finfo.py @@ -486,77 +486,37 @@ def float_to_str(f): # pylint: enable=protected-access return obj - @staticmethod - def _float8_p3109_p3_finfo(): - return finfo._float8_p3109_p_finfo(3) - - @staticmethod - def _float8_p3109_p4_finfo(): - return finfo._float8_p3109_p_finfo(4) - - @staticmethod - def _float8_p3109_p5_finfo(): - return finfo._float8_p3109_p_finfo(5) - def __new__(cls, dtype): - if ( - isinstance(dtype, str) - and dtype == "bfloat16" - or dtype == _bfloat16_dtype - ): - if _bfloat16_dtype not in cls._finfo_cache: - cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo() - return cls._finfo_cache[_bfloat16_dtype] - if ( - isinstance(dtype, str) - and dtype == "float8_e4m3b11fnuz" - or dtype == _float8_e4m3b11fnuz_dtype - ): - if _float8_e4m3b11fnuz_dtype not in cls._finfo_cache: - cls._finfo_cache[_float8_e4m3b11fnuz_dtype] = ( - cls._float8_e4m3b11fnuz_finfo() - ) - return cls._finfo_cache[_float8_e4m3b11fnuz_dtype] - if ( - isinstance(dtype, str) - and dtype == "float8_e4m3fn" - or dtype == _float8_e4m3fn_dtype - ): - if _float8_e4m3fn_dtype not in cls._finfo_cache: - cls._finfo_cache[_float8_e4m3fn_dtype] = cls._float8_e4m3fn_finfo() - return cls._finfo_cache[_float8_e4m3fn_dtype] - if ( - isinstance(dtype, str) - and dtype == "float8_e4m3fnuz" - or dtype == _float8_e4m3fnuz_dtype - ): - if _float8_e4m3fnuz_dtype not in cls._finfo_cache: - cls._finfo_cache[_float8_e4m3fnuz_dtype] = cls._float8_e4m3fnuz_finfo() - return cls._finfo_cache[_float8_e4m3fnuz_dtype] - if ( - isinstance(dtype, str) - and dtype == "float8_e5m2" - or dtype == _float8_e5m2_dtype - ): - if _float8_e5m2_dtype not in cls._finfo_cache: - cls._finfo_cache[_float8_e5m2_dtype] = cls._float8_e5m2_finfo() - return cls._finfo_cache[_float8_e5m2_dtype] - if ( - isinstance(dtype, str) - and dtype == "float8_e5m2fnuz" - or dtype == _float8_e5m2fnuz_dtype - ): - if _float8_e5m2fnuz_dtype not in cls._finfo_cache: - cls._finfo_cache[_float8_e5m2fnuz_dtype] = cls._float8_e5m2fnuz_finfo() - return cls._finfo_cache[_float8_e5m2fnuz_dtype] - for type_str, test_dtype, finfo in ( - ("float8_p3109_p3", _float8_p3109_p3_dtype, cls._float8_p3109_p3_finfo), - ("float8_p3109_p4", _float8_p3109_p4_dtype, cls._float8_p3109_p4_finfo), - ("float8_p3109_p5", _float8_p3109_p5_dtype, cls._float8_p3109_p5_finfo), + for type_str, test_dtype, constructor in ( + ("bfloat16", _bfloat16_dtype, cls._bfloat16_finfo), + ( + "float8_e4m3b11fnuz", + _float8_e4m3b11fnuz_dtype, + cls._float8_e4m3b11fnuz_finfo, + ), + ("float8_e4m3fn", _float8_e4m3fn_dtype, cls._float8_e4m3fn_finfo), + ("float8_e4m3fnuz", _float8_e4m3fnuz_dtype, cls._float8_e4m3fnuz_finfo), + ("float8_e5m2", _float8_e5m2_dtype, cls._float8_e5m2_finfo), + ("float8_e5m2fnuz", _float8_e5m2fnuz_dtype, cls._float8_e5m2fnuz_finfo), + ( + "float8_p3109_p3", + _float8_p3109_p3_dtype, + lambda: cls._float8_p3109_p_finfo(3), + ), + ( + "float8_p3109_p4", + _float8_p3109_p4_dtype, + lambda: cls._float8_p3109_p_finfo(4), + ), + ( + "float8_p3109_p5", + _float8_p3109_p5_dtype, + lambda: cls._float8_p3109_p_finfo(5), + ), ): if isinstance(dtype, str) and dtype == type_str or dtype == test_dtype: if test_dtype not in cls._finfo_cache: - cls._finfo_cache[test_dtype] = finfo() + cls._finfo_cache[test_dtype] = constructor() return cls._finfo_cache[test_dtype] return super().__new__(cls, dtype) From f209c406acce72aabf7299d3c8de354bece5c568 Mon Sep 17 00:00:00 2001 From: Andrew Fitzgibbon Date: Sat, 11 Nov 2023 10:43:51 +0000 Subject: [PATCH 3/4] Tidy finfo.__new__ --- ml_dtypes/_finfo.py | 59 +++++++++++++++++----------------- ml_dtypes/tests/CMakeLists.txt | 42 ++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 30 deletions(-) create mode 100644 ml_dtypes/tests/CMakeLists.txt diff --git a/ml_dtypes/_finfo.py b/ml_dtypes/_finfo.py index c9907195..7eb4aed4 100644 --- a/ml_dtypes/_finfo.py +++ b/ml_dtypes/_finfo.py @@ -38,6 +38,21 @@ _float8_p3109_p4_dtype = np.dtype(float8_p3109_p4) _float8_p3109_p5_dtype = np.dtype(float8_p3109_p5) +_name_to_dtype = { + dtype.name: dtype + for dtype in ( + _bfloat16_dtype, + _float8_e4m3b11fnuz_dtype, + _float8_e4m3fn_dtype, + _float8_e4m3fnuz_dtype, + _float8_e5m2_dtype, + _float8_e5m2fnuz_dtype, + _float8_p3109_p3_dtype, + _float8_p3109_p4_dtype, + _float8_p3109_p5_dtype, + ) +} + class _Bfloat16MachArLike: @@ -487,36 +502,20 @@ def float_to_str(f): return obj def __new__(cls, dtype): - for type_str, test_dtype, constructor in ( - ("bfloat16", _bfloat16_dtype, cls._bfloat16_finfo), - ( - "float8_e4m3b11fnuz", - _float8_e4m3b11fnuz_dtype, - cls._float8_e4m3b11fnuz_finfo, - ), - ("float8_e4m3fn", _float8_e4m3fn_dtype, cls._float8_e4m3fn_finfo), - ("float8_e4m3fnuz", _float8_e4m3fnuz_dtype, cls._float8_e4m3fnuz_finfo), - ("float8_e5m2", _float8_e5m2_dtype, cls._float8_e5m2_finfo), - ("float8_e5m2fnuz", _float8_e5m2fnuz_dtype, cls._float8_e5m2fnuz_finfo), - ( - "float8_p3109_p3", - _float8_p3109_p3_dtype, - lambda: cls._float8_p3109_p_finfo(3), - ), - ( - "float8_p3109_p4", - _float8_p3109_p4_dtype, - lambda: cls._float8_p3109_p_finfo(4), - ), - ( - "float8_p3109_p5", - _float8_p3109_p5_dtype, - lambda: cls._float8_p3109_p_finfo(5), - ), + for ty, constructor in ( + (_bfloat16_dtype, cls._bfloat16_finfo), + (_float8_e4m3b11fnuz_dtype, cls._float8_e4m3b11fnuz_finfo), + (_float8_e4m3fn_dtype, cls._float8_e4m3fn_finfo), + (_float8_e4m3fnuz_dtype, cls._float8_e4m3fnuz_finfo), + (_float8_e5m2_dtype, cls._float8_e5m2_finfo), + (_float8_e5m2fnuz_dtype, cls._float8_e5m2fnuz_finfo), + (_float8_p3109_p3_dtype, lambda: cls._float8_p3109_p_finfo(3)), + (_float8_p3109_p4_dtype, lambda: cls._float8_p3109_p_finfo(4)), + (_float8_p3109_p5_dtype, lambda: cls._float8_p3109_p_finfo(5)), ): - if isinstance(dtype, str) and dtype == type_str or dtype == test_dtype: - if test_dtype not in cls._finfo_cache: - cls._finfo_cache[test_dtype] = constructor() - return cls._finfo_cache[test_dtype] + if isinstance(dtype, str) and dtype == ty.name or dtype == ty: + if ty not in cls._finfo_cache: + cls._finfo_cache[ty] = constructor() + return cls._finfo_cache[ty] return super().__new__(cls, dtype) diff --git a/ml_dtypes/tests/CMakeLists.txt b/ml_dtypes/tests/CMakeLists.txt new file mode 100644 index 00000000..3dd07e94 --- /dev/null +++ b/ml_dtypes/tests/CMakeLists.txt @@ -0,0 +1,42 @@ +cmake_minimum_required(VERSION 3.14) +project(my_project) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(GOOGLETEST_DOWNLOAD_URL https://github.com/google/googletest/archive/refs/tags/v1.12.0.zip) + +include(FetchContent) +FetchContent_Declare( + googletest + URL ${GOOGLETEST_DOWNLOAD_URL} +) +# For Windows: Prevent overriding the parent project's compiler/linker settings +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) + +set(ABSL_PROPAGATE_CXX_STD ON) +set(ABSL_GOOGLETEST_DOWNLOAD_URL ${GOOGLETEST_DOWNLOAD_URL}) +add_subdirectory(abseil-cpp) + +enable_testing() + +add_executable( + float8_test + float8_test.cc +) +target_include_directories(float8_test PUBLIC + .. + ../.. + ../../third_party/eigen +) + +target_link_libraries( + float8_test + GTest::gtest_main + GTest::gmock_main + absl::strings +) + +include(GoogleTest) +gtest_discover_tests(float8_test) From c4ab9f3d6959ea4514cb7d545223b0c0dedddb62 Mon Sep 17 00:00:00 2001 From: Andrew Fitzgibbon Date: Sat, 11 Nov 2023 10:46:42 +0000 Subject: [PATCH 4/4] Remove extraneous dict --- ml_dtypes/_finfo.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/ml_dtypes/_finfo.py b/ml_dtypes/_finfo.py index 7eb4aed4..e4c2fef9 100644 --- a/ml_dtypes/_finfo.py +++ b/ml_dtypes/_finfo.py @@ -38,21 +38,6 @@ _float8_p3109_p4_dtype = np.dtype(float8_p3109_p4) _float8_p3109_p5_dtype = np.dtype(float8_p3109_p5) -_name_to_dtype = { - dtype.name: dtype - for dtype in ( - _bfloat16_dtype, - _float8_e4m3b11fnuz_dtype, - _float8_e4m3fn_dtype, - _float8_e4m3fnuz_dtype, - _float8_e5m2_dtype, - _float8_e5m2fnuz_dtype, - _float8_p3109_p3_dtype, - _float8_p3109_p4_dtype, - _float8_p3109_p5_dtype, - ) -} - class _Bfloat16MachArLike: