From 10878598af33cf23d7ef8343f4b5cb37ff954004 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Mon, 29 Jul 2024 11:56:18 -0700 Subject: [PATCH] Support qmatmul with different dims tensors (#4438) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4438 MobileBERT exposes an issue in our kernel, where tensors have compatible (for PyTorch) but different batch dimensions. This diff changes the meta kernel to support that (the kernel can already do it). Reviewed By: dulinriley Differential Revision: D60314979 --- backends/cadence/aot/ops_registrations.py | 46 ++++++++++++----------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index c877a7149d..adcf086873 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from math import prod from typing import Optional, Tuple import torch @@ -186,28 +187,29 @@ def quantized_matmul_meta( X_size = list(X.size()) Y_size = list(Y.size()) - assert len(X_size) == len( - Y_size - ), "quantized matmul not supported for tensors of different dimensions" - - if len(X_size) == 3: - assert ( - X_size[0] == Y_size[0] - ), "quantized matmul only supported for batch dimension of same size" - if transposed: - assert X_size[2] == Y_size[2], "matrices cannot be multiplied" - out_size = X_size[:2] + [Y_size[1]] - else: - assert X_size[2] == Y_size[1], "matrices cannot be multiplied" - out_size = X_size[:2] + [Y_size[2]] - elif len(X_size) == 2: - if transposed: - assert X_size[1] == Y_size[1], "matrices cannot be multiplied" - out_size = [X_size[0], Y_size[0]] - else: - assert X_size[1] == Y_size[0], "matrices cannot be multiplied" - out_size = [X_size[0], Y_size[1]] + # Get the batch dimensions for both tensors + X_batch_dims = X_size[:-2] + Y_batch_dims = Y_size[:-2] + + # If they don't match, check that they're compatible + if X_batch_dims != Y_batch_dims: + assert prod(X_batch_dims) == prod( + Y_batch_dims + ), f"Batch dimensions of X and Y do not match: {X_batch_dims} vs {Y_batch_dims}" + + # Get the matmul output size + if transposed: + assert X_size[-1] == Y_size[-1], "matrices cannot be multiplied" + mat_size = [X_size[-2], Y_size[-2]] else: - raise AssertionError("quantized matmul only supported for 2D or 3D tensors") + assert X_size[-1] == Y_size[-2], "matrices cannot be multiplied" + mat_size = [X_size[-2], Y_size[-1]] + + # Combine the larger batch dimensions with the matmul output size + out_size = ( + X_batch_dims + mat_size + if len(X_batch_dims) > len(Y_batch_dims) + else Y_batch_dims + mat_size + ) return X.new_empty(out_size, dtype=X.dtype)