Skip to content

Commit

Permalink
fix(qbytes_mm): reshape input
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Oct 1, 2024
1 parent 3303f71 commit 4121d1e
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion optimum/quanto/tensor/weights/qbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ def forward(ctx, input, other, bias=None):
if isinstance(input, QBytesTensor):
output = torch.ops.quanto.qbytes_mm(input._data, other._data, input._scale * other._scale)
else:
output = torch.ops.quanto.qbytes_mm(input, other._data, other._scale)
in_features = input.shape[-1]
out_features = other.shape[0]
output_shape = input.shape[:-1] + (out_features,)
output = torch.ops.quanto.qbytes_mm(input.view(-1, in_features), other._data, other._scale)
output = output.view(output_shape)
if bias is not None:
output = output + bias
return output
Expand Down

0 comments on commit 4121d1e

Please sign in to comment.