From 3908fcdf11b456772f1645584c97f40c90bf2ba7 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Wed, 22 May 2024 09:52:48 +0000 Subject: [PATCH] refactor(qmodule): group_size can only be 128, 64 or 32 --- quanto/nn/qmodule.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/quanto/nn/qmodule.py b/quanto/nn/qmodule.py index 72608af4..95c372ad 100644 --- a/quanto/nn/qmodule.py +++ b/quanto/nn/qmodule.py @@ -123,11 +123,13 @@ def __init__( self.weight_group_size = None if self.weight_qtype in (qint2, qint4): out_features = self.weight.shape[0] - if out_features >= 128: - group_size = self.weight.numel() // out_features - while group_size > 128 and group_size % 2 == 0: - group_size = group_size // 2 - self.weight_group_size = group_size + in_features = self.weight.numel() // out_features + group_size = 128 + if in_features > group_size: + while in_features % group_size != 0 and group_size > 32: + group_size -= 32 + if in_features % group_size == 0: + self.weight_group_size = group_size self.activation_qtype = activations self.optimizer = optimizer self.register_buffer("input_scale", torch.ones(()))