From ce81191f2d76f794327d11f6e0de9c0d39ccc41e Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Wed, 13 Nov 2024 18:37:00 +0000 Subject: [PATCH] Fix uninitialized variable in quantized compressors Both compressors have a can_quantize() check, which if ever doesn't succeed would trigger: > UnboundLocalError: cannot access local variable 'quantized_weight' where it is not associated with a value Add the obvious fix for this and highly artificial test cases that would trigger it. --- .../quantized_compressors/naive_quantized.py | 6 ++-- .../quantized_compressors/pack_quantized.py | 2 ++ .../quantized_compressors/test_int_quant.py | 34 +++++++++++++++---- .../quantized_compressors/test_pack_quant.py | 28 ++++++++++++--- 4 files changed, 57 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py index 0267aca4..85eebe00 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py @@ -93,9 +93,11 @@ def compress_weight( args=quantization_args, dtype=quantization_args.pytorch_dtype(), ) + else: + quantized_weight = weight - if device is not None: - quantized_weight = quantized_weight.to(device) + if device is not None: + quantized_weight = quantized_weight.to(device) return {"weight": quantized_weight} diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index ce9f0a57..c236f8c9 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -94,6 +94,8 @@ def compress_weight( args=quantization_args, dtype=torch.int8, ) + else: + quantized_weight = weight packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits) weight_shape = torch.tensor(weight.shape) diff --git a/tests/test_compressors/quantized_compressors/test_int_quant.py b/tests/test_compressors/quantized_compressors/test_int_quant.py index e4921508..e106a372 100644 --- a/tests/test_compressors/quantized_compressors/test_int_quant.py +++ b/tests/test_compressors/quantized_compressors/test_int_quant.py @@ -91,32 +91,54 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp): @pytest.mark.parametrize( - "strategy,group_size,sc,zp", + "strategy,group_size,sc,zp,int8_weights", [ - [QuantizationStrategy.TENSOR, None, 0.01, 0], + [QuantizationStrategy.TENSOR, None, 0.01, 0, False], + [QuantizationStrategy.TENSOR, None, 1, 0, True], [ QuantizationStrategy.GROUP, 128, torch.rand((300, 8)) * 0.01, torch.zeros((300, 8), dtype=torch.int8), + False, ], [ QuantizationStrategy.CHANNEL, None, torch.rand((300, 1)) * 0.01, torch.zeros((300, 1), dtype=torch.int8), + False, ], ], ) -def test_reload_match(strategy, group_size, sc, zp, tmp_path): +def test_reload_match(strategy, group_size, sc, zp, int8_weights, tmp_path): dense_state_dict = { "dummy.weight": torch.rand((300, 1024)), "dummy.weight_scale": torch.tensor(sc, dtype=torch.float32), "dummy.weight_zero_point": torch.tensor(zp, dtype=torch.int32), - "dummy2.weight": torch.rand((300, 1024)), - "dummy2.weight_scale": torch.tensor(sc, dtype=torch.float32), - "dummy2.weight_zero_point": torch.tensor(zp, dtype=torch.int32), } + if not int8_weights: + dense_state_dict.update( + { + "dummy2.weight": torch.rand((300, 1024)), + "dummy2.weight_scale": torch.tensor(sc, dtype=torch.float32), + "dummy2.weight_zero_point": torch.tensor(zp, dtype=torch.int32), + } + ) + else: + dense_state_dict.update( + { + "dummy2.weight": torch.randint( + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max, + (511, 350), + dtype=torch.int8, + ), + "dummy2.weight_scale": torch.tensor(sc, dtype=torch.float32), + "dummy2.weight_zero_point": torch.tensor(zp, dtype=torch.int32), + } + ) + quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size) compressor = IntQuantizationCompressor(config=quant_config) diff --git a/tests/test_compressors/quantized_compressors/test_pack_quant.py b/tests/test_compressors/quantized_compressors/test_pack_quant.py index fde57c4b..efaef66f 100644 --- a/tests/test_compressors/quantized_compressors/test_pack_quant.py +++ b/tests/test_compressors/quantized_compressors/test_pack_quant.py @@ -140,16 +140,34 @@ def test_repack_8bit(value): assert torch.equal(value, unpacked) -@pytest.mark.parametrize("num_bits", [4, 8]) -def test_reload_match(tmp_path, num_bits): +@pytest.mark.parametrize("num_bits,int8_weights", [(4, False), (8, False), (8, True)]) +def test_reload_match(tmp_path, num_bits, int8_weights): dense_state_dict = { "dummy.weight": torch.rand((511, 350)), "dummy.weight_scale": torch.tensor(0.01, dtype=torch.float32), "dummy.weight_zero_point": torch.tensor(0, dtype=torch.int8), - "dummy2.weight": torch.rand((128, 280)), - "dummy2.weight_scale": torch.tensor(0.02, dtype=torch.float32), - "dummy2.weight_zero_point": torch.tensor(15, dtype=torch.int8), } + if not int8_weights: + dense_state_dict.update( + { + "dummy2.weight": torch.rand((128, 280)), + "dummy2.weight_scale": torch.tensor(0.02, dtype=torch.float32), + "dummy2.weight_zero_point": torch.tensor(15, dtype=torch.int8), + } + ) + else: + dense_state_dict.update( + { + "dummy2.weight": torch.randint( + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max, + (511, 350), + dtype=torch.int8, + ), + "dummy2.weight_scale": torch.tensor(1, dtype=torch.float32), + "dummy2.weight_zero_point": torch.tensor(0, dtype=torch.int8), + } + ) names_to_scheme = { "dummy": QuantizationArgs(num_bits=num_bits),