Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into guard-tests-no-co…
Browse files Browse the repository at this point in the history
…mplex-dtypes
  • Loading branch information
cbourjau committed Dec 1, 2024
2 parents 462d0d3 + f1c3ed2 commit 523bf4c
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 48 deletions.
77 changes: 76 additions & 1 deletion array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def from_dtype(dtype, **kwargs) -> SearchStrategy[Scalar]:


@wraps(xps.arrays)
def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
def arrays_no_scalars(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
"""xps.arrays() without the crazy large numbers."""
if isinstance(dtype, SearchStrategy):
return dtype.flatmap(lambda d: arrays(d, *args, elements=elements, **kwargs))
Expand All @@ -77,6 +77,19 @@ def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
return xps.arrays(dtype, *args, elements=elements, **kwargs)


def _f(a, flag):
return a[()] if a.ndim==0 and flag else a


@wraps(xps.arrays)
def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
"""xps.arrays() without the crazy large numbers. Also draw 0D arrays or numpy scalars.
Is only relevant for numpy: on all other libraries, array[()] is no-op.
"""
return builds(_f, arrays_no_scalars(dtype, *args, elements=elements, **kwargs), booleans())


_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]
_sorted_dtypes = [d for category in _dtype_categories for d in category]

Expand Down Expand Up @@ -232,6 +245,68 @@ def shapes(**kw):
lambda shape: math.prod(i for i in shape if i) < MAX_ARRAY_SIZE
)

def _factorize(n: int) -> List[int]:
# Simple prime factorization. Only needs to handle n ~ MAX_ARRAY_SIZE
factors = []
while n % 2 == 0:
factors.append(2)
n //= 2

for i in range(3, int(math.sqrt(n)) + 1, 2):
while n % i == 0:
factors.append(i)
n //= i

if n > 1: # n is a prime number greater than 2
factors.append(n)

return factors

MAX_SIDE = MAX_ARRAY_SIZE // 64
# NumPy only supports up to 32 dims. TODO: Get this from the new inspection APIs
MAX_DIMS = min(MAX_ARRAY_SIZE // MAX_SIDE, 32)


@composite
def reshape_shapes(draw, arr_shape, ndims=integers(1, MAX_DIMS)):
"""
Generate shape tuples whose product equals the product of array_shape.
"""
shape = draw(arr_shape)

array_size = math.prod(shape)

n_dims = draw(ndims)

# Handle special cases
if array_size == 0:
# Generate a random tuple, and ensure at least one of the entries is 0
result = list(draw(shapes(min_dims=n_dims, max_dims=n_dims)))
pos = draw(integers(0, n_dims - 1))
result[pos] = 0
return tuple(result)

if array_size == 1:
return tuple(1 for _ in range(n_dims))

# Get prime factorization
factors = _factorize(array_size)

# Distribute prime factors randomly
result = [1] * n_dims
for factor in factors:
pos = draw(integers(0, n_dims - 1))
result[pos] *= factor

assert math.prod(result) == array_size

# An element of the reshape tuple can be -1, which means it is a stand-in
# for the remaining factors.
if draw(booleans()):
pos = draw(integers(0, n_dims - 1))
result[pos] = -1

return tuple(result)

one_d_shapes = xps.array_shapes(min_dims=1, max_dims=1, min_side=0, max_side=SQRT_MAX_ARRAY_SIZE)

Expand Down
3 changes: 2 additions & 1 deletion array_api_tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ def scalar_eq(s1: Scalar, s2: Scalar) -> bool:
data=st.data(),
)
def test_asarray_arrays(shape, dtypes, data):
x = data.draw(hh.arrays(dtype=dtypes.input_dtype, shape=shape), label="x")
# generate arrays only since we draw the copy= kwd below (and np.asarray(scalar, copy=False) error out)
x = data.draw(hh.arrays_no_scalars(dtype=dtypes.input_dtype, shape=shape), label="x")
dtypes_strat = st.just(dtypes.input_dtype)
if dtypes.input_dtype == dtypes.result_dtype:
dtypes_strat |= st.none()
Expand Down
34 changes: 28 additions & 6 deletions array_api_tests/test_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Union

import pytest
from hypothesis import given
from hypothesis import given, assume
from hypothesis import strategies as st

from . import _array_module as xp
Expand All @@ -23,26 +23,43 @@ def float32(n: Union[int, float]) -> float:
return struct.unpack("!f", struct.pack("!f", float(n)))[0]


def _float_match_complex(complex_dtype):
return xp.float32 if complex_dtype == xp.complex64 else xp.float64


@given(
x_dtype=non_complex_dtypes(),
dtype=non_complex_dtypes(),
x_dtype=hh.all_dtypes,
dtype=hh.all_dtypes,
kw=hh.kwargs(copy=st.booleans()),
data=st.data(),
)
def test_astype(x_dtype, dtype, kw, data):
_complex_dtypes = (xp.complex64, xp.complex128)

