From 37c83b578dff5f244956f21351afcd02438601ca Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Fri, 16 Aug 2024 15:52:19 -0700 Subject: [PATCH] fix: adjust _convert_weight_to_int4pack_cpu input weights for pytorch>=2.5 Fixes: #274 PyTorch 2.5 adjusted input weights of _convert_weight_to_int4pack_cpu from [n][k] int32 to [n][k / 2] uint8. Changing quanto code accordingly. See: https://github.com/pytorch/pytorch/pull/129940 See: https://github.com/pytorch/pytorch/commit/6f662e95756333284450ff9c3c6e78c796aa6e77 Signed-off-by: Dmitry Rogozhkin --- bench/torch_kernels/test_weight_int4pack_mm.py | 7 ++++++- optimum/quanto/tensor/qbits/tinygemm/packed.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/bench/torch_kernels/test_weight_int4pack_mm.py b/bench/torch_kernels/test_weight_int4pack_mm.py index c12ca17d..d6f9a5c8 100644 --- a/bench/torch_kernels/test_weight_int4pack_mm.py +++ b/bench/torch_kernels/test_weight_int4pack_mm.py @@ -16,6 +16,7 @@ import timeit import torch +from packaging import version def _group_quantize_tensor(w, n_bit=4, q_group_size=16): @@ -90,7 +91,11 @@ def avg_time(f, it): B = torch.rand([3200, 4800], dtype=dtype, device=device) group_size = 128 B_int32, B_scale_and_zeros = _group_quantize_tensor(B, n_bit=4, q_group_size=group_size) - B_packed = torch._convert_weight_to_int4pack(B_int32, innerKTiles=2) + if version.parse(torch.__version__).release >= version.parse("2.5.0").release: + B_uint8 = (B_int32[::, ::2] << 4 | B_int32[::, 1::2]).to(torch.uint8) + B_packed = torch._convert_weight_to_int4pack(B_uint8, innerKTiles=2) + else: + B_packed = torch._convert_weight_to_int4pack(B_int32, innerKTiles=2) # Check quantized mm is close to float mm qout = torch._weight_int4pack_mm(A, B_packed, group_size, B_scale_and_zeros) diff --git a/optimum/quanto/tensor/qbits/tinygemm/packed.py b/optimum/quanto/tensor/qbits/tinygemm/packed.py index 326beab1..b1d15c0e 100644 --- a/optimum/quanto/tensor/qbits/tinygemm/packed.py +++ b/optimum/quanto/tensor/qbits/tinygemm/packed.py @@ -16,6 +16,7 @@ from copy import copy import torch +from packaging import version from torch.utils import _pytree as pytree @@ -53,7 +54,11 @@ def pack(cls, t): """ inner_ktiles = 2 t = t.to(torch.int32).contiguous() - data = torch._convert_weight_to_int4pack(t, innerKTiles=inner_ktiles) + if version.parse(torch.__version__).release >= version.parse("2.5.0").release: + t_uint8 = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) + data = torch._convert_weight_to_int4pack(t_uint8, innerKTiles=inner_ktiles) + else: + data = torch._convert_weight_to_int4pack(t, innerKTiles=inner_ktiles) # We need to store size and stride to make sure the unpacked data has the correct shape return TinyGemmPackedTensor(data, t.size(), t.stride())