Skip to content

Commit

Permalink
Merge pull request #170 from asmeurer/numpy-rewrap
Browse files Browse the repository at this point in the history
Re-enable wrapping unconditionally for NumPy
  • Loading branch information
asmeurer authored Aug 6, 2024
2 parents 158622e + 29afe3a commit f905d8c
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
numpy-version: ['1.21', '1.26', 'dev']
numpy-version: ['1.21', '1.26', '2.0', 'dev']
exclude:
- python-version: '3.11'
numpy-version: '1.21'
Expand Down
9 changes: 3 additions & 6 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def is_ndonnx_array(x):

import ndonnx as ndx

return isinstance(x, ndx.Array)
return isinstance(x, ndx.Array)

def is_dask_array(x):
"""
Expand Down Expand Up @@ -340,12 +340,9 @@ def your_function(x, y):
elif use_compat is False:
namespaces.add(np)
else:
# numpy 2.0 has __array_namespace__ and is fully array API
# numpy 2.0+ have __array_namespace__, however, they are not yet fully array API
# compatible.
if hasattr(np.empty(0), '__array_namespace__'):
namespaces.add(np.empty(0).__array_namespace__(api_version=api_version))
else:
namespaces.add(numpy_namespace)
namespaces.add(numpy_namespace)
elif is_cupy_array(x):
if _use_compat:
_check_api_version(api_version)
Expand Down
7 changes: 2 additions & 5 deletions tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@

import pytest

wrapped_libraries = ["cupy", "torch", "dask.array"]
all_libraries = wrapped_libraries + ["numpy", "jax.numpy", "sparse"]
import numpy as np
if np.__version__[0] == '1':
wrapped_libraries.append("numpy")
wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
all_libraries = wrapped_libraries + ["jax.numpy"]

# `sparse` added array API support as of Python 3.10.
if sys.version_info >= (3, 10):
Expand Down

0 comments on commit f905d8c

Please sign in to comment.