Skip to content

Commit

Permalink
Refactor dtypes (#798)
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi authored Oct 23, 2024
1 parent 0a0802e commit cc3c8d9
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 91 deletions.
40 changes: 38 additions & 2 deletions sparse/mlir_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,43 @@

from . import levels
from ._conversions import asarray, from_constituent_arrays, to_numpy, to_scipy
from ._dtypes import asdtype
from ._dtypes import (
asdtype,
complex64,
complex128,
float16,
float32,
float64,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
)
from ._ops import add

__all__ = ["add", "asarray", "asdtype", "to_numpy", "to_scipy", "levels", "from_constituent_arrays"]
__all__ = [
"add",
"asarray",
"asdtype",
"to_numpy",
"to_scipy",
"levels",
"from_constituent_arrays",
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"float16",
"float32",
"float64",
"complex64",
"complex128",
]
2 changes: 1 addition & 1 deletion sparse/mlir_backend/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def ndim(self) -> int:
return len(self.shape)

@property
def dtype(self) -> type[DType]:
def dtype(self) -> DType:
return self._storage.get_storage_format().dtype

@property
Expand Down
4 changes: 2 additions & 2 deletions sparse/mlir_backend/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def fn_cache(f, maxsize: int | None = None):
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))


def get_nd_memref_descr(rank: int, dtype: type[DType]) -> ctypes.Structure:
def get_nd_memref_descr(rank: int, dtype: DType) -> ctypes.Structure:
return _get_nd_memref_descr(int(rank), asdtype(dtype))


@fn_cache
def _get_nd_memref_descr(rank: int, dtype: type[DType]) -> ctypes.Structure:
def _get_nd_memref_descr(rank: int, dtype: DType) -> ctypes.Structure:
return rt.make_nd_memref_descriptor(rank, dtype.to_ctype())


Expand Down
11 changes: 4 additions & 7 deletions sparse/mlir_backend/_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array:

level_props = LevelProperties(0)
if not arr.has_canonical_format:
level_props |= LevelProperties.NonOrdered | LevelProperties.NonUnique
level_props |= LevelProperties.NonOrdered

