diff --git a/optimum/quanto/tensor/qtensor_func.py b/optimum/quanto/tensor/qtensor_func.py index 6bb0dbf4..b85f1d55 100644 --- a/optimum/quanto/tensor/qtensor_func.py +++ b/optimum/quanto/tensor/qtensor_func.py @@ -17,6 +17,7 @@ import torch from .qbits import AWQBitsTensor +from .qbytes import QBytesTensor from .qtensor import qfallback @@ -108,6 +109,10 @@ def forward(ctx, input, other, bias): bits=4, group_size=other._group_size, ) + elif isinstance(other, QBytesTensor): + if isinstance(input, QBytesTensor): + output = torch.ops.quanto.qbytes_mm(input._data, other._data, input._scale * other._scale) + output = torch.ops.quanto.qbytes_mm(input, other._data, other._scale) else: output = torch.matmul(input, other.t()) if bias is not None: