diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 27d8ef5..25a1e8b 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -824,6 +824,63 @@ def is_writeable_array(x) -> bool: return True +def is_lazy_array(x) -> bool: + """Return True if x is potentially a future or it may be otherwise impossible or + expensive to eagerly read its contents, regardless of their size, e.g. by + calling ``bool(x)`` or ``float(x)``. + + Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be + cheap as long as the array has the right dtype. + + Note + ---- + This function errs on the side of caution for array types that may or may not be + lazy, e.g. JAX arrays, by always returning True for them. + """ + if ( + is_numpy_array(x) + or is_cupy_array(x) + or is_torch_array(x) + or is_pydata_sparse_array(x) + ): + return False + + # **JAX note:** while it is possible to determine if you're inside or outside + # jax.jit by testing the subclass of a jax.Array object, as well as testing bool() + # as we do below for unknown arrays, this is not recommended by JAX best practices. + + # **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on. + # This behaviour, while impossible to change without breaking backwards + # compatibility, is highly detrimental to performance as the whole graph will end + # up being computed multiple times. + + if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x): + return True + + # Unknown Array API compatible object. Note that this test may have dire consequences + # in terms of performance, e.g. for a lazy object that eagerly computes the graph + # on __bool__ (dask is one such example, which however is special-cased above). + + # Select a single point of the array + s = size(x) + if s is None or math.isnan(s): + return True + xp = array_namespace(x) + if s > 1: + x = xp.reshape(x, (-1,))[0] + # Cast to dtype=bool and deal with size 0 arrays + x = xp.any(x) + + try: + bool(x) + return False + # The Array API standard dictactes that __bool__ should raise TypeError if the + # output cannot be defined. + # Here we allow for it to raise arbitrary exceptions, e.g. like Dask does. + except Exception: + return True + + __all__ = [ "array_namespace", "device", @@ -845,6 +902,7 @@ def is_writeable_array(x) -> bool: "is_pydata_sparse_array", "is_pydata_sparse_namespace", "is_writeable_array", + "is_lazy_array", "size", "to_device", ] diff --git a/docs/helper-functions.rst b/docs/helper-functions.rst index 9d620ce..155eda9 100644 --- a/docs/helper-functions.rst +++ b/docs/helper-functions.rst @@ -52,6 +52,7 @@ yet. .. autofunction:: is_pydata_sparse_array .. autofunction:: is_ndonnx_array .. autofunction:: is_writeable_array +.. autofunction:: is_lazy_array .. autofunction:: is_numpy_namespace .. autofunction:: is_cupy_namespace .. autofunction:: is_torch_namespace diff --git a/tests/test_common.py b/tests/test_common.py index 1a4a32d..7887d4d 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,3 +1,10 @@ +import math + +import pytest +import numpy as np +import array +from numpy.testing import assert_allclose + from array_api_compat import ( # noqa: F401 is_numpy_array, is_cupy_array, is_torch_array, is_dask_array, is_jax_array, is_pydata_sparse_array, @@ -6,15 +13,10 @@ ) from array_api_compat import ( - device, is_array_api_obj, is_writeable_array, size, to_device + device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device ) from ._helpers import import_, wrapped_libraries, all_libraries -import pytest -import numpy as np -import array -from numpy.testing import assert_allclose - is_array_functions = { 'numpy': 'is_numpy_array', 'cupy': 'is_cupy_array', @@ -115,6 +117,45 @@ def test_size_none(library): assert size(x) in (None, 5) +@pytest.mark.parametrize("library", all_libraries) +def test_is_lazy_array(library): + lib = import_(library) + x = lib.asarray([1, 2, 3]) + assert isinstance(is_lazy_array(x), bool) + + +@pytest.mark.parametrize("shape", [(math.nan,), (1, math.nan), (None, ), (1, None)]) +def test_is_lazy_array_nan_size(shape, monkeypatch): + """Test is_lazy_array() on an unknown Array API compliant object + with NaN (like Dask) or None (like ndonnx) in its shape + """ + xp = import_("array_api_strict") + x = xp.asarray(1) + assert not is_lazy_array(x) + monkeypatch.setattr(type(x), "shape", shape) + assert is_lazy_array(x) + + +@pytest.mark.parametrize("exc", [TypeError, AssertionError]) +def test_is_lazy_array_bool_raises(exc, monkeypatch): + """Test is_lazy_array() on an unknown Array API compliant object + where calling bool() raises: + - TypeError: e.g. like jitted JAX. This is the proper exception which + lazy arrays should raise as per the Array API specification + - something else: e.g. like Dask, where bool() triggers compute() + which can result in any kind of exception to be raised + """ + xp = import_("array_api_strict") + x = xp.asarray(1) + assert not is_lazy_array(x) + + def __bool__(self): + raise exc("Hello world") + + monkeypatch.setattr(type(x), "__bool__", __bool__) + assert is_lazy_array(x) + + @pytest.mark.parametrize("library", all_libraries) def test_device(library): xp = import_(library, wrapper=True) @@ -172,6 +213,7 @@ def test_asarray_cross_library(source_library, target_library, request): assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}" + @pytest.mark.parametrize("library", wrapped_libraries) def test_asarray_copy(library): # Note, we have this test here because the test suite currently doesn't