Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update __array_api_version__ to 2023.12 #191

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def is_torch_namespace(xp) -> bool:
is_array_api_strict_namespace
"""
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}


def is_ndonnx_namespace(xp):
"""
Expand Down Expand Up @@ -415,10 +415,11 @@ def is_array_api_strict_namespace(xp):
return xp.__name__ == 'array_api_strict'

def _check_api_version(api_version):
if api_version == '2021.12':
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
elif api_version is not None and api_version != '2022.12':
raise ValueError("Only the 2022.12 version of the array API specification is currently supported")
if api_version in ['2021.12', '2022.12']:
warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2023.12")
elif api_version is not None and api_version not in ['2021.12', '2022.12',
'2023.12']:
raise ValueError("Only the 2023.12 version of the array API specification is currently supported")

def array_namespace(*xs, api_version=None, use_compat=None):
"""
Expand All @@ -431,7 +432,7 @@ def array_namespace(*xs, api_version=None, use_compat=None):

api_version: str
The newest version of the spec that you need support for (currently
the compat library wrapped APIs support v2022.12).
the compat library wrapped APIs support v2023.12).

use_compat: bool or None
If None (the default), the native namespace will be returned if it is
Expand Down
2 changes: 1 addition & 1 deletion array_api_compat/cupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@

from ..common._helpers import * # noqa: F401,F403

__array_api_version__ = '2022.12'
__array_api_version__ = '2023.12'
2 changes: 1 addition & 1 deletion array_api_compat/dask/array/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403

__array_api_version__ = '2022.12'
__array_api_version__ = '2023.12'

__import__(__package__ + '.linalg')
__import__(__package__ + '.fft')
2 changes: 1 addition & 1 deletion array_api_compat/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
except ImportError:
pass

__array_api_version__ = '2022.12'
__array_api_version__ = '2023.12'
2 changes: 1 addition & 1 deletion array_api_compat/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@

from ..common._helpers import * # noqa: F403

__array_api_version__ = '2022.12'
__array_api_version__ = '2023.12'
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ each array library itself fully compatible with the array API, but this
requires making backwards incompatible changes in many cases, so this will
take some time.

Currently all libraries here are implemented against the [2022.12
version](https://data-apis.org/array-api/2022.12/) of the standard.
Currently all libraries here are implemented against the [2023.12
version](https://data-apis.org/array-api/2023.12/) of the standard.

## Installation

Expand Down
14 changes: 10 additions & 4 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ._helpers import import_, all_libraries, wrapped_libraries

@pytest.mark.parametrize("use_compat", [True, False, None])
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12"])
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"])
@pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
def test_array_namespace(library, api_version, use_compat):
xp = import_(library)
Expand Down Expand Up @@ -94,14 +94,20 @@ def test_array_namespace_errors_torch():
def test_api_version():
x = torch.asarray([1, 2])
torch_ = import_("torch", wrapper=True)
assert array_namespace(x, api_version="2022.12") == torch_
assert array_namespace(x, api_version="2023.12") == torch_
assert array_namespace(x, api_version=None) == torch_
assert array_namespace(x) == torch_
# Should issue a warning
with warnings.catch_warnings(record=True) as w:
assert array_namespace(x, api_version="2021.12") == torch_
assert len(w) == 1
assert "2021.12" in str(w[0].message)
assert len(w) == 1
assert "2021.12" in str(w[0].message)

# Should issue a warning
with warnings.catch_warnings(record=True) as w:
assert array_namespace(x, api_version="2022.12") == torch_
assert len(w) == 1
assert "2022.12" in str(w[0].message)

pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))

Expand Down
Loading