Skip to content

Commit

Permalink
MAINT: make __array__ raise on python < 3.12
Browse files Browse the repository at this point in the history
Otherwise, on python 3.11 and below, np.array(array_api_strict_array)
becomes a 0D object array.
  • Loading branch information
ev-br committed Jan 24, 2025
1 parent 77400d0 commit 8115901
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
9 changes: 7 additions & 2 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,18 @@ def __repr__(self: Array, /) -> str:
# Instead of `__array__` we now implement the buffer protocol.
# Note that it makes array-apis-strict requiring python>=3.12
def __buffer__(self, flags):
print('__buffer__')
if self._device != CPU_DEVICE:
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
return memoryview(self._array)
def __release_buffer(self, buffer):
print('__release__')
# XXX anything to do here?
pass

def __array__(self, *args, **kwds):
# a stub for python < 3.12; otherwise numpy silently produces object arrays
raise TypeError(
f"Interoperation with NumPy requires python >= 3.12. Please upgrade."

Check failure on line 171 in array_api_strict/_array_object.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F541)

array_api_strict/_array_object.py:171:13: F541 f-string without any placeholders
)

# These are various helper functions to make the array behavior match the
# spec in places where it either deviates from or is more strict than
Expand Down
14 changes: 14 additions & 0 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,20 @@ def test_array_conversion():
with pytest.raises((RuntimeError, TypeError)):
asarray([a])

# __buffer__ should work for now for conversion to numpy
a = ones((2, 3))
na = np.array(a)
assert na.shape == (2, 3)
assert na.dtype == np.float64

@pytest.mark.skipif(not sys.version_info.major*100 + sys.version_info.minor < 312,
reason="conversion to numpy errors out unless python >= 3.12"
)
def test_array_conversion_2():
a = ones((2, 3))
with pytest.raises(TypeError):
np.array(a)


def test_allow_newaxis():
a = ones(5)
Expand Down

0 comments on commit 8115901

Please sign in to comment.