if xp.bool in (x_dtype, dtype):
elements_strat = hh.from_dtype(x_dtype)
else:
m1, M1 = dh.dtype_ranges[x_dtype]
m2, M2 = dh.dtype_ranges[dtype]

if dh.is_int_dtype(x_dtype):
cast = int
elif x_dtype == xp.float32:
elif x_dtype in (xp.float32, xp.complex64):
cast = float32
else:
cast = float

real_dtype = x_dtype
if x_dtype in _complex_dtypes:
real_dtype = _float_match_complex(x_dtype)
m1, M1 = dh.dtype_ranges[real_dtype]

real_dtype = dtype
if dtype in _complex_dtypes:
real_dtype = _float_match_complex(x_dtype)
m2, M2 = dh.dtype_ranges[real_dtype]

min_value = cast(max(m1, m2))
max_value = cast(min(M1, M2))

elements_strat = hh.from_dtype(
x_dtype,
min_value=min_value,
Expand All @@ -54,6 +71,11 @@ def test_astype(x_dtype, dtype, kw, data):
hh.arrays(dtype=x_dtype, shape=hh.shapes(), elements=elements_strat), label="x"
)

# according to the spec, "Casting a complex floating-point array to a real-valued
# data type should not be permitted."
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype
assume(not ((x_dtype in _complex_dtypes) and (dtype not in _complex_dtypes)))

out = xp.astype(x, dtype, **kw)

ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype)
Expand Down
28 changes: 15 additions & 13 deletions array_api_tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def test_fft(x, data):
assert_n_axis_shape("fft", x=x, n=n, axis=axis, out=out)


if hh.complex_dtypes:
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
def test_ifft(x, data):
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
Expand All @@ -130,6 +131,7 @@ def test_ifft(x, data):
assert_n_axis_shape("ifft", x=x, n=n, axis=axis, out=out)


if hh.complex_dtypes:
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
def test_fftn(x, data):
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
Expand All @@ -140,6 +142,7 @@ def test_fftn(x, data):
assert_s_axes_shape("fftn", x=x, s=s, axes=axes, out=out)


if hh.complex_dtypes:
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
def test_ifftn(x, data):
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
Expand Down Expand Up @@ -230,21 +233,20 @@ def test_irfftn(x, data):
expected=dh.dtype_components[x.dtype],
)

# TODO: assert shape correctly
# _axes = sh.normalize_axis(axes, x.ndim)
# _s = x.shape if s is None else s
# expected = []
# for i in range(x.ndim):
# if i in _axes:
# side = _s[_axes.index(i)]
# else:
# side = x.shape[i]
# expected.append(side)
# last_axis = max(_axes)
# expected[last_axis] = _s[_axes.index(last_axis)] // 2 + 1
# ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected))
_axes = sh.normalize_axis(axes, x.ndim)
_s = x.shape if s is None else s
expected = []
for i in range(x.ndim):
if i in _axes:
side = _s[_axes.index(i)]
else:
side = x.shape[i]
expected.append(side)
expected[_axes[-1]] = 2*(_s[-1] - 1) if s is None else _s[-1]
ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected))


if hh.complex_dtypes:
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
def test_hfft(x, data):
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)
Expand Down
23 changes: 18 additions & 5 deletions array_api_tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from . import _array_module as xp
from ._array_module import linalg


def assert_equal(x, y, msg_extra=None):
extra = '' if not msg_extra else f' ({msg_extra})'
if x.dtype in dh.all_float_dtypes:
Expand All @@ -60,6 +61,7 @@ def assert_equal(x, y, msg_extra=None):
else:
assert_exactly_equal(x, y, msg_extra=msg_extra)


