From 4ac92551f4e56a8ee5ae50b15906db91ab453a99 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 16 Dec 2024 17:31:59 +0200 Subject: [PATCH] ENH: allow mean(complex) in 2024.12 --- array_api_strict/_statistical_functions.py | 11 ++++++++-- .../tests/test_statistical_functions.py | 20 +++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 6ea9746..f06785c 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -3,6 +3,7 @@ from ._dtypes import ( _real_floating_dtypes, _real_numeric_dtypes, + _floating_dtypes, _numeric_dtypes, ) from ._array_object import Array @@ -65,8 +66,14 @@ def mean( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> Array: - if x.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in mean") + + if get_array_api_strict_flags()['api_version'] > '2023.12': + allowed_dtypes = _floating_dtypes + else: + allowed_dtypes = _real_floating_dtypes + + if x.dtype not in allowed_dtypes: + raise TypeError("Only floating-point dtypes are allowed in mean") return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims), device=x.device) diff --git a/array_api_strict/tests/test_statistical_functions.py b/array_api_strict/tests/test_statistical_functions.py index 7f2a457..c97670d 100644 --- a/array_api_strict/tests/test_statistical_functions.py +++ b/array_api_strict/tests/test_statistical_functions.py @@ -1,3 +1,4 @@ +import cmath import pytest from .._flags import set_array_api_strict_flags @@ -37,3 +38,22 @@ def test_sum_prod_trace_2023_12(func_name): assert func(a_real).dtype == xp.float32 assert func(a_complex).dtype == xp.complex64 assert func(a_int).dtype == xp.int64 + + +# mean(complex-valued array) is allowed from 2024.12 onwards +def test_mean_complex(): + a = xp.asarray([1j, 2j, 3j]) + + set_array_api_strict_flags(api_version='2023.12') + with pytest.raises(TypeError): + xp.mean(a) + + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2024.12') + m = xp.mean(a) + assert cmath.isclose(complex(m), 2j) + + # mean of integer arrays is still not allowed + with pytest.raises(TypeError): + xp.mean(xp.arange(3)) +