Skip to content

Commit

Permalink
fix: adjust _convert_weight_to_int4pack_cpu input weights for pytorch…
Browse files Browse the repository at this point in the history
…>=2.5

Fixes: #274

PyTorch 2.5 adjusted input weights of _convert_weight_to_int4pack_cpu
from [n][k] int32 to [n][k / 2] uint8. Changing quanto code accordingly.

See: pytorch/pytorch#129940
See: pytorch/pytorch@6f662e9
Signed-off-by: Dmitry Rogozhkin <[email protected]>
  • Loading branch information
dvrogozh committed Aug 16, 2024
1 parent 3c7a807 commit 37c83b5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
7 changes: 6 additions & 1 deletion bench/torch_kernels/test_weight_int4pack_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import timeit

import torch
from packaging import version


def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
Expand Down Expand Up @@ -90,7 +91,11 @@ def avg_time(f, it):
B = torch.rand([3200, 4800], dtype=dtype, device=device)
group_size = 128
B_int32, B_scale_and_zeros = _group_quantize_tensor(B, n_bit=4, q_group_size=group_size)
B_packed = torch._convert_weight_to_int4pack(B_int32, innerKTiles=2)
if version.parse(torch.__version__).release >= version.parse("2.5.0").release:
B_uint8 = (B_int32[::, ::2] << 4 | B_int32[::, 1::2]).to(torch.uint8)
B_packed = torch._convert_weight_to_int4pack(B_uint8, innerKTiles=2)
else:
B_packed = torch._convert_weight_to_int4pack(B_int32, innerKTiles=2)

# Check quantized mm is close to float mm
qout = torch._weight_int4pack_mm(A, B_packed, group_size, B_scale_and_zeros)
Expand Down
7 changes: 6 additions & 1 deletion optimum/quanto/tensor/qbits/tinygemm/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from copy import copy

import torch
from packaging import version
from torch.utils import _pytree as pytree


Expand Down Expand Up @@ -53,7 +54,11 @@ def pack(cls, t):
"""
inner_ktiles = 2
t = t.to(torch.int32).contiguous()
data = torch._convert_weight_to_int4pack(t, innerKTiles=inner_ktiles)
if version.parse(torch.__version__).release >= version.parse("2.5.0").release:
t_uint8 = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8)
data = torch._convert_weight_to_int4pack(t_uint8, innerKTiles=inner_ktiles)
else:
data = torch._convert_weight_to_int4pack(t, innerKTiles=inner_ktiles)
# We need to store size and stride to make sure the unpacked data has the correct shape
return TinyGemmPackedTensor(data, t.size(), t.stride())

Expand Down

0 comments on commit 37c83b5

Please sign in to comment.