diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index b1d3475601..69b0fb6e99 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -420,10 +420,9 @@ def ffn_or_attn_only(mod, fqn): else: quantize_(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization: + use_hqq = False if "hqq" in quantization: use_hqq = True - else: - use_hqq = False group_size = int(quantization.split("-")[1]) assert ( group_size @@ -434,7 +433,7 @@ def ffn_or_attn_only(mod, fqn): 256, ] ), f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" - quantize_(model, int4_weight_only(group_size=group_size)) + quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) elif "int8adq-int4w-symm" in quantization: from torchao.dtypes import CutlassInt4PackedLayout