From b08e0acb39f3bb2b7b021efce1e9f4de4af2166a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 27 Sep 2024 09:46:06 +0100 Subject: [PATCH 1/5] Fix (llm): small fixes to LLM --- src/brevitas/core/function_wrapper/shape.py | 5 +++++ src/brevitas/core/stats/stats_op.py | 6 +++++- src/brevitas_examples/common/generative/quantize.py | 9 ++++++--- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/brevitas/core/function_wrapper/shape.py b/src/brevitas/core/function_wrapper/shape.py index f1dfc7796..e175e4445 100644 --- a/src/brevitas/core/function_wrapper/shape.py +++ b/src/brevitas/core/function_wrapper/shape.py @@ -165,6 +165,11 @@ def __init__(self, expanded_groupwise_shape, group_size, group_dim) -> None: @brevitas.jit.script_method def forward(self, x: torch.Tensor): + # This one is a bit tricky but we could end up here: + # - If we quantize the zero point, which will already have expanded shape matching the scale (although no padding, but we don't need the padding) + # - Groupwise HQO quantization, where weight will already have been padded and expanded + if len(x.shape) == len(self.expanded_groupwise_shape): + return x y = torch.nn.functional.pad( x, padding(x, self.group_size, self.group_dim), mode='constant', value=0.) y = y.view(self.expanded_groupwise_shape) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index 461aeb3e6..29d4d06e8 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -692,7 +692,8 @@ def parameter_search(self, xl, x): self.set_local_loss_mode(False) qt_value = self.input_view_shape_impl(quant_tensor.value) qt_scale = self.input_view_shape_impl(quant_tensor.scale) - qt_int = self.input_view_shape_impl(quant_tensor.int()) + qt_zp = self.input_view_shape_impl(quant_tensor.zero_point) + qt_int = qt_value / qt_scale + qt_zp loss = torch.abs(qt_value - x).mean() best_candidate = torch.where(loss < best_loss, candidate, best_candidate) if loss >= best_loss: @@ -700,6 +701,9 @@ def parameter_search(self, xl, x): best_loss = torch.min(loss, best_loss) W_e = shrink_lp_op(x - qt_value, self.beta, self.lp_norm) + # Compared to the original formulation, the value we're looking for is: + # - scaled by qt_scale + # - opposite sign val = self.input_view_shape_impl((x - W_e) - qt_int * qt_scale) if self.stats_reduce_dim is None: diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 10f7ce259..f74a91933 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -98,13 +98,16 @@ 'sym': Int8WeightPerChannelFixedPoint}, 'per_group': { 'sym': MXInt8Weight, 'asym': ShiftedMXUInt8Weight}}, + 'hqo': { + 'per_group': { + 'asym': MXHQO}}, 'mse': { 'per_tensor': { 'sym': Int8WeightPerTensorFixedPointMSE}, 'per_channel': { - 'sym': Int8WeightPerChannelFixedPointMSE}}, - 'per_group': { - 'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}}}, + 'sym': Int8WeightPerChannelFixedPointMSE}, + 'per_group': { + 'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}}}}, 'float': { 'float_scale': { 'stats': { From fd51bb898a8793aec90038a032a626ae20ec076e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 1 Oct 2024 12:17:58 +0100 Subject: [PATCH 2/5] Groupwise MSE support --- src/brevitas/quant/base.py | 53 +++++++++++++------ .../quant/experimental/mx_quant_ocp.py | 4 +- src/brevitas/quant/shifted_scaled_int.py | 6 +-- .../common/generative/quantize.py | 3 -- 4 files changed, 41 insertions(+), 25 deletions(-) diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 92f41b990..e1d118239 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -429,21 +429,7 @@ class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant): pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl -class MSESubInjectorBase(ExtendedInjector): - - @value - def inner_stats_input_view_shape_impl(scaling_per_output): - if scaling_per_output == ScalingPerOutputType.CHANNEL: - return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS - elif scaling_per_output == ScalingPerOutputType.TENSOR: - return StatsInputViewShapeImpl.OVER_TENSOR - elif scaling_per_output == ScalingPerOutputType.GROUP: - raise RuntimeError("Not implemented yet") - - permute_dims = (this << 1).permute_dims - - -class MSESymmetricScaleSubInjector(MSESubInjectorBase): +class MSESymmetricScaleSubInjector(ExtendedInjector): scaling_per_output = (this << 1).scaling_per_output proxy_module = (this << 1).proxy_module mse_init_op = AbsMax @@ -451,9 +437,11 @@ class MSESymmetricScaleSubInjector(MSESubInjectorBase): stats_reduce_dim = (this << 1).stats_reduce_dim device = (this << 1).device type = (this << 1).type + permute_dims = (this << 1).permute_dims + inner_stats_input_view_shape_impl = (this << 1).inner_stats_input_view_shape_impl -class MSEAsymmetricScaleSubInjector(MSESubInjectorBase): +class MSEAsymmetricScaleSubInjector(ExtendedInjector): scaling_per_output = (this << 1).scaling_per_output proxy_module = (this << 1).proxy_module mse_init_op = AbsMinMax @@ -461,9 +449,11 @@ class MSEAsymmetricScaleSubInjector(MSESubInjectorBase): stats_reduce_dim = (this << 1).stats_reduce_dim device = (this << 1).device dtype = (this << 1).dtype + permute_dims = (this << 1).permute_dims + inner_stats_input_view_shape_impl = (this << 1).inner_stats_input_view_shape_impl -class MSEZeroPointSubInjector(MSESubInjectorBase): +class MSEZeroPointSubInjector(ExtendedInjector): # zp is per channel when scaling is per channel scaling_per_output = (this << 1).scaling_per_output proxy_module = (this << 1).proxy_module @@ -473,6 +463,8 @@ class MSEZeroPointSubInjector(MSESubInjectorBase): stats_reduce_dim = (this << 1).stats_reduce_dim device = (this << 1).device dtype = (this << 1).dtype + permute_dims = (this << 1).permute_dims + inner_stats_input_view_shape_impl = (this << 1).inner_stats_input_view_shape_impl class MSEAsymmetricScale(ExtendedInjector): @@ -484,6 +476,15 @@ class MSEAsymmetricScale(ExtendedInjector): scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS scaling_stats_input_view_shape_impl = nn.Identity() + @value + def inner_stats_input_view_shape_impl(scaling_per_output): + if scaling_per_output == ScalingPerOutputType.CHANNEL: + return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS + elif scaling_per_output == ScalingPerOutputType.TENSOR: + return StatsInputViewShapeImpl.OVER_TENSOR + elif scaling_per_output == ScalingPerOutputType.GROUP: + return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK + @value def scaling_stats_impl(): return this.mse_scale.stats_impl @@ -498,6 +499,15 @@ class MSESymmetricScale(ExtendedInjector): scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS scaling_stats_input_view_shape_impl = nn.Identity() + @value + def inner_stats_input_view_shape_impl(scaling_per_output): + if scaling_per_output == ScalingPerOutputType.CHANNEL: + return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS + elif scaling_per_output == ScalingPerOutputType.TENSOR: + return StatsInputViewShapeImpl.OVER_TENSOR + elif scaling_per_output == ScalingPerOutputType.GROUP: + return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK + @value def scaling_stats_impl(): return this.mse_scale.stats_impl @@ -511,6 +521,15 @@ class MSEZeroPoint(ExtendedInjector): mse_zero_point = MSEZeroPointSubInjector zero_point_stats_input_view_shape_impl = nn.Identity() + @value + def inner_stats_input_view_shape_impl(scaling_per_output): + if scaling_per_output == ScalingPerOutputType.CHANNEL: + return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS + elif scaling_per_output == ScalingPerOutputType.TENSOR: + return StatsInputViewShapeImpl.OVER_TENSOR + elif scaling_per_output == ScalingPerOutputType.GROUP: + return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK + @value def zero_point_stats_impl(): return this.mse_zero_point.stats_impl diff --git a/src/brevitas/quant/experimental/mx_quant_ocp.py b/src/brevitas/quant/experimental/mx_quant_ocp.py index 6e61d078c..2299c1783 100644 --- a/src/brevitas/quant/experimental/mx_quant_ocp.py +++ b/src/brevitas/quant/experimental/mx_quant_ocp.py @@ -123,14 +123,14 @@ class MXInt8Act(MXActMixin, GroupwiseActProxyMixin, IntQuant, MaxStatsScaling, A bit_width = 8 -class MXInt8WeightMSE(MXInt8Weight, MSESymmetricScale): +class MXInt8WeightMSE(MSESymmetricScale, MXInt8Weight): """ MX Int signed weight quantizer with per-channel MSE-based scaling. """ pass -class ShiftedMXUInt8WeightMSE(ShiftedMXUInt8Weight, MSEAsymmetricScale): +class ShiftedMXUInt8WeightMSE(MSEAsymmetricScale, ShiftedMXUInt8Weight): """ MX Int signed weight quantizer with per-channel MSE-based scaling. """ diff --git a/src/brevitas/quant/shifted_scaled_int.py b/src/brevitas/quant/shifted_scaled_int.py index d18150a10..d52df53b1 100644 --- a/src/brevitas/quant/shifted_scaled_int.py +++ b/src/brevitas/quant/shifted_scaled_int.py @@ -5,7 +5,7 @@ from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector from brevitas.quant.base import * from brevitas.quant.base import HQOActZeroPoint -from brevitas.quant.base import HQOZeroPoint +from brevitas.quant.base import HQOWeightZeroPoint from brevitas.quant.solver.act import ActQuantSolver from brevitas.quant.solver.weight import WeightQuantSolver @@ -145,7 +145,7 @@ class ShiftedUint8WeightPerChannelFloatMSE(MSEAsymmetricScale, pass -class ShiftedUint8WeightPerTensorFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerTensorFloat): +class ShiftedUint8WeightPerTensorFloatHQO(HQOWeightZeroPoint, ShiftedUint8WeightPerTensorFloat): """ 8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer zero point. Zero-point is initialized from HQO local loss. @@ -157,7 +157,7 @@ class ShiftedUint8WeightPerTensorFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerTen quantize_zero_point = False -class ShiftedUint8WeightPerChannelFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerChannelFloat): +class ShiftedUint8WeightPerChannelFloatHQO(HQOWeightZeroPoint, ShiftedUint8WeightPerChannelFloat): """ 8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer zero point. Zero-point is initialized from HQO local loss. diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index f74a91933..5259c6776 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -98,9 +98,6 @@ 'sym': Int8WeightPerChannelFixedPoint}, 'per_group': { 'sym': MXInt8Weight, 'asym': ShiftedMXUInt8Weight}}, - 'hqo': { - 'per_group': { - 'asym': MXHQO}}, 'mse': { 'per_tensor': { 'sym': Int8WeightPerTensorFixedPointMSE}, From 6297b173b73ae98687a42d1a09dfde342880f6eb Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 1 Oct 2024 12:25:41 +0100 Subject: [PATCH 3/5] New groupdim options for LLM --- src/brevitas_examples/common/generative/quantize.py | 5 +++++ src/brevitas_examples/llm/main.py | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 5259c6776..15b47884f 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -210,6 +210,7 @@ def generate_quantizers( weight_group_size, quantize_weight_zero_point, weight_quant_format='int', + weight_group_dim=None, input_bit_width=None, input_quant_format='', input_scale_precision=None, @@ -276,6 +277,10 @@ def generate_quantizers( 'narrow_range': False, 'quantize_zero_point': quantize_weight_zero_point}, **weight_float_format) + + if weight_group_dim is not None: + weight_quant = weight_quant.let(**{'group_dim': weight_group_dim}) + if dtype == torch.float16: weight_quant = weight_quant.let(**{'scaling_min_val': 1e-4}) if weight_kwargs is not None: diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index e19390774..67c8144c7 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -226,6 +226,7 @@ def main(args): weight_quant_type=args.weight_quant_type, weight_quant_granularity=args.weight_quant_granularity, weight_group_size=args.weight_group_size, + weight_group_dim=args.weight_group_dim, quantize_weight_zero_point=args.quantize_weight_zero_point, weight_quant_format=args.weight_quant_format, input_bit_width=args.input_bit_width, @@ -358,6 +359,12 @@ def parse_args(args): default='per_group', choices=['per_channel', 'per_tensor', 'per_group'], help='Granularity for scales/zero-point of weights. Default: per_group.') + parser.add_argument( + '--weight-group-dim', + type=int, + default=None, + choices=[1, 0], + help='Override default group_dim for groupsize quantization. Default: layer-dependant') parser.add_argument( '--weight-group-size', type=int, From 7c35004a8b02bb4648793fc427be87484ee3a0f6 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 8 Oct 2024 13:51:48 +0100 Subject: [PATCH 4/5] Fix groupwise param_from_stats --- src/brevitas_examples/common/generative/quantize.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 15b47884f..9460fadf1 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -290,9 +290,8 @@ def generate_quantizers( if weight_quant_granularity == 'per_group': weight_quant = weight_quant.let(**{'group_size': weight_group_size}) # weight scale is converted to a standalone parameter - # This is done already by default in the per_group quantizer - if weight_quant_granularity != 'per_group': - weight_quant = weight_quant.let(scaling_impl_type='parameter_from_stats') + + weight_quant = weight_quant.let(scaling_impl_type='parameter_from_stats') # weight zero-point is converted to a standalone parameter # This is done already by default in the per_group quantizer if weight_quant_type == 'asym' and weight_quant_granularity != 'per_group': From 84dae4e1aadaa96592d421201ba5edee240f5276 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 8 Oct 2024 13:56:03 +0100 Subject: [PATCH 5/5] README --- src/brevitas_examples/llm/README.md | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index cdf708d17..06effc7b9 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -17,11 +17,12 @@ Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Curren usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval] [--dataset {wikitext2,c4}] [--weight-bit-width WEIGHT_BIT_WIDTH] - [--weight-param-method {stats,mse}] + [--weight-param-method {stats,mse,hqo}] [--weight-scale-precision {float_scale,po2_scale}] [--weight-quant-type {sym,asym}] [--weight-quant-format WEIGHT_QUANT_FORMAT] [--weight-quant-granularity {per_channel,per_tensor,per_group}] + [--weight-group-dim {1,0}] [--weight-group-size WEIGHT_GROUP_SIZE] [--quantize-weight-zero-point] [--input-bit-width INPUT_BIT_WIDTH] @@ -38,6 +39,7 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--weight-equalization] [--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ] [--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}] + [--export-prefix EXPORT_PREFIX] [--checkpoint-name CHECKPOINT_NAME] options: @@ -51,7 +53,7 @@ options: Dataset to use for quantization (default: wikitext2) --weight-bit-width WEIGHT_BIT_WIDTH Weight bit width. Default: 8. - --weight-param-method {stats,mse} + --weight-param-method {stats,mse,hqo} How scales/zero-point are determined. Default: stats. --weight-scale-precision {float_scale,po2_scale} Whether scale is a float value or a po2. Default: po2. @@ -65,6 +67,9 @@ options: --weight-quant-granularity {per_channel,per_tensor,per_group} Granularity for scales/zero-point of weights. Default: per_group. + --weight-group-dim {1,0} + Override default group_dim for groupsize quantization. + Default: layer-dependant --weight-group-size WEIGHT_GROUP_SIZE Group size for per_group weight quantization. Default: 128. @@ -119,6 +124,10 @@ options: --load-awq LOAD_AWQ Load the awq search results. --export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight} Model export. + --export-prefix EXPORT_PREFIX + Path prefix to use for the various export flows. If + None, a path will be derived from the model name + (default: None) --checkpoint-name CHECKPOINT_NAME Filename to save checkpoint. If `None`, no checkpoint is saved (default: None)