From 9a69ca3cc4d0cbf1896a0c86ecc91a59cb74777a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 3 Oct 2024 07:23:41 +0100 Subject: [PATCH] fix gptq --- src/brevitas/graph/gpxq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 75992e8fe..d8b436fc1 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -273,7 +273,7 @@ def get_quant_weights(self, i, i1, permutation_list): index = permutation_list[0][i] q = self.layer.quant_weight(quant_input=self.quant_metadata).value.unsqueeze( 0) # [1, OC, 1] - q = q[:, :, i:i + 1] # [groups, OC/groups, 1] + q = q[:, :, index:index + 1] # [groups, OC/groups, 1] else: index = permutation_list[0][i] subtensor_slice_list = [None, (index, index + 1)]