Skip to content

Commit

Permalink
Merge pull request #314 from ev-br/vecdot_conj
Browse files Browse the repository at this point in the history
ENH: test vecdot values, incl complex conj
  • Loading branch information
ev-br authored Nov 23, 2024
2 parents c2e010e + 6ea8ae2 commit a71b4c0
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions array_api_tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from . import _array_module as xp
from ._array_module import linalg


def assert_equal(x, y, msg_extra=None):
extra = '' if not msg_extra else f' ({msg_extra})'
if x.dtype in dh.all_float_dtypes:
Expand All @@ -60,6 +61,7 @@ def assert_equal(x, y, msg_extra=None):
else:
assert_exactly_equal(x, y, msg_extra=msg_extra)


def _test_stacks(f, *args, res=None, dims=2, true_val=None,
matrix_axes=(-2, -1),
res_axes=None,
Expand Down Expand Up @@ -106,6 +108,7 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None,
if true_val:
assert_equal(decomp_res_stack, true_val(*x_stacks, **kw), msg_extra)


def _test_namedtuple(res, fields, func_name):
"""
Test that res is a namedtuple with the correct fields.
Expand All @@ -121,6 +124,7 @@ def _test_namedtuple(res, fields, func_name):
assert hasattr(res, field), f"{func_name}() result namedtuple doesn't have the '{field}' field"
assert res[i] is getattr(res, field), f"{func_name}() result namedtuple '{field}' field is not in position {i}"


@pytest.mark.unvectorized
@pytest.mark.xp_extension('linalg')
@given(
Expand Down Expand Up @@ -901,6 +905,15 @@ def true_trace(x_stack, offset=0):

_test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace)


def _conj(x):
# XXX: replace with xp.dtype when all array libraries implement it
if x.dtype in (xp.complex64, xp.complex128):
return xp.conj(x)
else:
return x


def _test_vecdot(namespace, x1, x2, data):
vecdot = namespace.vecdot
broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)
Expand All @@ -925,11 +938,8 @@ def _test_vecdot(namespace, x1, x2, data):
ph.assert_result_shape("vecdot", in_shapes=[x1.shape, x2.shape],
out_shape=res.shape, expected=expected_shape)

if x1.dtype in dh.int_dtypes:
def true_val(x, y, axis=-1):
return xp.sum(xp.multiply(x, y), dtype=res.dtype)
else:
true_val = None
def true_val(x, y, axis=-1):
return xp.sum(xp.multiply(_conj(x), y), dtype=res.dtype)

_test_stacks(vecdot, x1, x2, res=res, dims=0,
matrix_axes=(axis,), true_val=true_val)
Expand All @@ -944,6 +954,7 @@ def true_val(x, y, axis=-1):
def test_linalg_vecdot(x1, x2, data):
_test_vecdot(linalg, x1, x2, data)


@pytest.mark.unvectorized
@given(
*two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)),
Expand All @@ -952,10 +963,12 @@ def test_linalg_vecdot(x1, x2, data):
def test_vecdot(x1, x2, data):
_test_vecdot(_array_module, x1, x2, data)


# Insanely large orders might not work. There isn't a limit specified in the
# spec, so we just limit to reasonable values here.
max_ord = 100


@pytest.mark.unvectorized
@pytest.mark.xp_extension('linalg')
@given(
Expand Down

0 comments on commit a71b4c0

Please sign in to comment.