Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: allow mean(complex) in 2024.12 #110

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions array_api_strict/_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ._dtypes import (
_real_floating_dtypes,
_real_numeric_dtypes,
_floating_dtypes,
_numeric_dtypes,
)
from ._array_object import Array
Expand Down Expand Up @@ -65,8 +66,14 @@ def mean(
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> Array:
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in mean")

if get_array_api_strict_flags()['api_version'] > '2023.12':
allowed_dtypes = _floating_dtypes
else:
allowed_dtypes = _real_floating_dtypes

if x.dtype not in allowed_dtypes:
raise TypeError("Only floating-point dtypes are allowed in mean")
return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims), device=x.device)


Expand Down
20 changes: 20 additions & 0 deletions array_api_strict/tests/test_statistical_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import cmath
import pytest

from .._flags import set_array_api_strict_flags
Expand Down Expand Up @@ -37,3 +38,22 @@ def test_sum_prod_trace_2023_12(func_name):
assert func(a_real).dtype == xp.float32
assert func(a_complex).dtype == xp.complex64
assert func(a_int).dtype == xp.int64


# mean(complex-valued array) is allowed from 2024.12 onwards
def test_mean_complex():
a = xp.asarray([1j, 2j, 3j])

set_array_api_strict_flags(api_version='2023.12')
with pytest.raises(TypeError):
xp.mean(a)

with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version='2024.12')
m = xp.mean(a)
assert cmath.isclose(complex(m), 2j)

# mean of integer arrays is still not allowed
with pytest.raises(TypeError):
xp.mean(xp.arange(3))

Loading