diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index c877a7149d..adcf086873 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from math import prod from typing import Optional, Tuple import torch @@ -186,28 +187,29 @@ def quantized_matmul_meta( X_size = list(X.size()) Y_size = list(Y.size()) - assert len(X_size) == len( - Y_size - ), "quantized matmul not supported for tensors of different dimensions" - - if len(X_size) == 3: - assert ( - X_size[0] == Y_size[0] - ), "quantized matmul only supported for batch dimension of same size" - if transposed: - assert X_size[2] == Y_size[2], "matrices cannot be multiplied" - out_size = X_size[:2] + [Y_size[1]] - else: - assert X_size[2] == Y_size[1], "matrices cannot be multiplied" - out_size = X_size[:2] + [Y_size[2]] - elif len(X_size) == 2: - if transposed: - assert X_size[1] == Y_size[1], "matrices cannot be multiplied" - out_size = [X_size[0], Y_size[0]] - else: - assert X_size[1] == Y_size[0], "matrices cannot be multiplied" - out_size = [X_size[0], Y_size[1]] + # Get the batch dimensions for both tensors + X_batch_dims = X_size[:-2] + Y_batch_dims = Y_size[:-2] + + # If they don't match, check that they're compatible + if X_batch_dims != Y_batch_dims: + assert prod(X_batch_dims) == prod( + Y_batch_dims + ), f"Batch dimensions of X and Y do not match: {X_batch_dims} vs {Y_batch_dims}" + + # Get the matmul output size + if transposed: + assert X_size[-1] == Y_size[-1], "matrices cannot be multiplied" + mat_size = [X_size[-2], Y_size[-2]] else: - raise AssertionError("quantized matmul only supported for 2D or 3D tensors") + assert X_size[-1] == Y_size[-2], "matrices cannot be multiplied" + mat_size = [X_size[-2], Y_size[-1]] + + # Combine the larger batch dimensions with the matmul output size + out_size = ( + X_batch_dims + mat_size + if len(X_batch_dims) > len(Y_batch_dims) + else Y_batch_dims + mat_size + ) return X.new_empty(out_size, dtype=X.dtype)