Skip to content

Commit

Permalink
Improve int4 constexpr-ness, add more operators, numeric_limits.
Browse files Browse the repository at this point in the history
This is to allow better support for int4 in C++.

PiperOrigin-RevId: 561481591
  • Loading branch information
cantonios authored and The ml_dtypes Authors committed Aug 31, 2023
1 parent 54e375a commit 0804294
Show file tree
Hide file tree
Showing 8 changed files with 682 additions and 102 deletions.
17 changes: 8 additions & 9 deletions ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,16 @@

from typing import Type

from ml_dtypes._custom_floats import bfloat16
from ml_dtypes._custom_floats import float8_e4m3b11fnuz
from ml_dtypes._custom_floats import float8_e4m3fn
from ml_dtypes._custom_floats import float8_e4m3fnuz
from ml_dtypes._custom_floats import float8_e5m2
from ml_dtypes._custom_floats import float8_e5m2fnuz
from ml_dtypes._custom_floats import int4
from ml_dtypes._custom_floats import uint4
from ml_dtypes._finfo import finfo
from ml_dtypes._iinfo import iinfo

from ml_dtypes._ml_dtypes_lib import bfloat16
from ml_dtypes._ml_dtypes_lib import float8_e4m3b11fnuz
from ml_dtypes._ml_dtypes_lib import float8_e4m3fn
from ml_dtypes._ml_dtypes_lib import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_lib import float8_e5m2
from ml_dtypes._ml_dtypes_lib import float8_e5m2fnuz
from ml_dtypes._ml_dtypes_lib import int4
from ml_dtypes._ml_dtypes_lib import uint4
import numpy as np

bfloat16: Type[np.generic]
Expand Down
13 changes: 6 additions & 7 deletions ml_dtypes/_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@

from typing import Dict

from ml_dtypes._custom_floats import bfloat16
from ml_dtypes._custom_floats import float8_e4m3b11fnuz
from ml_dtypes._custom_floats import float8_e4m3fn
from ml_dtypes._custom_floats import float8_e4m3fnuz
from ml_dtypes._custom_floats import float8_e5m2
from ml_dtypes._custom_floats import float8_e5m2fnuz

from ml_dtypes._ml_dtypes_lib import bfloat16
from ml_dtypes._ml_dtypes_lib import float8_e4m3b11fnuz
from ml_dtypes._ml_dtypes_lib import float8_e4m3fn
from ml_dtypes._ml_dtypes_lib import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_lib import float8_e5m2
from ml_dtypes._ml_dtypes_lib import float8_e5m2fnuz
import numpy as np

_bfloat16_dtype = np.dtype(bfloat16)
Expand Down
5 changes: 2 additions & 3 deletions ml_dtypes/_iinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

"""Overload of numpy.iinfo to handle dtypes defined in ml_dtypes."""

from ml_dtypes._custom_floats import int4
from ml_dtypes._custom_floats import uint4

from ml_dtypes._ml_dtypes_lib import int4
from ml_dtypes._ml_dtypes_lib import uint4
import numpy as np

_int4_dtype = np.dtype(int4)
Expand Down
9 changes: 5 additions & 4 deletions ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ limitations under the License.

#include "Eigen/Core"
#include "_src/custom_float.h"
#include "_src/int4.h"
#include "_src/int4_numpy.h"
#include "include/float8.h"
#include "include/int4.h"