def _test_stacks(f, *args, res=None, dims=2, true_val=None,
matrix_axes=(-2, -1),
res_axes=None,
Expand Down Expand Up @@ -106,6 +108,7 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None,
if true_val:
assert_equal(decomp_res_stack, true_val(*x_stacks, **kw), msg_extra)


def _test_namedtuple(res, fields, func_name):
"""
Test that res is a namedtuple with the correct fields.
Expand All @@ -121,6 +124,7 @@ def _test_namedtuple(res, fields, func_name):
assert hasattr(res, field), f"{func_name}() result namedtuple doesn't have the '{field}' field"
assert res[i] is getattr(res, field), f"{func_name}() result namedtuple '{field}' field is not in position {i}"


@pytest.mark.unvectorized
@pytest.mark.xp_extension('linalg')
@given(
Expand Down Expand Up @@ -901,6 +905,15 @@ def true_trace(x_stack, offset=0):

_test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace)


def _conj(x):
# XXX: replace with xp.dtype when all array libraries implement it
if x.dtype in (xp.complex64, xp.complex128):
return xp.conj(x)
else:
return x


def _test_vecdot(namespace, x1, x2, data):
vecdot = namespace.vecdot
broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)
Expand All @@ -925,11 +938,8 @@ def _test_vecdot(namespace, x1, x2, data):
ph.assert_result_shape("vecdot", in_shapes=[x1.shape, x2.shape],
out_shape=res.shape, expected=expected_shape)

if x1.dtype in dh.int_dtypes:
def true_val(x, y, axis=-1):
return xp.sum(xp.multiply(x, y), dtype=res.dtype)
else:
true_val = None
def true_val(x, y, axis=-1):
return xp.sum(xp.multiply(_conj(x), y), dtype=res.dtype)

_test_stacks(vecdot, x1, x2, res=res, dims=0,
matrix_axes=(axis,), true_val=true_val)
Expand All @@ -944,6 +954,7 @@ def true_val(x, y, axis=-1):
def test_linalg_vecdot(x1, x2, data):
_test_vecdot(linalg, x1, x2, data)


@pytest.mark.unvectorized
@given(
*two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)),
Expand All @@ -952,10 +963,12 @@ def test_linalg_vecdot(x1, x2, data):
def test_vecdot(x1, x2, data):
_test_vecdot(_array_module, x1, x2, data)


# Insanely large orders might not work. There isn't a limit specified in the
# spec, so we just limit to reasonable values here.
max_ord = 100


@pytest.mark.unvectorized
@pytest.mark.xp_extension('linalg')
@given(
Expand Down
25 changes: 5 additions & 20 deletions array_api_tests/test_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
from . import xps
from .typing import Array, Shape

MAX_SIDE = hh.MAX_ARRAY_SIZE // 64
MAX_DIMS = min(hh.MAX_ARRAY_SIZE // MAX_SIDE, 32) # NumPy only supports up to 32 dims


def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]:
key = "shape"
Expand Down Expand Up @@ -66,7 +63,7 @@ def test_concat(dtypes, base_shape, data):
shape_strat = hh.shapes()
else:
_axis = axis if axis >= 0 else len(base_shape) + axis
shape_strat = st.integers(0, MAX_SIDE).map(
shape_strat = st.integers(0, hh.MAX_SIDE).map(
lambda i: base_shape[:_axis] + (i,) + base_shape[_axis + 1 :]
)
arrays = []
Expand Down Expand Up @@ -348,26 +345,14 @@ def test_repeat(x, kw, data):
kw=kw)
start = end

@st.composite
def reshape_shapes(draw, shape):
size = 1 if len(shape) == 0 else math.prod(shape)
rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size))
assume(all(side <= MAX_SIDE for side in rshape))
if len(rshape) != 0 and size > 0 and draw(st.booleans()):
index = draw(st.integers(0, len(rshape) - 1))
rshape[index] = -1
return tuple(rshape)

reshape_shape = st.shared(hh.shapes(), key="reshape_shape")

@pytest.mark.unvectorized
@pytest.mark.skip("flaky") # TODO: fix!
@given(
x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(max_side=MAX_SIDE)),
data=st.data(),
x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape),
shape=hh.reshape_shapes(reshape_shape),
)
def test_reshape(x, data):
shape = data.draw(reshape_shapes(x.shape))

def test_reshape(x, shape):
out = xp.reshape(x, shape)

ph.assert_dtype("reshape", in_dtype=x.dtype, out_dtype=out.dtype)
Expand Down
11 changes: 9 additions & 2 deletions array_api_tests/test_special_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from decimal import ROUND_HALF_EVEN, Decimal
from enum import Enum, auto
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Literal
from warnings import warn
from warnings import warn, filterwarnings, catch_warnings

import pytest
from hypothesis import given, note, settings, assume
from hypothesis import strategies as st
from hypothesis.errors import NonInteractiveExampleWarning

from array_api_tests.typing import Array, DataType

Expand Down Expand Up @@ -1250,7 +1251,13 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]

@pytest.mark.parametrize("func_name, func, case", unary_params)
def test_unary(func_name, func, case):
in_value = case.cond_from_dtype(xp.float64).example()
with catch_warnings():
# XXX: We are using example here to generate one example draw, but
# hypothesis issues a warning from this. We should consider either
# drawing multiple examples like a normal test, or just hard-coding a
# single example test case without using hypothesis.
filterwarnings('ignore', category=NonInteractiveExampleWarning)
in_value = case.cond_from_dtype(xp.float64).example()
x = xp.asarray(in_value, dtype=xp.float64)
out = func(x)
out_value = float(out)
Expand Down

0 comments on commit 523bf4c

Please sign in to comment.