Skip to content

Commit

Permalink
feat(qtensor): use qbytes_mm in linear dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Jun 27, 2024
1 parent afdd4ab commit 1277079
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions optimum/quanto/tensor/qtensor_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch

from .qbits import AWQBitsTensor
from .qbytes import QBytesTensor
from .qtensor import qfallback


Expand Down Expand Up @@ -108,6 +109,10 @@ def forward(ctx, input, other, bias):
bits=4,
group_size=other._group_size,
)
elif isinstance(other, QBytesTensor):
if isinstance(input, QBytesTensor):
output = torch.ops.quanto.qbytes_mm(input._data, other._data, input._scale * other._scale)
output = torch.ops.quanto.qbytes_mm(input, other._data, other._scale)
else:
output = torch.matmul(input, other.t())
if bias is not None:
Expand Down

0 comments on commit 1277079

Please sign in to comment.