From c2ee28016787c8e27501ed69465718b0f2fab5a4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 11 Feb 2025 16:10:04 +0100 Subject: [PATCH] Only do reshapes in `tensordot` when needed --- pytensor/tensor/math.py | 75 ++++++++++++++++++++++----------------- tests/tensor/test_math.py | 2 +- 2 files changed, 44 insertions(+), 33 deletions(-) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index f11e33b41d..56517ef980 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -2152,13 +2152,11 @@ def tensordot( a = as_tensor_variable(a) b = as_tensor_variable(b) runtime_shape_a = a.shape - bcast_a = a.broadcastable static_shape_a = a.type.shape - ndim_a = a.ndim + ndim_a = a.type.ndim runtime_shape_b = b.shape - bcast_b = b.broadcastable static_shape_b = b.type.shape - ndim_b = b.ndim + ndim_b = b.type.ndim if na != nb: raise ValueError( "The number of axes supplied for tensordot must be equal for each tensor. " @@ -2166,48 +2164,61 @@ def tensordot( ) axes_a = list(normalize_axis_tuple(axes_a, ndim_a)) axes_b = list(normalize_axis_tuple(axes_b, ndim_b)) + + # The operation is only valid if the original dimensions match in length + # The ravelling of the dimensions to coerce the operation into a single dot + # could mask such errors, so we add an Assert if needed. must_assert_runtime = False - for k in range(na): - ax_a = axes_a[k] - ax_b = axes_b[k] - if (bcast_a[ax_a] != bcast_b[ax_b]) or ( + for ax_a, ax_b in zip(axes_a, axes_b, strict=True): + if ( static_shape_a[ax_a] is not None and static_shape_b[ax_b] is not None and static_shape_a[ax_a] != static_shape_b[ax_b] ): raise ValueError( - "Input arrays have inconsistent broadcastable pattern or type shape along the axes " + "Input arrays have inconsistent type shape along the axes " "that are to be reduced with tensordot." ) elif static_shape_a[ax_a] is None or static_shape_b[ax_b] is None: if must_assert_runtime: a = Assert( "Input array shape along reduced axes of tensordot are not equal" - )(a, eq(a.shape[ax_a], b.shape[ax_b])) + )(a, eq(runtime_shape_a[ax_a], runtime_shape_b[ax_b])) must_assert_runtime = True - # Move the axes to sum over to the end of "a" - # and to the front of "b" - notin = [k for k in range(ndim_a) if k not in axes_a] - newaxes_a = notin + axes_a - N2 = 1 - for axis in axes_a: - N2 *= runtime_shape_a[axis] - newshape_a = (-1, N2) - olda = [runtime_shape_a[axis] for axis in notin] - - notin = [k for k in range(ndim_b) if k not in axes_b] - newaxes_b = axes_b + notin - N2 = 1 - for axis in axes_b: - N2 *= runtime_shape_b[axis] - newshape_b = (N2, -1) - oldb = [runtime_shape_b[axis] for axis in notin] - - at = a.transpose(newaxes_a).reshape(newshape_a) - bt = b.transpose(newaxes_b).reshape(newshape_b) - res = _dot(at, bt) - return res.reshape(olda + oldb) + # Convert tensordot into a stacked dot product. + # We stack the summed axes and the non-summed axes of each tensor separately, + # and place the summed axes at the end of a and the beginning of b + non_summed_axes_a = [k for k in range(ndim_a) if k not in axes_a] + non_summed_dims_a = [runtime_shape_a[axis] for axis in non_summed_axes_a] + transpose_axes_a = non_summed_axes_a + axes_a + a_needs_reshape = len(non_summed_axes_a) > 1 or len(axes_a) > 1 + + non_summed_axes_b = [k for k in range(ndim_b) if k not in axes_b] + non_summed_dims_b = [runtime_shape_b[axis] for axis in non_summed_axes_b] + transpose_axes_b = axes_b + non_summed_axes_b + b_needs_reshape = len(axes_b) > 1 or len(non_summed_axes_b) > 1 + + # summed_size_a and summed_size_b must be the same, + # but to facilitate reasoning about useless reshapes we compute both from their shapes + at = a.transpose(transpose_axes_a) + if a_needs_reshape: + non_summed_size_a = variadic_mul(*non_summed_dims_a) + summed_size_a = variadic_mul(*[runtime_shape_a[axis] for axis in axes_a]) + at = at.reshape((non_summed_size_a, summed_size_a)) + + bt = b.transpose(transpose_axes_b) + if b_needs_reshape: + non_summed_size_b = variadic_mul(*non_summed_dims_b) + summed_size_b = variadic_mul(*[runtime_shape_b[axis] for axis in axes_b]) + bt = bt.reshape((summed_size_b, non_summed_size_b)) + + res = dot(at, bt) + + if a_needs_reshape or b_needs_reshape: + res = res.reshape(non_summed_dims_a + non_summed_dims_b) + + return res def outer(x, y): diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 2d19ef0114..7eda01a313 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -2278,7 +2278,7 @@ def test_type_shape(self): with pytest.raises( ValueError, - match="Input arrays have inconsistent broadcastable pattern or type shape", + match="Input arrays have inconsistent type shape", ): tensordot(ones(shape=(7, 4)), ones(shape=(7, 4)), axes=1)