Skip to content

Commit

Permalink
Add SciPy conversions.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Oct 17, 2024
1 parent 1a20262 commit 6ab8320
Show file tree
Hide file tree
Showing 10 changed files with 311 additions and 644 deletions.
4 changes: 2 additions & 2 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ mkdocs-jupyter = "*"

[feature.tests.tasks]
test = "pytest --pyargs sparse -n auto"
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -vvv", env = { SPARSE_BACKEND = "MLIR" } }
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -vvv", env = { SPARSE_BACKEND = "Finch", PYTHONFAULTHANDLER = "${HOME}/faulthandler.log" }, depends-on = ["precompile"] }
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -v", env = { SPARSE_BACKEND = "MLIR" } }
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -v", env = { SPARSE_BACKEND = "Finch", PYTHONFAULTHANDLER = "${HOME}/faulthandler.log" }, depends-on = ["precompile"] }

[feature.tests.dependencies]
pytest = ">=3.5"
Expand Down
12 changes: 4 additions & 8 deletions sparse/mlir_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,20 @@
"to enable MLIR backend."
) from e

from ._constructors import (
PackedArgumentTuple,
asarray,
)
from ._common import PackedArgumentTuple
from ._conversions import asarray, to_numpy, to_scipy
from ._dtypes import (
asdtype,
)
from ._ops import (
add,
broadcast_to,
reshape,
)

__all__ = [
"add",
"broadcast_to",
"asarray",
"asdtype",
"reshape",
"PackedArgumentTuple",
"to_numpy",
"to_scipy",
]
36 changes: 36 additions & 0 deletions sparse/mlir_backend/_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import dataclasses

from ._common import _hold_ref, numpy_to_ranked_memref, ranked_memref_to_numpy
from ._levels import StorageFormat


class Array:
def __init__(self, *, storage, shape: tuple[int, ...]) -> None:
storage_rank = storage.get_storage_format().rank
if len(shape) != storage_rank:
raise ValueError(f"Mismatched rank, `{storage_rank=}`, `{shape=}`")

Check warning on line 11 in sparse/mlir_backend/_array.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_array.py#L11

Added line #L11 was not covered by tests

self._storage = storage
self._shape = shape

@property
def shape(self) -> tuple[int, ...]:
return self._shape

@property
def ndim(self) -> int:
return len(self.shape)

Check warning on line 22 in sparse/mlir_backend/_array.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_array.py#L22

Added line #L22 was not covered by tests

@property
def dtype(self):
return self._storage.get_storage_format().dtype

Check warning on line 26 in sparse/mlir_backend/_array.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_array.py#L26

Added line #L26 was not covered by tests

def copy(self):
storage_format: StorageFormat = dataclasses.replace(self._storage.get_storage_format(), owns_memory=False)

Check warning on line 29 in sparse/mlir_backend/_array.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_array.py#L29

Added line #L29 was not covered by tests

fields = self._storage.get__fields_()
arrs = [ranked_memref_to_numpy(f).copy() for f in fields]
memrefs = [numpy_to_ranked_memref(arr) for arr in arrs]
arr = Array(storage=storage_format.get_ctypes_type()(*memrefs), shape=self.shape)
for carr in arrs:
_hold_ref(arr, carr)

Check warning on line 36 in sparse/mlir_backend/_array.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_array.py#L31-L36

Added lines #L31 - L36 were not covered by tests
18 changes: 7 additions & 11 deletions sparse/mlir_backend/_common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import abc
import ctypes
import functools
import weakref
from dataclasses import dataclass

import mlir.runtime as rt
from mlir import ir

import numpy as np

Expand All @@ -17,8 +15,12 @@ def fn_cache(f, maxsize: int | None = None):
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))


@fn_cache
def get_nd_memref_descr(rank: int, dtype: type[DType]) -> ctypes.Structure:
return _get_nd_memref_descr(int(rank), asdtype(dtype))


@fn_cache
def _get_nd_memref_descr(rank: int, dtype: type[DType]) -> ctypes.Structure:
return rt.make_nd_memref_descriptor(rank, dtype.to_ctype())


Expand All @@ -43,12 +45,6 @@ def free_memref(obj: ctypes.Structure) -> None:
libc.free(ctypes.cast(obj.allocated, ctypes.c_void_p))

Check warning on line 45 in sparse/mlir_backend/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_common.py#L45

Added line #L45 was not covered by tests


class MlirType(abc.ABC):
@classmethod
@abc.abstractmethod
def get_mlir_type(cls) -> ir.Type: ...


@dataclass
class PackedArgumentTuple:
contents: tuple
Expand All @@ -67,13 +63,13 @@ def _hold_self_ref_in_ret(fn):
@functools.wraps(fn)
def wrapped(self, *a, **kw):
ret = fn(self, *a, **kw)
_take_owneship(ret, self)
_hold_ref(ret, self)

Check warning on line 66 in sparse/mlir_backend/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_common.py#L66

Added line #L66 was not covered by tests
return ret

return wrapped


def _take_owneship(owner, obj):
def _hold_ref(owner, obj):
ptr = ctypes.py_object(obj)
ctypes.pythonapi.Py_IncRef(ptr)

Expand Down
Loading

0 comments on commit 6ab8320

Please sign in to comment.