From 2525f699ff2339641b064e52b8fef7d489356f98 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Tue, 25 Jun 2024 18:10:50 +0000 Subject: [PATCH] update g_idx --- .../quantization/gptq/utils/gptq_wrapper.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py index 5e8052ffe7..43fe205a97 100644 --- a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -14,10 +14,11 @@ import time +from torch.nn import Parameter + from sparseml.modifiers.utils import SPARSITY_THRESHOLD from sparseml.modifiers.utils.compression_wrapper import ModuleCompressionWrapper -from torch.nn import Parameter try: import transformers @@ -177,13 +178,25 @@ def fasterprune( group_size = quant_scheme.weights.group_size if group_size is None or group_size == -1: group_size = self.layer.weight.shape[1] - + if actorder: - g_idx = torch.Tensor([perm[j] // group_size for j in range(self.columns)], dtype=torch.int32, device=invperm.device) + g_idx = torch.tensor( + [perm[j] // group_size for j in range(self.columns)], + dtype=torch.int32, + device=invperm.device + ) + g_idx = g_idx[invperm] - self.layer.weight_g_idx = Parameter(g_idx, requires_grad=False,) + self.layer.weight_g_idx = Parameter( + g_idx, + requires_grad=False, + ) else: - g_idx = torch.Tensor([j // group_size for j in range(self.columns)], dtype=torch.int32, device=W.device) + g_idx = torch.Tensor( + [j // group_size for j in range(self.columns)], + + device=W.device, + ) from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization.lifecycle.forward import ( @@ -191,13 +204,14 @@ def fasterprune( ) strategy = quant_scheme.weights.strategy - + breakpoint() if strategy == QuantizationStrategy.TENSOR: q = fake_quantize( q, scale, zero_point, self.layer.quantization_scheme.weights, + g_idx, ) elif strategy == QuantizationStrategy.CHANNEL: # TODO: for channelwise why isn't this just a 1d tensor? @@ -205,6 +219,7 @@ def fasterprune( q, scale[:, 0], zero_point[:, 0], + # g_idx, quant_scheme.weights, ) else: # strategy == QuantizationStrategy.GROUP @@ -222,6 +237,7 @@ def fasterprune( q, scale[:, input_dim_group], zero_point[:, input_dim_group], + # g_idx, altered_qargs, ) @@ -247,8 +263,7 @@ def fasterprune( _LOGGER.info("time %.2f" % (time.time() - tick)) _LOGGER.info("error %.2f" % torch.sum(Losses).item()) - - + if actorder: W = W[:, invperm]