Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Format specification API #792

Merged
merged 16 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ mkdocs-jupyter = "*"

[feature.tests.tasks]
test = "pytest --pyargs sparse -n auto"
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -vvv", env = { SPARSE_BACKEND = "MLIR" } }
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto", env = { SPARSE_BACKEND = "Finch" }, depends-on = ["precompile"] }
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -v" }
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -v", depends-on = ["precompile"] }

[feature.tests.dependencies]
pytest = ">=3.5"
Expand All @@ -51,10 +51,19 @@ precompile = "python -c 'import finch'"
scipy = ">=0.19"
finch-tensor = ">=0.1.31"

[feature.finch.activation.env]
SPARSE_BACKEND = "Finch"

[feature.finch.target.osx-arm64.activation.env]
PYTHONFAULTHANDLER = "${HOME}/faulthandler.log"

[feature.mlir.dependencies]
scipy = ">=0.19"
mlir-python-bindings = "19.*"

[feature.mlir.activation.env]
SPARSE_BACKEND = "MLIR"

[environments]
tests = ["tests", "extras"]
docs = ["docs", "extras"]
Expand Down
27 changes: 7 additions & 20 deletions sparse/mlir_backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,17 @@
try:
import mlir # noqa: F401

del mlir
except ModuleNotFoundError as e:
raise ImportError(
"MLIR Python bindings not installed. Run "
"`conda install conda-forge::mlir-python-bindings` "
"to enable MLIR backend."
) from e

from ._constructors import (
PackedArgumentTuple,
asarray,
)
from ._dtypes import (
asdtype,
)
from ._ops import (
add,
broadcast_to,
reshape,
)
from . import levels
from ._conversions import asarray, from_constituent_arrays, to_numpy, to_scipy
from ._dtypes import asdtype
from ._ops import add

__all__ = [
"add",
"broadcast_to",
"asarray",
"asdtype",
"reshape",
"PackedArgumentTuple",
]
__all__ = ["add", "asarray", "asdtype", "to_numpy", "to_scipy", "levels", "from_constituent_arrays"]
45 changes: 45 additions & 0 deletions sparse/mlir_backend/_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np

from ._dtypes import DType
from .levels import StorageFormat


class Array:
def __init__(self, *, storage, shape: tuple[int, ...]) -> None:
storage_rank = storage.get_storage_format().rank
if len(shape) != storage_rank:
raise ValueError(f"Mismatched rank, `{storage_rank=}`, `{shape=}`")

self._storage = storage
self._shape = shape

Check warning on line 14 in sparse/mlir_backend/_array.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_array.py#L11-L14

Added lines #L11 - L14 were not covered by tests

@property
def shape(self) -> tuple[int, ...]:
return self._shape

@property
def ndim(self) -> int:
return len(self.shape)

@property
def dtype(self) -> type[DType]:
return self._storage.get_storage_format().dtype

@property
def format(self) -> StorageFormat:
return self._storage.get_storage_format()

def _get_mlir_type(self):
return self.format._get_mlir_type(shape=self.shape)

def _to_module_arg(self):
return self._storage.to_module_arg()

def copy(self) -> "Array":
from ._conversions import from_constituent_arrays

arrs = tuple(arr.copy() for arr in self.get_constituent_arrays())
return from_constituent_arrays(format=self.format, arrays=arrs, shape=self.shape)

def get_constituent_arrays(self) -> tuple[np.ndarray, ...]:
return self._storage.get_constituent_arrays()
55 changes: 29 additions & 26 deletions sparse/mlir_backend/_common.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,50 @@
import abc
import ctypes
import functools
import weakref
from dataclasses import dataclass

from mlir import ir
import mlir.runtime as rt

import numpy as np

class MlirType(abc.ABC):
@classmethod
@abc.abstractmethod
def get_mlir_type(cls) -> ir.Type: ...
from ._core import libc
from ._dtypes import DType, asdtype


@dataclass
class PackedArgumentTuple:
contents: tuple
def fn_cache(f, maxsize: int | None = None):
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))

def __getitem__(self, index):
return self.contents[index]

def __iter__(self):
yield from self.contents
def get_nd_memref_descr(rank: int, dtype: type[DType]) -> ctypes.Structure:
return _get_nd_memref_descr(int(rank), asdtype(dtype))

def __len__(self):
return len(self.contents)

@fn_cache
def _get_nd_memref_descr(rank: int, dtype: type[DType]) -> ctypes.Structure:
return rt.make_nd_memref_descriptor(rank, dtype.to_ctype())


def numpy_to_ranked_memref(arr: np.ndarray) -> ctypes.Structure:
memref = rt.get_ranked_memref_descriptor(arr)
memref_descr = get_nd_memref_descr(arr.ndim, asdtype(arr.dtype))
# Required due to ctypes type checks
return memref_descr(
allocated=memref.allocated,
aligned=memref.aligned,
offset=memref.offset,
shape=memref.shape,
strides=memref.strides,
)

def fn_cache(f, maxsize: int | None = None):
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))

def ranked_memref_to_numpy(ref: ctypes.Structure) -> np.ndarray:
return rt.ranked_memref_to_numpy([ref])

def _hold_self_ref_in_ret(fn):
@functools.wraps(fn)
def wrapped(self, *a, **kw):
ret = fn(self, *a, **kw)
_take_owneship(ret, self)
return ret

return wrapped
def free_memref(obj: ctypes.Structure) -> None:
libc.free(ctypes.cast(obj.allocated, ctypes.c_void_p))


def _take_owneship(owner, obj):
def _hold_ref(owner, obj):
ptr = ctypes.py_object(obj)
ctypes.pythonapi.Py_IncRef(ptr)

Expand Down
Loading
Loading