Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: is_lazy_array() #228

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,63 @@ def is_writeable_array(x) -> bool:
return True


def is_lazy_array(x) -> bool:
"""Return True if x is potentially a future or it may be otherwise impossible or
expensive to eagerly read its contents, regardless of their size, e.g. by
calling ``bool(x)`` or ``float(x)``.

Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be
cheap as long as the array has the right dtype.

Note
----
This function errs on the side of caution for array types that may or may not be
lazy, e.g. JAX arrays, by always returning True for them.
"""
if (
is_numpy_array(x)
or is_cupy_array(x)
or is_torch_array(x)
or is_pydata_sparse_array(x)
):
return False

# **JAX note:** while it is possible to determine if you're inside or outside
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
# as we do below for unknown arrays, this is not recommended by JAX best practices.

# **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on.
# This behaviour, while impossible to change without breaking backwards
# compatibility, is highly detrimental to performance as the whole graph will end
# up being computed multiple times.

if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
return True

# Unknown Array API compatible object. Note that this test may have dire consequences
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
# on __bool__ (dask is one such example, which however is special-cased above).
Comment on lines +860 to +862
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good argument for replacing everything below this point with a blind return False. Please discuss if you'd prefer it that way.


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

array_api_strict reaches this point

# Select a single point of the array
s = size(x)
if s is None or math.isnan(s):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be simplified after #231

return True
xp = array_namespace(x)
if s > 1:
x = xp.reshape(x, (-1,))[0]
# Cast to dtype=bool and deal with size 0 arrays
x = xp.any(x)

try:
bool(x)
return False
# The Array API standard dictactes that __bool__ should raise TypeError if the
# output cannot be defined.
# Here we allow for it to raise arbitrary exceptions, e.g. like Dask does.
except Exception:
return True


__all__ = [
"array_namespace",
"device",
Expand All @@ -845,6 +902,7 @@ def is_writeable_array(x) -> bool:
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"is_writeable_array",
"is_lazy_array",
"size",
"to_device",
]
Expand Down
1 change: 1 addition & 0 deletions docs/helper-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ yet.
.. autofunction:: is_pydata_sparse_array
.. autofunction:: is_ndonnx_array
.. autofunction:: is_writeable_array
.. autofunction:: is_lazy_array
.. autofunction:: is_numpy_namespace
.. autofunction:: is_cupy_namespace
.. autofunction:: is_torch_namespace
Expand Down
54 changes: 48 additions & 6 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
import math

import pytest
import numpy as np
import array
from numpy.testing import assert_allclose

from array_api_compat import ( # noqa: F401
is_numpy_array, is_cupy_array, is_torch_array,
is_dask_array, is_jax_array, is_pydata_sparse_array,
Expand All @@ -6,15 +13,10 @@
)

from array_api_compat import (
device, is_array_api_obj, is_writeable_array, size, to_device
device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device
)
from ._helpers import import_, wrapped_libraries, all_libraries

import pytest
import numpy as np
import array
from numpy.testing import assert_allclose

is_array_functions = {
'numpy': 'is_numpy_array',
'cupy': 'is_cupy_array',
Expand Down Expand Up @@ -115,6 +117,45 @@ def test_size_none(library):
assert size(x) in (None, 5)


@pytest.mark.parametrize("library", all_libraries)
def test_is_lazy_array(library):
lib = import_(library)
x = lib.asarray([1, 2, 3])
assert isinstance(is_lazy_array(x), bool)


@pytest.mark.parametrize("shape", [(math.nan,), (1, math.nan), (None, ), (1, None)])
def test_is_lazy_array_nan_size(shape, monkeypatch):
"""Test is_lazy_array() on an unknown Array API compliant object
with NaN (like Dask) or None (like ndonnx) in its shape
Comment on lines +127 to +130
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be simplified after #231

"""
xp = import_("array_api_strict")
x = xp.asarray(1)
assert not is_lazy_array(x)
monkeypatch.setattr(type(x), "shape", shape)
assert is_lazy_array(x)


@pytest.mark.parametrize("exc", [TypeError, AssertionError])
def test_is_lazy_array_bool_raises(exc, monkeypatch):
"""Test is_lazy_array() on an unknown Array API compliant object
where calling bool() raises:
- TypeError: e.g. like jitted JAX. This is the proper exception which
lazy arrays should raise as per the Array API specification
- something else: e.g. like Dask, where bool() triggers compute()
which can result in any kind of exception to be raised
"""
xp = import_("array_api_strict")
x = xp.asarray(1)
assert not is_lazy_array(x)

def __bool__(self):
raise exc("Hello world")

monkeypatch.setattr(type(x), "__bool__", __bool__)
assert is_lazy_array(x)


@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
xp = import_(library, wrapper=True)
Expand Down Expand Up @@ -172,6 +213,7 @@ def test_asarray_cross_library(source_library, target_library, request):

assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"


@pytest.mark.parametrize("library", wrapped_libraries)
def test_asarray_copy(library):
# Note, we have this test here because the test suite currently doesn't
Expand Down
Loading