-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: is_lazy_array() #228
base: main
Are you sure you want to change the base?
ENH: is_lazy_array() #228
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -824,6 +824,63 @@ def is_writeable_array(x) -> bool: | |
return True | ||
|
||
|
||
def is_lazy_array(x) -> bool: | ||
"""Return True if x is potentially a future or it may be otherwise impossible or | ||
expensive to eagerly read its contents, regardless of their size, e.g. by | ||
calling ``bool(x)`` or ``float(x)``. | ||
|
||
Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be | ||
cheap as long as the array has the right dtype. | ||
|
||
Note | ||
---- | ||
This function errs on the side of caution for array types that may or may not be | ||
lazy, e.g. JAX arrays, by always returning True for them. | ||
""" | ||
if ( | ||
is_numpy_array(x) | ||
or is_cupy_array(x) | ||
or is_torch_array(x) | ||
or is_pydata_sparse_array(x) | ||
): | ||
return False | ||
|
||
# **JAX note:** while it is possible to determine if you're inside or outside | ||
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool() | ||
# as we do below for unknown arrays, this is not recommended by JAX best practices. | ||
|
||
# **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on. | ||
# This behaviour, while impossible to change without breaking backwards | ||
# compatibility, is highly detrimental to performance as the whole graph will end | ||
# up being computed multiple times. | ||
|
||
if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x): | ||
return True | ||
|
||
# Unknown Array API compatible object. Note that this test may have dire consequences | ||
# in terms of performance, e.g. for a lazy object that eagerly computes the graph | ||
# on __bool__ (dask is one such example, which however is special-cased above). | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. array_api_strict reaches this point |
||
# Select a single point of the array | ||
s = size(x) | ||
if s is None or math.isnan(s): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be simplified after #231 |
||
return True | ||
xp = array_namespace(x) | ||
if s > 1: | ||
x = xp.reshape(x, (-1,))[0] | ||
# Cast to dtype=bool and deal with size 0 arrays | ||
x = xp.any(x) | ||
|
||
try: | ||
bool(x) | ||
return False | ||
# The Array API standard dictactes that __bool__ should raise TypeError if the | ||
# output cannot be defined. | ||
# Here we allow for it to raise arbitrary exceptions, e.g. like Dask does. | ||
except Exception: | ||
return True | ||
|
||
|
||
__all__ = [ | ||
"array_namespace", | ||
"device", | ||
|
@@ -845,6 +902,7 @@ def is_writeable_array(x) -> bool: | |
"is_pydata_sparse_array", | ||
"is_pydata_sparse_namespace", | ||
"is_writeable_array", | ||
"is_lazy_array", | ||
"size", | ||
"to_device", | ||
] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,10 @@ | ||
import math | ||
|
||
import pytest | ||
import numpy as np | ||
import array | ||
from numpy.testing import assert_allclose | ||
|
||
from array_api_compat import ( # noqa: F401 | ||
is_numpy_array, is_cupy_array, is_torch_array, | ||
is_dask_array, is_jax_array, is_pydata_sparse_array, | ||
|
@@ -6,15 +13,10 @@ | |
) | ||
|
||
from array_api_compat import ( | ||
device, is_array_api_obj, is_writeable_array, size, to_device | ||
device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device | ||
) | ||
from ._helpers import import_, wrapped_libraries, all_libraries | ||
|
||
import pytest | ||
import numpy as np | ||
import array | ||
from numpy.testing import assert_allclose | ||
|
||
is_array_functions = { | ||
'numpy': 'is_numpy_array', | ||
'cupy': 'is_cupy_array', | ||
|
@@ -115,6 +117,45 @@ def test_size_none(library): | |
assert size(x) in (None, 5) | ||
|
||
|
||
@pytest.mark.parametrize("library", all_libraries) | ||
def test_is_lazy_array(library): | ||
lib = import_(library) | ||
x = lib.asarray([1, 2, 3]) | ||
assert isinstance(is_lazy_array(x), bool) | ||
|
||
|
||
@pytest.mark.parametrize("shape", [(math.nan,), (1, math.nan), (None, ), (1, None)]) | ||
def test_is_lazy_array_nan_size(shape, monkeypatch): | ||
"""Test is_lazy_array() on an unknown Array API compliant object | ||
with NaN (like Dask) or None (like ndonnx) in its shape | ||
Comment on lines
+127
to
+130
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be simplified after #231 |
||
""" | ||
xp = import_("array_api_strict") | ||
x = xp.asarray(1) | ||
assert not is_lazy_array(x) | ||
monkeypatch.setattr(type(x), "shape", shape) | ||
assert is_lazy_array(x) | ||
|
||
|
||
@pytest.mark.parametrize("exc", [TypeError, AssertionError]) | ||
def test_is_lazy_array_bool_raises(exc, monkeypatch): | ||
"""Test is_lazy_array() on an unknown Array API compliant object | ||
where calling bool() raises: | ||
- TypeError: e.g. like jitted JAX. This is the proper exception which | ||
lazy arrays should raise as per the Array API specification | ||
- something else: e.g. like Dask, where bool() triggers compute() | ||
which can result in any kind of exception to be raised | ||
""" | ||
xp = import_("array_api_strict") | ||
x = xp.asarray(1) | ||
assert not is_lazy_array(x) | ||
|
||
def __bool__(self): | ||
raise exc("Hello world") | ||
|
||
monkeypatch.setattr(type(x), "__bool__", __bool__) | ||
assert is_lazy_array(x) | ||
|
||
|
||
@pytest.mark.parametrize("library", all_libraries) | ||
def test_device(library): | ||
xp = import_(library, wrapper=True) | ||
|
@@ -172,6 +213,7 @@ def test_asarray_cross_library(source_library, target_library, request): | |
|
||
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}" | ||
|
||
|
||
@pytest.mark.parametrize("library", wrapped_libraries) | ||
def test_asarray_copy(library): | ||
# Note, we have this test here because the test suite currently doesn't | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good argument for replacing everything below this point with a blind
return False
. Please discuss if you'd prefer it that way.