Skip to content

Commit

Permalink
Merge pull request #190 from asmeurer/sign-fix
Browse files Browse the repository at this point in the history
Add a wrapper for sign for NumPy-likes
  • Loading branch information
asmeurer authored Oct 29, 2024
2 parents 5affae5 + 8dee4d6 commit 522a608
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 10 deletions.
19 changes: 17 additions & 2 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import NamedTuple
import inspect

from ._helpers import array_namespace, _check_device, device, is_torch_array
from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace

# These functions are modified from the NumPy versions.

Expand Down Expand Up @@ -530,11 +530,26 @@ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
raise ValueError("Input array must be at least 1-d.")
return tuple(xp.moveaxis(x, axis, 0))

# numpy 1.26 does not use the standard definition for sign on complex numbers

def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
if isdtype(x.dtype, 'complex floating', xp=xp):
out = (x/xp.abs(x, **kwargs))[...]
# sign(0) = 0 but the above formula would give nan
out[x == 0+0j] = 0+0j
else:
out = xp.sign(x, **kwargs)
# CuPy sign() does not propagate nans. See
# https://github.com/data-apis/array-api-compat/issues/136
if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
out[xp.isnan(x)] = xp.nan
return out[()]

__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
'unstack']
'unstack', 'sign']
8 changes: 1 addition & 7 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
matmul = get_xp(cp)(_aliases.matmul)
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
tensordot = get_xp(cp)(_aliases.tensordot)
sign = get_xp(cp)(_aliases.sign)

_copy_default = object()

Expand Down Expand Up @@ -109,13 +110,6 @@ def asarray(

return cp.array(obj, dtype=dtype, **kwargs)

def sign(x: ndarray, /) -> ndarray:
# CuPy sign() does not propagate nans. See
# https://github.com/data-apis/array-api-compat/issues/136
out = cp.sign(x)
out[cp.isnan(x)] = cp.nan
return out

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp, 'vecdot'):
Expand Down
2 changes: 1 addition & 1 deletion array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _dask_arange(
trunc = get_xp(np)(_aliases.trunc)
matmul = get_xp(np)(_aliases.matmul)
tensordot = get_xp(np)(_aliases.tensordot)

sign = get_xp(np)(_aliases.sign)

# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
Expand Down
1 change: 1 addition & 0 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
matmul = get_xp(np)(_aliases.matmul)
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
tensordot = get_xp(np)(_aliases.tensordot)
sign = get_xp(np)(_aliases.sign)

def _supports_buffer_protocol(obj):
try:
Expand Down
4 changes: 4 additions & 0 deletions torch-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1
array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)]

# inverse trig functions are too inaccurate on CPU
array_api_tests/test_operators_and_elementwise_functions.py::test_acos
array_api_tests/test_operators_and_elementwise_functions.py::test_atan
array_api_tests/test_operators_and_elementwise_functions.py::test_asin

# overflow near float max
array_api_tests/test_operators_and_elementwise_functions.py::test_log1p
Expand Down

0 comments on commit 522a608

Please sign in to comment.