diff --git a/bench/torch_kernels/test_weight_int4pack_mm.py b/bench/torch_kernels/test_weight_int4pack_mm.py index c12ca17d..d6f9a5c8 100644 --- a/bench/torch_kernels/test_weight_int4pack_mm.py +++ b/bench/torch_kernels/test_weight_int4pack_mm.py @@ -16,6 +16,7 @@ import timeit import torch +from packaging import version def _group_quantize_tensor(w, n_bit=4, q_group_size=16): @@ -90,7 +91,11 @@ def avg_time(f, it): B = torch.rand([3200, 4800], dtype=dtype, device=device) group_size = 128 B_int32, B_scale_and_zeros = _group_quantize_tensor(B, n_bit=4, q_group_size=group_size) - B_packed = torch._convert_weight_to_int4pack(B_int32, innerKTiles=2) + if version.parse(torch.__version__).release >= version.parse("2.5.0").release: + B_uint8 = (B_int32[::, ::2] << 4 | B_int32[::, 1::2]).to(torch.uint8) + B_packed = torch._convert_weight_to_int4pack(B_uint8, innerKTiles=2) + else: + B_packed = torch._convert_weight_to_int4pack(B_int32, innerKTiles=2) # Check quantized mm is close to float mm qout = torch._weight_int4pack_mm(A, B_packed, group_size, B_scale_and_zeros) diff --git a/optimum/quanto/tensor/qbits/tinygemm/packed.py b/optimum/quanto/tensor/qbits/tinygemm/packed.py index 326beab1..b1d15c0e 100644 --- a/optimum/quanto/tensor/qbits/tinygemm/packed.py +++ b/optimum/quanto/tensor/qbits/tinygemm/packed.py @@ -16,6 +16,7 @@ from copy import copy import torch +from packaging import version from torch.utils import _pytree as pytree @@ -53,7 +54,11 @@ def pack(cls, t): """ inner_ktiles = 2 t = t.to(torch.int32).contiguous() - data = torch._convert_weight_to_int4pack(t, innerKTiles=inner_ktiles) + if version.parse(torch.__version__).release >= version.parse("2.5.0").release: + t_uint8 = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) + data = torch._convert_weight_to_int4pack(t_uint8, innerKTiles=inner_ktiles) + else: + data = torch._convert_weight_to_int4pack(t, innerKTiles=inner_ktiles) # We need to store size and stride to make sure the unpacked data has the correct shape return TinyGemmPackedTensor(data, t.size(), t.stride())