namespace ml_dtypes {

Expand Down Expand Up @@ -297,7 +298,7 @@ bool Initialize() {

static PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"_custom_floats",
"_ml_dtypes_lib",
};

// TODO(phawkins): PyMODINIT_FUNC handles visibility correctly in Python 3.9+.
Expand All @@ -308,14 +309,14 @@ static PyModuleDef module_def = {
#define EXPORT_SYMBOL __attribute__((visibility("default")))
#endif

extern "C" EXPORT_SYMBOL PyObject* PyInit__custom_floats() {
extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_lib() {
Safe_PyObjectPtr m = make_safe(PyModule_Create(&module_def));
if (!m) {
return nullptr;
}
if (!Initialize()) {
if (!PyErr_Occurred()) {
PyErr_SetString(PyExc_RuntimeError, "cannot load _custom_floats module.");
PyErr_SetString(PyExc_RuntimeError, "cannot load _ml_dtypes_lib module.");
}
return nullptr;
}
Expand Down
82 changes: 4 additions & 78 deletions ml_dtypes/_src/int4.h → ml_dtypes/_src/int4_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,95 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef ML_DTYPES_INT4_H_
#define ML_DTYPES_INT4_H_
#ifndef ML_DTYPES_INT4_NUMPY_H_
#define ML_DTYPES_INT4_NUMPY_H_

// Must be included first
// clang-format off
#include "_src/numpy.h"
// clang-format on

#include <cstdint> //NOLINT
#include <optional> //NOLINT
#include <ostream> //NOLINT
#include <sstream> //NOLINT

#include "Eigen/Core"
#include "_src/common.h" // NOLINT
#include "_src/ufuncs.h" // NOLINT
#include "include/int4.h"

namespace ml_dtypes {

template <typename UnderlyingTy>
struct i4 {
private:
UnderlyingTy v : 4;

public:
i4() : v(0) {}
explicit i4(UnderlyingTy val) : v(val & 0x0F) {}
template <typename T>
explicit i4(T t) : i4(static_cast<UnderlyingTy>(t)) {}
i4(const i4& other) = default;

static constexpr i4 lowest() {
return std::is_signed<UnderlyingTy>::value ? i4(-8) : i4(0);
}
static constexpr i4 highest() {
return std::is_signed<UnderlyingTy>::value ? i4(7) : i4(15);
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
explicit operator T() const {
return static_cast<T>(v);
}
// NOLINTNEXTLINE(google-explicit-constructor)
operator std::optional<int64_t>() const { return static_cast<int64_t>(v); }

i4 operator-() const { return i4(-v); }
i4 operator+(const i4& other) const { return i4((v + other.v)); }
i4 operator-(const i4& other) const { return i4((v - other.v)); }
i4 operator*(const i4& other) const { return i4((v * other.v)); }
i4 operator/(const i4& other) const { return i4((v / other.v)); }
i4 operator%(const i4& other) const { return i4((v % other.v)); }

i4 operator>>(const int amount) const { return i4((v >> amount)); }
i4 operator<<(const int amount) const { return i4((v << amount)); }

bool operator==(const i4& other) const { return v == other.v; }
bool operator!=(const i4& other) const { return v != other.v; }
bool operator<(const i4& other) const { return v < other.v; }
bool operator>(const i4& other) const { return v > other.v; }
bool operator<=(const i4& other) const { return v <= other.v; }
bool operator>=(const i4& other) const { return v >= other.v; }

bool operator==(const int64_t other) const { return v == other; }
bool operator!=(const int64_t other) const { return v != other; }
bool operator<(const int64_t other) const { return v < other; }
bool operator>(const int64_t other) const { return v > other; }
bool operator<=(const int64_t other) const { return v <= other; }
bool operator>=(const int64_t other) const { return v >= other; }

i4& operator++() {
v = (v + 1) & 0x0F;
return *this;
}

friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) {
os << static_cast<int16_t>(num.v);
return os;
}

std::string ToString() const {
std::ostringstream os;
os << static_cast<int16_t>(v);
return os.str();
}
};

using int4 = i4<int8_t>;
using uint4 = i4<uint8_t>;

template <typename T>
struct Int4TypeDescriptor {
static int Dtype() { return npy_type; }
Expand Down Expand Up @@ -878,4 +804,4 @@ bool RegisterInt4Dtype(PyObject* numpy) {

} // namespace ml_dtypes

#endif // ML_DTYPES_INT4_H_
#endif // ML_DTYPES_INT4_NUMPY_H_
Loading

0 comments on commit 0804294

Please sign in to comment.