Skip to content

Commit

Permalink
Only do reshapes in tensordot when needed
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 11, 2025
1 parent 3cff4f5 commit c2ee280
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 33 deletions.
75 changes: 43 additions & 32 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2152,62 +2152,73 @@ def tensordot(
a = as_tensor_variable(a)
b = as_tensor_variable(b)
runtime_shape_a = a.shape
bcast_a = a.broadcastable
static_shape_a = a.type.shape
ndim_a = a.ndim
ndim_a = a.type.ndim
runtime_shape_b = b.shape
bcast_b = b.broadcastable
static_shape_b = b.type.shape
ndim_b = b.ndim
ndim_b = b.type.ndim
if na != nb:
raise ValueError(
"The number of axes supplied for tensordot must be equal for each tensor. "
f"Got {na} and {nb} respectively."
)
axes_a = list(normalize_axis_tuple(axes_a, ndim_a))
axes_b = list(normalize_axis_tuple(axes_b, ndim_b))

# The operation is only valid if the original dimensions match in length
# The ravelling of the dimensions to coerce the operation into a single dot
# could mask such errors, so we add an Assert if needed.
must_assert_runtime = False
for k in range(na):
ax_a = axes_a[k]
ax_b = axes_b[k]
if (bcast_a[ax_a] != bcast_b[ax_b]) or (
for ax_a, ax_b in zip(axes_a, axes_b, strict=True):
if (
static_shape_a[ax_a] is not None
and static_shape_b[ax_b] is not None
and static_shape_a[ax_a] != static_shape_b[ax_b]
):
raise ValueError(
"Input arrays have inconsistent broadcastable pattern or type shape along the axes "
"Input arrays have inconsistent type shape along the axes "
"that are to be reduced with tensordot."
)
elif static_shape_a[ax_a] is None or static_shape_b[ax_b] is None:
if must_assert_runtime:
a = Assert(
"Input array shape along reduced axes of tensordot are not equal"
)(a, eq(a.shape[ax_a], b.shape[ax_b]))
)(a, eq(runtime_shape_a[ax_a], runtime_shape_b[ax_b]))
must_assert_runtime = True

# Move the axes to sum over to the end of "a"
# and to the front of "b"
notin = [k for k in range(ndim_a) if k not in axes_a]
newaxes_a = notin + axes_a
N2 = 1
for axis in axes_a:
N2 *= runtime_shape_a[axis]
newshape_a = (-1, N2)
olda = [runtime_shape_a[axis] for axis in notin]

notin = [k for k in range(ndim_b) if k not in axes_b]
newaxes_b = axes_b + notin
N2 = 1
for axis in axes_b:
N2 *= runtime_shape_b[axis]
newshape_b = (N2, -1)
oldb = [runtime_shape_b[axis] for axis in notin]

at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
res = _dot(at, bt)
return res.reshape(olda + oldb)
# Convert tensordot into a stacked dot product.
# We stack the summed axes and the non-summed axes of each tensor separately,
# and place the summed axes at the end of a and the beginning of b
non_summed_axes_a = [k for k in range(ndim_a) if k not in axes_a]
non_summed_dims_a = [runtime_shape_a[axis] for axis in non_summed_axes_a]
transpose_axes_a = non_summed_axes_a + axes_a
a_needs_reshape = len(non_summed_axes_a) > 1 or len(axes_a) > 1

non_summed_axes_b = [k for k in range(ndim_b) if k not in axes_b]
non_summed_dims_b = [runtime_shape_b[axis] for axis in non_summed_axes_b]
transpose_axes_b = axes_b + non_summed_axes_b
b_needs_reshape = len(axes_b) > 1 or len(non_summed_axes_b) > 1

# summed_size_a and summed_size_b must be the same,
# but to facilitate reasoning about useless reshapes we compute both from their shapes
at = a.transpose(transpose_axes_a)
if a_needs_reshape:
non_summed_size_a = variadic_mul(*non_summed_dims_a)
summed_size_a = variadic_mul(*[runtime_shape_a[axis] for axis in axes_a])
at = at.reshape((non_summed_size_a, summed_size_a))

bt = b.transpose(transpose_axes_b)
if b_needs_reshape:
non_summed_size_b = variadic_mul(*non_summed_dims_b)
summed_size_b = variadic_mul(*[runtime_shape_b[axis] for axis in axes_b])
bt = bt.reshape((summed_size_b, non_summed_size_b))

res = dot(at, bt)

if a_needs_reshape or b_needs_reshape:
res = res.reshape(non_summed_dims_a + non_summed_dims_b)

return res


def outer(x, y):
Expand Down
2 changes: 1 addition & 1 deletion tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2278,7 +2278,7 @@ def test_type_shape(self):

with pytest.raises(
ValueError,
match="Input arrays have inconsistent broadcastable pattern or type shape",
match="Input arrays have inconsistent type shape",
):
tensordot(ones(shape=(7, 4)), ones(shape=(7, 4)), axes=1)

Expand Down

0 comments on commit c2ee280

Please sign in to comment.