Skip to content

Commit

Permalink
Add cumulative_sum to torch
Browse files Browse the repository at this point in the history
  • Loading branch information
asmeurer committed Sep 24, 2024
1 parent cb9acd4 commit c0dd5b0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
4 changes: 3 additions & 1 deletion array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ def cumulative_sum(
include_initial: bool = False,
**kwargs
) -> ndarray:
wrapped_xp = array_namespace(x)

# TODO: The standard is not clear about what should happen when x.ndim == 0.
if axis is None:
if x.ndim > 1:
Expand All @@ -290,7 +292,7 @@ def cumulative_sum(
initial_shape = list(x.shape)
initial_shape[axis] = 1
res = xp.concatenate(
[xp.zeros_like(res, shape=initial_shape), res],
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
axis=axis,
)
return res
Expand Down
18 changes: 11 additions & 7 deletions array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
from builtins import all as _builtin_all, any as _builtin_any

from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose,
vecdot as _aliases_vecdot, clip as
_aliases_clip, unstack as _aliases_unstack,)
vecdot as _aliases_vecdot,
clip as _aliases_clip,
unstack as _aliases_unstack,
cumulative_sum as _aliases_cumulative_sum,
)
from .._internal import get_xp

from ._info import __array_namespace_info__
Expand Down Expand Up @@ -198,6 +201,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep

clip = get_xp(torch)(_aliases_clip)
unstack = get_xp(torch)(_aliases_unstack)
cumulative_sum = get_xp(torch)(_aliases_cumulative_sum)

# torch.sort also returns a tuple
# https://github.com/pytorch/pytorch/issues/70921
Expand Down Expand Up @@ -732,11 +736,11 @@ def sign(x: array, /) -> array:
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide',
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow',
'remainder', 'subtract', 'max', 'min', 'clip', 'unstack', 'sort',
'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat',
'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where',
'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros',
'empty', 'tril', 'triu', 'expand_dims', 'astype',
'remainder', 'subtract', 'max', 'min', 'clip', 'unstack',
'cumulative_sum', 'sort', 'prod', 'sum', 'any', 'all', 'mean',
'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll',
'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full',
'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult',
'UniqueInverseResult', 'unique_all', 'unique_counts',
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
Expand Down

0 comments on commit c0dd5b0

Please sign in to comment.