Skip to content

Commit

Permalink
Merge pull request #176 from scipy/numpy_2.2_abbrev
Browse files Browse the repository at this point in the history
Numpy 2.2 abbreviations: handle both numpy<2.2 and numpy>=2.2 abbreviations, depending on 'got':

IOW, if a doctest matches numpy 2.2, i.e. has a new `shape=` fake kwarg,
- when run on numpy<2.2, the `shape=` part is ignored
- when run on numpy>=2.2, the shape is checked.
  • Loading branch information
ev-br authored Dec 9, 2024
2 parents 78b24e3 + dc9a25c commit 4a1ff04
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 15 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/pip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.10', '3.11']
python-version: ['3.11', '3.12']
numpy: ['"numpy<2.2"', 'numpy']
os: [ubuntu-latest]
pytest: ['"pytest<8.0"', pytest]
pre: ['', '--pre']
Expand All @@ -35,6 +36,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install ${{matrix.pytest}} ${{matrix.pre}}
python -m pip install -e . ${{matrix.pre}}
python -m pip install ${{matrix.numpy}}
- name: Echo versions
run: |
Expand Down
43 changes: 38 additions & 5 deletions scipy_doctest/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,12 @@ def __init__(self, *, # DTChecker configuration
'masked_array': np.ma.masked_array,
'int64': np.int64,
'uint64': np.uint64,
'int8': np.int8,
'int32': np.int32,
'uint32': np.uint32,
'int16': np.int16,
'uint16': np.uint16,
'int8': np.int8,
'uint8': np.uint8,
'float32': np.float32,
'float64': np.float64,
'dtype': np.dtype,
Expand Down Expand Up @@ -262,6 +266,29 @@ def has_masked(got):
return 'masked_array' in got and '--' in got


def try_split_shape_from_abbrv(s_got):
"""NumPy 2.2 added shape=(123,) to abbreviated array repr.
If present, split it off, and return a tuple. `(array, shape)`
"""
if "shape=" in s_got:
# handle
# array(..., shape=(1000,))
# array(..., shape=(100, 100))
# array(..., shape=(100, 100), dtype=uint16)
match = re.match(r'(.+),\s+shape=\(([\d\s,]+)\)(.+)', s_got, flags=re.DOTALL)
if match:
grp = match.groups()

s_got = grp[0] + grp[-1]
s_got = s_got.replace(',,', ',')
shape_str = f'({grp[1]})'

return ''.join(s_got.split('...,')), shape_str

return ''.join(s_got.split('...,')), ''


class DTChecker(doctest.OutputChecker):
obj_pattern = re.compile(r'at 0x[0-9a-fA-F]+>')
vanilla = doctest.OutputChecker()
Expand Down Expand Up @@ -325,11 +352,17 @@ def check_output(self, want, got, optionflags):
return self.check_output(s_want, s_got, optionflags)

#handle array abbreviation for n-dimensional arrays, n >= 1
ndim_array = (s_want.startswith("array([") and s_want.endswith("])") and
s_got.startswith("array([") and s_got.endswith("])"))
ndim_array = (s_want.startswith("array([") and "..." in s_want and
s_got.startswith("array([") and "..." in s_got)
if ndim_array:
s_want = ''.join(s_want.split('...,'))
s_got = ''.join(s_got.split('...,'))
s_want, want_shape = try_split_shape_from_abbrv(s_want)
s_got, got_shape = try_split_shape_from_abbrv(s_got)

if got_shape:
# NumPy 2.2 output, `with shape=`, check the shapes, too
s_want = f"{s_want}, {want_shape}"
s_got = f"{s_got}, {got_shape}"

return self.check_output(s_want, s_got, optionflags)

# maybe we are dealing with masked arrays?
Expand Down
30 changes: 21 additions & 9 deletions scipy_doctest/tests/module_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,25 @@ def array_abbreviation():
"""
Numpy abbreviates arrays, check that it works.
NB: the implementation might need to change when
numpy finally disallows default-creating ragged arrays.
Currently, `...` gets interpreted as an Ellipsis,
thus the `a_want/a_got` variables in DTChecker are in fact
object arrays.
XXX: check if ... creates ragged arrays, avoid if so.
NumPy 2.2 abbreviations
=======================
NumPy 2.2 adds shape=(...) to abbreviated arrays.
This is not a valid argument to `array(...), so it cannot be eval-ed,
and need to be removed for doctesting.
The implementation handles both formats, and checks the shapes if present
in the actual output. If not present in the output, they are ignored.
>>> import numpy as np
>>> np.arange(10000)
array([0, 1, 2, ..., 9997, 9998, 9999])
array([0, 1, 2, ..., 9997, 9998, 9999], shape=(10000,))
>>> np.arange(10000, dtype=np.uint16)
array([ 0, 1, 2, ..., 9997, 9998, 9999], shape=(10000,), dtype=uint16)
>>> np.diag(np.arange(33)) / 30
array([[0., 0., 0., ..., 0., 0.,0.],
Expand All @@ -180,19 +191,20 @@ def array_abbreviation():
...,
[0., 0., 0., ..., 1., 0., 0.],
[0., 0., 0., ..., 0., 1.03333333, 0.],
[0., 0., 0., ..., 0., 0., 1.06666667]])
[0., 0., 0., ..., 0., 0., 1.06666667]], shape=(33, 33))
>>> np.diag(np.arange(1, 1001, dtype=float))
>>> np.diag(np.arange(1, 1001, dtype=np.uint16))
array([[1, 0, 0, ..., 0, 0, 0],
[0, 2, 0, ..., 0, 0, 0],
[0, 0, 3, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 998, 0, 0],
[0, 0, 0, ..., 0, 999, 0],
[0, 0, 0, ..., 0, 0, 1000]])
[0, 0, 0, ..., 0, 0, 1000]], shape=(1000, 1000), dtype=uint16)
"""


def nan_equal():
"""
Test that nans are treated as equal.
Expand Down

0 comments on commit 4a1ff04

Please sign in to comment.