Skip to content

Commit

Permalink
Fix cupy sign nan handling
Browse files Browse the repository at this point in the history
  • Loading branch information
asmeurer committed Oct 24, 2024
1 parent 2539057 commit 55e3f71
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import NamedTuple
import inspect

from ._helpers import array_namespace, _check_device, device, is_torch_array
from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace

# These functions are modified from the NumPy versions.

Expand Down Expand Up @@ -541,7 +541,8 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
out = xp.sign(x, **kwargs)
# CuPy sign() does not propagate nans. See
# https://github.com/data-apis/array-api-compat/issues/136
out[xp.isnan(x)] = xp.nan
if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
out[xp.isnan(x)] = xp.nan
return out[()]

__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
Expand Down

0 comments on commit 55e3f71

Please sign in to comment.