Skip to content

Commit

Permalink
Introduced wrap generic to dispatch construction to third-party subcl…
Browse files Browse the repository at this point in the history
…asses. (#54)

This is necessary as we can't dispatch to different subclasses  in DelayedArray's
own __init__ method. So, developers need to write their own method for wrap()
in order to construct a custom DelayedArray subclass based on their seed type.
  • Loading branch information
LTLA authored Sep 27, 2023
1 parent 27e529d commit 102bae5
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## Version 0.3.1

- Added a `wrap()` method that dispatches to different `DelayedArray` subclasses based on the seed.

## Version 0.3.0

- Replace the `__DelayedArray` methods with generics, for easier extensibility to classes outside of our control.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ We can wrap this in a `DelayedArray` class:

```python
import delayedarray
d = delayedarray.DelayedArray(x)
d = delayedarray.wrap(x)
## <100 x 20> DelayedArray object of type 'float64'
## [[0.58969193, 0.36342181, 0.03111773, ..., 0.72036247, 0.40297173,
## 0.48654955],
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@
autodoc_default_options = {
'special-members': True,
'undoc-members': False,
'exclude-members': '__weakref__, __dict__, __str__, __module__, __init__'
'exclude-members': '__weakref__, __dict__, __str__, __module__'
}

autosummary_generate = True
Expand Down
15 changes: 10 additions & 5 deletions src/delayedarray/DelayedArray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Tuple, Union, Any
import numpy
from numpy import array, dtype, integer, issubdtype, ndarray, prod, array2string

Expand Down Expand Up @@ -85,12 +85,17 @@ class DelayedArray:
- A method for the
:py:meth:`~delayedarray.create_dask_array.create_dask_array` generic,
if the seed is not already compatible with the **dask** package.
Attributes:
seed: Any array-like object that satisfies the seed contract.
"""

def __init__(self, seed):
def __init__(self, seed: Any):
"""Construct a ``DelayedArray`` object from a seed.
Most users are advised to use :py:meth:`~delayedarray.wrap.wrap`
instead, as this can be specialized by developers to construct
subclasses that are optimized for custom seed types.
Args:
seed: Any array-like object that satisfies the seed contract.
"""
self._seed = seed

@property
Expand Down
1 change: 1 addition & 0 deletions src/delayedarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@
from .chunk_shape import chunk_shape
from .is_pristine import is_pristine
from .guess_iteration_block_size import guess_iteration_block_size
from .wrap import wrap
27 changes: 27 additions & 0 deletions src/delayedarray/wrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from functools import singledispatch
from typing import Any

from .DelayedArray import DelayedArray


@singledispatch
def wrap(x: Any) -> DelayedArray:
"""Create a :py:class:`~delayedarray.DelayedArray.DelayedArray` from an
object satisfying the seed contract. Developers can implement methods for
this generic to create ``DelayedArray`` subclasses based on the seed type.
Args:
x:
Any object satisfiying the seed contract, see documentation for
:py:class:`~delayedarray.DelayedArray.DelayedArray` for details.
Returns:
A ``DelayedArray`` or one of its subclasses.
"""
return DelayedArray(x)


@wrap.register
def wrap_DelayedArray(x: DelayedArray):
"""See :py:meth:`~delayedarray.wrap.wrap`."""
return x
9 changes: 9 additions & 0 deletions tests/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,12 @@ def test_ndarray_colmajor():

out = str(x)
assert out.find("<40 x 30> DelayedArray object of type 'float64'") != -1


def test_ndarray_wrap():
raw = numpy.random.rand(30, 40)
x = delayedarray.wrap(raw)
assert isinstance(x, delayedarray.DelayedArray)
assert x.shape == raw.shape
x = delayedarray.wrap(x)
assert isinstance(x, delayedarray.DelayedArray)

0 comments on commit 102bae5

Please sign in to comment.