Skip to content

Commit

Permalink
Format specification API (#792)
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi authored Oct 22, 2024
1 parent ef53f7d commit 0a0802e
Show file tree
Hide file tree
Showing 10 changed files with 564 additions and 733 deletions.
13 changes: 11 additions & 2 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ mkdocs-jupyter = "*"

[feature.tests.tasks]
test = "pytest --pyargs sparse -n auto"
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -vvv", env = { SPARSE_BACKEND = "MLIR" } }
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto", env = { SPARSE_BACKEND = "Finch" }, depends-on = ["precompile"] }
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -v" }
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -v", depends-on = ["precompile"] }

[feature.tests.dependencies]
pytest = ">=3.5"
Expand All @@ -51,10 +51,19 @@ precompile = "python -c 'import finch'"
scipy = ">=0.19"
finch-tensor = ">=0.1.31"

[feature.finch.activation.env]
SPARSE_BACKEND = "Finch"

[feature.finch.target.osx-arm64.activation.env]
PYTHONFAULTHANDLER = "${HOME}/faulthandler.log"

[feature.mlir.dependencies]
scipy = ">=0.19"
mlir-python-bindings = "19.*"

[feature.mlir.activation.env]
SPARSE_BACKEND = "MLIR"

[environments]
tests = ["tests", "extras"]
docs = ["docs", "extras"]
Expand Down
27 changes: 7 additions & 20 deletions sparse/mlir_backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,17 @@
try:
import mlir # noqa: F401

del mlir
except ModuleNotFoundError as e:
raise ImportError(
"MLIR Python bindings not installed. Run "
"`conda install conda-forge::mlir-python-bindings` "
"to enable MLIR backend."
) from e

from ._constructors import (
PackedArgumentTuple,
asarray,
)
from ._dtypes import (
asdtype,
)
from ._ops import (
add,
broadcast_to,
reshape,
)
from . import levels
from ._conversions import asarray, from_constituent_arrays, to_numpy, to_scipy
from ._dtypes import asdtype
from ._ops import add

__all__ = [
"add",
"broadcast_to",
"asarray",
"asdtype",
"reshape",
"PackedArgumentTuple",
]
__all__ = ["add", "asarray", "asdtype", "to_numpy", "to_scipy", "levels", "from_constituent_arrays"]
45 changes: 45 additions & 0 deletions sparse/mlir_backend/_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np

from ._dtypes import DType
from .levels import StorageFormat


class Array:
def __init__(self, *, storage, shape: tuple[int, ...]) -> None:
storage_rank = storage.get_storage_format().rank
if len(shape) != storage_rank:
raise ValueError(f"Mismatched rank, `{storage_rank=}`, `{shape=}`")

self._storage = storage
self._shape = shape

@property
def shape(self) -> tuple[int, ...]:
return self._shape

@property
def ndim(self) -> int:
return len(self.shape)

@property
def dtype(self) -> type[DType]:
return self._storage.get_storage_format().dtype

@property
def format(self) -> StorageFormat:
return self._storage.get_storage_format()

def _get_mlir_type(self):
return self.format._get_mlir_type(shape=self.shape)

def _to_module_arg(self):
return self._storage.to_module_arg()

def copy(self) -> "Array":
from ._conversions import from_constituent_arrays

arrs = tuple(arr.copy() for arr in self.get_constituent_arrays())
return from_constituent_arrays(format=self.format, arrays=arrs, shape=self.shape)

def get_constituent_arrays(self) -> tuple[np.ndarray, ...]:
return self._storage.get_constituent_arrays()
55 changes: 29 additions & 26 deletions sparse/mlir_backend/_common.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,50 @@
import abc
import ctypes
import functools
import weakref
from dataclasses import dataclass

from mlir import ir
import mlir.runtime as rt

import numpy as np

class MlirType(abc.ABC):
@classmethod
@abc.abstractmethod
def get_mlir_type(cls) -> ir.Type: ...
from ._core import libc
from ._dtypes import DType, asdtype


@dataclass
class PackedArgumentTuple:
contents: tuple
def fn_cache(f, maxsize: int | None = None):
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))

def __getitem__(self, index):
return self.contents[index]

def __iter__(self):
yield from self.contents
def get_nd_memref_descr(rank: int, dtype: type[DType]) -> ctypes.Structure:
return _get_nd_memref_descr(int(rank), asdtype(dtype))

def __len__(self):
return len(self.contents)

@fn_cache
def _get_nd_memref_descr(rank: int, dtype: type[DType]) -> ctypes.Structure:
return rt.make_nd_memref_descriptor(rank, dtype.to_ctype())


def numpy_to_ranked_memref(arr: np.ndarray) -> ctypes.Structure:
memref = rt.get_ranked_memref_descriptor(arr)
memref_descr = get_nd_memref_descr(arr.ndim, asdtype(arr.dtype))
# Required due to ctypes type checks
return memref_descr(
allocated=memref.allocated,
aligned=memref.aligned,
offset=memref.offset,
shape=memref.shape,
strides=memref.strides,
)

def fn_cache(f, maxsize: int | None = None):
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))

def ranked_memref_to_numpy(ref: ctypes.Structure) -> np.ndarray:
return rt.ranked_memref_to_numpy([ref])

def _hold_self_ref_in_ret(fn):
@functools.wraps(fn)
def wrapped(self, *a, **kw):
ret = fn(self, *a, **kw)
_take_owneship(ret, self)
return ret

return wrapped
def free_memref(obj: ctypes.Structure) -> None:
libc.free(ctypes.cast(obj.allocated, ctypes.c_void_p))


def _take_owneship(owner, obj):
def _hold_ref(owner, obj):
ptr = ctypes.py_object(obj)
ctypes.pythonapi.Py_IncRef(ptr)

Expand Down
Loading

0 comments on commit 0a0802e

Please sign in to comment.