Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support qmatmul with different dims tensors #4438

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 24 additions & 22 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from math import prod
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -186,28 +187,29 @@ def quantized_matmul_meta(
X_size = list(X.size())
Y_size = list(Y.size())

assert len(X_size) == len(
Y_size
), "quantized matmul not supported for tensors of different dimensions"

if len(X_size) == 3:
assert (
X_size[0] == Y_size[0]
), "quantized matmul only supported for batch dimension of same size"
if transposed:
assert X_size[2] == Y_size[2], "matrices cannot be multiplied"
out_size = X_size[:2] + [Y_size[1]]
else:
assert X_size[2] == Y_size[1], "matrices cannot be multiplied"
out_size = X_size[:2] + [Y_size[2]]
elif len(X_size) == 2:
if transposed:
assert X_size[1] == Y_size[1], "matrices cannot be multiplied"
out_size = [X_size[0], Y_size[0]]
else:
assert X_size[1] == Y_size[0], "matrices cannot be multiplied"
out_size = [X_size[0], Y_size[1]]
# Get the batch dimensions for both tensors
X_batch_dims = X_size[:-2]
Y_batch_dims = Y_size[:-2]

# If they don't match, check that they're compatible
if X_batch_dims != Y_batch_dims:
assert prod(X_batch_dims) == prod(
Y_batch_dims
), f"Batch dimensions of X and Y do not match: {X_batch_dims} vs {Y_batch_dims}"

# Get the matmul output size
if transposed:
assert X_size[-1] == Y_size[-1], "matrices cannot be multiplied"
mat_size = [X_size[-2], Y_size[-2]]
else:
raise AssertionError("quantized matmul only supported for 2D or 3D tensors")
assert X_size[-1] == Y_size[-2], "matrices cannot be multiplied"
mat_size = [X_size[-2], Y_size[-1]]

# Combine the larger batch dimensions with the matmul output size
out_size = (
X_batch_dims + mat_size
if len(X_batch_dims) > len(Y_batch_dims)
else Y_batch_dims + mat_size
)
Comment on lines +208 to +213
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why combine the larger of the two batch dims? Perhaps it should just default to the first argument's instead?
This would seem like unpredictable behavior to someone using this function.


return X.new_empty(out_size, dtype=X.dtype)
Loading