coo_format = get_storage_format(
levels=(
Expand All @@ -130,17 +130,14 @@ def to_scipy(arr: Array) -> ScipySparseArray:
case (Level(LevelFormat.Dense, _), Level(LevelFormat.Compressed, _)):
indptr, indices, data = arr.get_constituent_arrays()
if storage_format.order == (0, 1):
sps_arr = sps.csr_array((data, indices, indptr), shape=arr.shape)
else:
sps_arr = sps.csc_array((data, indices, indptr), shape=arr.shape)
return sps.csr_array((data, indices, indptr), shape=arr.shape)
return sps.csc_array((data, indices, indptr), shape=arr.shape)
case (Level(LevelFormat.Compressed, _), Level(LevelFormat.Singleton, _)):
_, coords, data = arr.get_constituent_arrays()
sps_arr = sps.coo_array((data, (coords[:, 0], coords[:, 1])), shape=arr.shape)
return sps.coo_array((data, (coords[:, 0], coords[:, 1])), shape=arr.shape)
case _:
raise RuntimeError(f"No conversion implemented for `{storage_format=}`.")

return sps_arr


def asarray(arr, copy: bool | None = None) -> Array:
if sps is not None and isinstance(arr, ScipySparseArray):
Expand Down
125 changes: 55 additions & 70 deletions sparse/mlir_backend/_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import abc
import inspect
import dataclasses
import math
import sys
import typing

import mlir.runtime as rt
from mlir import ir

import numpy as np


class MlirType(abc.ABC):
@classmethod
@abc.abstractmethod
def _get_mlir_type(cls) -> ir.Type: ...
def _get_mlir_type(self) -> ir.Type: ...


def _get_pointer_width() -> int:
Expand All @@ -22,106 +21,92 @@ def _get_pointer_width() -> int:
_PTR_WIDTH = _get_pointer_width()


def _make_int_classes(namespace: dict[str, object], bit_widths: typing.Iterable[int]) -> None:
for bw in bit_widths:

class SignedBW(SignedIntegerDType):
np_dtype = getattr(np, f"int{bw}")
bit_width = bw

@classmethod
def _get_mlir_type(cls):
return ir.IntegerType.get_signless(cls.bit_width)

SignedBW.__name__ = f"Int{bw}"
SignedBW.__module__ = __name__

class UnsignedBW(UnsignedIntegerDType):
np_dtype = getattr(np, f"uint{bw}")
bit_width = bw

@classmethod
def _get_mlir_type(cls):
return ir.IntegerType.get_signless(cls.bit_width)

UnsignedBW.__name__ = f"UInt{bw}"
UnsignedBW.__module__ = __name__

namespace[SignedBW.__name__] = SignedBW
namespace[UnsignedBW.__name__] = UnsignedBW


@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
class DType(MlirType):
np_dtype: np.dtype
bit_width: int

@classmethod
def to_ctype(cls):
return np.ctypeslib.as_ctypes_type(cls.np_dtype)

@property
@abc.abstractmethod
def np_dtype(self) -> np.dtype:
raise NotImplementedError

class FloatingDType(DType): ...
def to_ctype(self):
return rt.as_ctype(self.np_dtype)


class Float64(FloatingDType):
np_dtype = np.float64
bit_width = 64
@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
class IeeeRealFloatingDType(DType):
@property
def np_dtype(self) -> np.dtype:
return np.dtype(getattr(np, f"float{self.bit_width}"))

@classmethod
def _get_mlir_type(cls):
return ir.F64Type.get()
def _get_mlir_type(self) -> ir.Type:
return getattr(ir, f"F{self.bit_width}Type").get()


class Float32(FloatingDType):
np_dtype = np.float32
bit_width = 32
float64 = IeeeRealFloatingDType(bit_width=64)
float32 = IeeeRealFloatingDType(bit_width=32)
float16 = IeeeRealFloatingDType(bit_width=16)

@classmethod
def _get_mlir_type(cls):
return ir.F32Type.get()

@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
class IeeeComplexFloatingDType(DType):
@property
def np_dtype(self) -> np.dtype:
return np.dtype(getattr(np, f"complex{self.bit_width}"))

class Float16(FloatingDType):
np_dtype = np.float16
bit_width = 16
def _get_mlir_type(self) -> ir.Type:
return ir.ComplexType.get(getattr(ir, f"F{self.bit_width // 2}Type").get())

@classmethod
def _get_mlir_type(cls):
return ir.F16Type.get()

complex64 = IeeeComplexFloatingDType(bit_width=64)
complex128 = IeeeComplexFloatingDType(bit_width=128)

class IntegerDType(DType): ...

@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
class IntegerDType(DType):
def _get_mlir_type(self) -> ir.Type:
return ir.IntegerType.get_signless(self.bit_width)

class UnsignedIntegerDType(IntegerDType): ...

@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
class UnsignedIntegerDType(IntegerDType):
@property
def np_dtype(self) -> np.dtype:
return np.dtype(getattr(np, f"uint{self.bit_width}"))

class SignedIntegerDType(IntegerDType): ...

int8 = UnsignedIntegerDType(bit_width=8)
int16 = UnsignedIntegerDType(bit_width=16)
int32 = UnsignedIntegerDType(bit_width=32)
int64 = UnsignedIntegerDType(bit_width=64)

_make_int_classes(locals(), [8, 16, 32, 64])

@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
class SignedIntegerDType(IntegerDType):
@property
def np_dtype(self) -> np.dtype:
return np.dtype(getattr(np, f"int{self.bit_width}"))

class Index(DType):
np_dtype = np.intp

@classmethod
def _get_mlir_type(cls):
return ir.IndexType.get()
uint8 = SignedIntegerDType(bit_width=8)
uint16 = SignedIntegerDType(bit_width=16)
uint32 = SignedIntegerDType(bit_width=32)
uint64 = SignedIntegerDType(bit_width=64)


IntP: type[SignedIntegerDType] = locals()[f"Int{_PTR_WIDTH}"]
UIntP: type[UnsignedIntegerDType] = locals()[f"UInt{_PTR_WIDTH}"]
intp: SignedIntegerDType = locals()[f"int{_PTR_WIDTH}"]
uintp: UnsignedIntegerDType = locals()[f"uint{_PTR_WIDTH}"]


def isdtype(dt, /) -> bool:
return isinstance(dt, type) and issubclass(dt, DType) and not inspect.isabstract(dt)
return isinstance(dt, DType)


NUMPY_DTYPE_MAP = {np.dtype(dt.np_dtype): dt for dt in locals().values() if isdtype(dt)}


def asdtype(dt, /) -> type[DType]:
def asdtype(dt, /) -> DType:
if isdtype(dt):
return dt

Expand Down
17 changes: 12 additions & 5 deletions sparse/mlir_backend/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,33 @@
import mlir.execution_engine
import mlir.passmanager
from mlir import ir
from mlir.dialects import arith, func, linalg, sparse_tensor, tensor
from mlir.dialects import arith, complex, func, linalg, sparse_tensor, tensor

from ._array import Array
from ._common import fn_cache
from ._core import CWD, DEBUG, MLIR_C_RUNNER_UTILS, ctx, pm
from ._dtypes import DType, FloatingDType
from ._dtypes import DType, IeeeComplexFloatingDType, IeeeRealFloatingDType, IntegerDType


@fn_cache
def get_add_module(
a_tensor_type: ir.RankedTensorType,
b_tensor_type: ir.RankedTensorType,
out_tensor_type: ir.RankedTensorType,
dtype: type[DType],
dtype: DType,
rank: int,
) -> ir.Module:
with ir.Location.unknown(ctx):
module = ir.Module.create()
# TODO: add support for complex dialect/dtypes
arith_op = arith.AddFOp if issubclass(dtype, FloatingDType) else arith.AddIOp
if isinstance(dtype, IeeeRealFloatingDType):
arith_op = arith.AddFOp
elif isinstance(dtype, IeeeComplexFloatingDType):
arith_op = complex.AddOp
elif isinstance(dtype, IntegerDType):
arith_op = arith.AddIOp
else:
raise RuntimeError(f"Can not add {dtype=}.")

dtype = dtype._get_mlir_type()
ordering = ir.AffineMap.get_permutation(range(rank))

Expand Down
6 changes: 3 additions & 3 deletions sparse/mlir_backend/levels.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class StorageFormat:
order: tuple[int, ...]
pos_width: int
crd_width: int
dtype: type[DType]
dtype: DType

@property
def storage_rank(self) -> int:
Expand Down Expand Up @@ -162,7 +162,7 @@ def get_storage_format(
order: typing.Literal["C", "F"] | tuple[int, ...],
pos_width: int,
crd_width: int,
dtype: type[DType],
dtype: DType,
) -> StorageFormat:
levels = tuple(levels)
if isinstance(order, str):
Expand All @@ -186,7 +186,7 @@ def _get_storage_format(
order: tuple[int, ...],
pos_width: int,
crd_width: int,
dtype: type[DType],
dtype: DType,
) -> StorageFormat:
return StorageFormat(
levels=levels,
Expand Down
17 changes: 16 additions & 1 deletion sparse/mlir_backend/tests/test_simple.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import typing
from collections.abc import Iterable

import sparse

Expand All @@ -24,6 +25,8 @@
np.uint64,
np.float32,
np.float64,
np.complex64,
np.complex128,
],
)

Expand Down Expand Up @@ -67,6 +70,18 @@ def sampler_real_floating(size: tuple[int, ...]):

return sampler_real_floating

if np.issubdtype(dtype, np.complexfloating):
float_dtype = np.array(0, dtype=dtype).real.dtype

def sampler_complex_floating(size: tuple[int, ...]):
real_sampler = generate_sampler(float_dtype, rng)
if not isinstance(size, Iterable):
size = (size,)
float_arr = real_sampler(tuple(size) + (2,))
return float_arr.view(dtype)[..., 0]

return sampler_complex_floating

raise NotImplementedError(f"{dtype=} not yet supported.")


Expand Down Expand Up @@ -212,7 +227,7 @@ def test_coo_3d_format(dtype):
levels=(
sparse.levels.Level(sparse.levels.LevelFormat.Compressed, sparse.levels.LevelProperties.NonUnique),
sparse.levels.Level(sparse.levels.LevelFormat.Singleton, sparse.levels.LevelProperties.NonUnique),
sparse.levels.Level(sparse.levels.LevelFormat.Singleton, sparse.levels.LevelProperties.NonUnique),
sparse.levels.Level(sparse.levels.LevelFormat.Singleton, sparse.levels.LevelProperties(0)),
),
order="C",
pos_width=64,
Expand Down

0 comments on commit cc3c8d9

Please sign in to comment.