Skip to content

Commit

Permalink
ENH: astype: add device kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 23, 2025
1 parent 8a79994 commit e1da0d6
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 24 deletions.
7 changes: 1 addition & 6 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,6 @@ def unique_values(x: ndarray, /, xp) -> ndarray:
**kwargs,
)

def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
if not copy and dtype == x.dtype:
return x
return x.astype(dtype=dtype, copy=copy)

# These functions have different keyword argument names

def std(
Expand Down Expand Up @@ -549,7 +544,7 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
'unstack', 'sign']
22 changes: 18 additions & 4 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import cupy as cp

from ..common import _aliases
from ..common import _aliases, _helpers
from .._internal import get_xp

from ._info import __array_namespace_info__
Expand Down Expand Up @@ -46,7 +46,6 @@
unique_counts = get_xp(cp)(_aliases.unique_counts)
unique_inverse = get_xp(cp)(_aliases.unique_inverse)
unique_values = get_xp(cp)(_aliases.unique_values)
astype = _aliases.astype
std = get_xp(cp)(_aliases.std)
var = get_xp(cp)(_aliases.var)
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
Expand Down Expand Up @@ -110,6 +109,21 @@ def asarray(

return cp.array(obj, dtype=dtype, **kwargs)


def astype(
x: ndarray,
dtype: Dtype,
/,
*,
copy: bool = True,
device: Optional[Device] = None,
) -> ndarray:
if device is None:
return x.astype(dtype=dtype, copy=copy)
out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device)
return out.copy() if copy and out is x else x


# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp, 'vecdot'):
Expand All @@ -127,10 +141,10 @@ def asarray(
else:
unstack = get_xp(cp)(_aliases.unstack)

__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
'acos', 'acosh', 'asin', 'asinh', 'atan',
'atan2', 'atanh', 'bitwise_left_shift',
'bitwise_invert', 'bitwise_right_shift',
'concat', 'pow', 'sign']
'bool', 'concat', 'pow', 'sign']

_all_ignore = ['cp', 'get_xp']
2 changes: 1 addition & 1 deletion array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _isscalar(a):

_common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]

__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'acos',
__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'astype', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast',
Expand Down
17 changes: 14 additions & 3 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
unique_counts = get_xp(np)(_aliases.unique_counts)
unique_inverse = get_xp(np)(_aliases.unique_inverse)
unique_values = get_xp(np)(_aliases.unique_values)
astype = _aliases.astype
std = get_xp(np)(_aliases.std)
var = get_xp(np)(_aliases.var)
cumulative_sum = get_xp(np)(_aliases.cumulative_sum)
Expand Down Expand Up @@ -115,6 +114,18 @@ def asarray(

return np.array(obj, copy=copy, dtype=dtype, **kwargs)


def astype(
x: ndarray,
dtype: Dtype,
/,
*,
copy: bool = True,
device: Optional[Device] = None,
) -> ndarray:
return x.astype(dtype=dtype, copy=copy)


# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np, 'vecdot'):
Expand All @@ -132,10 +143,10 @@ def asarray(
else:
unstack = get_xp(np)(_aliases.unstack)

__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
'acos', 'acosh', 'asin', 'asinh', 'atan',
'atan2', 'atanh', 'bitwise_left_shift',
'bitwise_invert', 'bitwise_right_shift',
'concat', 'pow']
'bool', 'concat', 'pow']

_all_ignore = ['np', 'get_xp']
15 changes: 13 additions & 2 deletions array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,19 @@ def triu(x: array, /, *, k: int = 0) -> array:
def expand_dims(x: array, /, *, axis: int = 0) -> array:
return torch.unsqueeze(x, axis)

def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
return x.to(dtype, copy=copy)

def astype(
x: array,
dtype: Dtype,
/,
*,
copy: bool = True,
device: Optional[Device] = None,
) -> array:
if device is not None:
return x.to(device, dtype=dtype, copy=copy)
return x.to(dtype=dtype, copy=copy)


def broadcast_arrays(*arrays: array) -> List[array]:
shape = torch.broadcast_shapes(*[a.shape for a in arrays])
Expand Down
1 change: 0 additions & 1 deletion cupy-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -181,5 +181,4 @@ array_api_tests/test_fft.py::test_irfftn
# cupy.ndaray cannot be specified as `repeats` argument.
array_api_tests/test_manipulation_functions.py::test_repeat
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
array_api_tests/test_signatures.py::test_func_signature[astype]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
1 change: 0 additions & 1 deletion dask-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,3 @@ array_api_tests/test_statistical_functions.py::test_prod
# 2023.12 support
array_api_tests/test_manipulation_functions.py::test_repeat
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_signatures.py::test_func_signature[astype]
1 change: 0 additions & 1 deletion numpy-1-21-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is
# 2023.12 support
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
array_api_tests/test_signatures.py::test_func_signature[astype]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
# uint64 repeats not supported
array_api_tests/test_manipulation_functions.py::test_repeat
1 change: 0 additions & 1 deletion numpy-1-26-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ array_api_tests/test_statistical_functions.py::test_prod
# 2023.12 support
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
array_api_tests/test_signatures.py::test_func_signature[astype]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
# uint64 repeats not supported
array_api_tests/test_manipulation_functions.py::test_repeat
1 change: 0 additions & 1 deletion numpy-dev-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot]

# 2023.12 support
# Argument 'device' missing from signature
array_api_tests/test_signatures.py::test_func_signature[astype]
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
# uint64 repeats not supported
Expand Down
1 change: 0 additions & 1 deletion numpy-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot]
# 2023.12 support
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
array_api_tests/test_signatures.py::test_func_signature[astype]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
# uint64 repeats not supported
array_api_tests/test_manipulation_functions.py::test_repeat
2 changes: 0 additions & 2 deletions torch-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -202,5 +202,3 @@ array_api_tests/test_signatures.py::test_func_signature[repeat]
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
# Argument 'max_version' missing from signature
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
# Argument 'device' missing from signature
array_api_tests/test_signatures.py::test_func_signature[astype]

0 comments on commit e1da0d6

Please sign in to comment.