Skip to content

Commit

Permalink
cp.generic is an alias for np.generic
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Nov 28, 2024
1 parent ee25aae commit 5b2b3b4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def is_cupy_array(x):
import cupy as cp

# TODO: Should we reject ndarray subclasses?
return isinstance(x, (cp.ndarray, cp.generic))
return isinstance(x, cp.ndarray)

def is_torch_array(x):
"""
Expand Down
19 changes: 19 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@ def test_is_xp_namespace(library, func):
assert is_func(lib) == (func == is_namespace_functions[library])


@pytest.mark.parametrize('library', all_libraries)
def test_xp_is_array_generics(library):
"""
Test that scalar selection on a xp.ndarray always returns
an object that matches with exactly one among the is_*_array
function of the same library and is_numpy_array.
"""
lib = import_(library)
x = lib.asarray([1, 2, 3])
x0 = x[0]

matches = []
for library2, func in is_array_functions.items():
is_func = globals()[func]
if is_func(x0):
matches.append(library2)
assert matches in ([library], ["numpy"])


@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
xp = import_(library, wrapper=True)
Expand Down

0 comments on commit 5b2b3b4

Please sign in to comment.