Skip to content

Commit

Permalink
refactor(qmodule): group_size can only be 128, 64 or 32
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed May 23, 2024
1 parent 6c1edd3 commit 3908fcd
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions quanto/nn/qmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,13 @@ def __init__(
self.weight_group_size = None
if self.weight_qtype in (qint2, qint4):
out_features = self.weight.shape[0]
if out_features >= 128:
group_size = self.weight.numel() // out_features
while group_size > 128 and group_size % 2 == 0:
group_size = group_size // 2
self.weight_group_size = group_size
in_features = self.weight.numel() // out_features
group_size = 128
if in_features > group_size:
while in_features % group_size != 0 and group_size > 32:
group_size -= 32
if in_features % group_size == 0:
self.weight_group_size = group_size
self.activation_qtype = activations
self.optimizer = optimizer
self.register_buffer("input_scale", torch.ones(()))
Expand Down

0 comments on commit 3908fcd

Please sign in to comment.