diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 6b6c7e3e..d967daa4 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -10,7 +10,7 @@ from hypothesis import assume, reject from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, integers, just, lists, none, one_of, - sampled_from, shared, builds) + sampled_from, shared, builds, nothing) from . import _array_module as xp, api_version from . import array_helpers as ah @@ -200,11 +200,11 @@ def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes: real_floating_dtypes = sampled_from(dh.real_float_dtypes) numeric_dtypes = sampled_from(dh.numeric_dtypes) # Note: this always returns complex dtypes, even if api_version < 2022.12 -complex_dtypes: SearchStrategy[Any] | None = sampled_from(dh.complex_dtypes) if dh.complex_dtypes else None +complex_dtypes: SearchStrategy[Any] = sampled_from(dh.complex_dtypes) if dh.complex_dtypes else nothing() def all_floating_dtypes() -> SearchStrategy[DataType]: strat = floating_dtypes - if api_version >= "2022.12" and complex_dtypes is not None: + if api_version >= "2022.12" and not complex_dtypes.is_empty: strat |= complex_dtypes return strat diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 2298eab5..fecf9d91 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1062,6 +1062,7 @@ def refimpl(_x, _min, _max): @pytest.mark.min_version("2022.12") +@pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from") @given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) def test_conj(x): out = xp.conj(x) @@ -1264,6 +1265,7 @@ def test_hypot(x1, x2): @pytest.mark.min_version("2022.12") +@pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from") @given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) def test_imag(x): out = xp.imag(x) @@ -1559,6 +1561,7 @@ def test_pow(ctx, data): @pytest.mark.min_version("2022.12") +@pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from") @given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) def test_real(x): out = xp.real(x)