From d653001bb0987c53031b828e8ef40e9593482e39 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 28 Aug 2024 13:12:57 +0100 Subject: [PATCH] PTQ update --- .../common/generative/quant_blocks.py | 5 +- .../benchmark/ptq_benchmark_torchvision.py | 4 + .../imagenet_classification/ptq/ptq_common.py | 91 ++++++++++++++----- .../ptq/ptq_evaluate.py | 16 ++-- tests/brevitas/graph/equalization_fixtures.py | 1 + 5 files changed, 82 insertions(+), 35 deletions(-) diff --git a/src/brevitas_examples/common/generative/quant_blocks.py b/src/brevitas_examples/common/generative/quant_blocks.py index 18149578d..696340a2c 100644 --- a/src/brevitas_examples/common/generative/quant_blocks.py +++ b/src/brevitas_examples/common/generative/quant_blocks.py @@ -3,15 +3,12 @@ # SPDX-License-Identifier: BSD-3-Clause """ -from typing import Callable, List, Optional, Tuple +from typing import Callable import torch from torch import Tensor import torch.nn as nn -import brevitas -from brevitas.core.function_wrapper.shape import PermuteDims -from brevitas.core.utils import SliceTensor from brevitas.core.zero_point import _ScaleShiftZeroPoint from brevitas.function.ops_ste import abs_binary_sign_grad diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index 668eee22c..69a5f626a 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -89,7 +89,9 @@ def unique(sequence): 'act_bit_width': [8], # Act bit width 'bias_bit_width': [32], # Bias Bit-Width for Po2 scale 'weight_quant_granularity': ['per_channel'], # Scaling Per Output Channel + 'act_quant_granularity': ['per_tensor'], # Scaling Per Output Channel 'act_quant_type': ['sym'], # Act Quant Type + 'act_scale_computation_type': ['static'], # Act Quant Type 'act_param_method': ['stats'], # Act Param Method 'weight_param_method': ['mse'], # Weight Quant Type 'bias_corr': [True], # Bias Correction @@ -240,7 +242,9 @@ def ptq_torchvision_models(args): weight_param_method=config_namespace.weight_param_method, act_param_method=config_namespace.act_param_method, bias_bit_width=config_namespace.bias_bit_width, + act_scale_computation_type=config_namespace.act_scale_computation_type, weight_quant_granularity=config_namespace.weight_quant_granularity, + act_quant_granularity=config_namespace.act_quant_granularity, act_quant_percentile=config_namespace.act_quant_percentile, act_quant_type=config_namespace.act_quant_type, scale_factor_type=config_namespace.scale_factor_type, diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 3c6b82243..bac596be5 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -29,6 +29,20 @@ from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloatMSE from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloatMSE +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloatMSE +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloatMSE +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloatMSE +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE +from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act +from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight +from brevitas.quant.experimental.mx_quant_ocp import MXInt8WeightMSE +from brevitas.quant.experimental.mx_quant_ocp import ShiftedMXUInt8Weight +from brevitas.quant.experimental.mx_quant_ocp import ShiftedMXUInt8WeightMSE from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint @@ -96,12 +110,16 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): 'per_tensor': { 'sym': Int8WeightPerTensorFixedPoint}, 'per_channel': { - 'sym': Int8WeightPerChannelFixedPoint},}, + 'sym': Int8WeightPerChannelFixedPoint}, + 'per_group': { + 'sym': MXInt8Weight, 'asym': ShiftedMXUInt8Weight}}, 'mse': { 'per_tensor': { 'sym': Int8WeightPerTensorFixedPointMSE}, 'per_channel': { - 'sym': Int8WeightPerChannelFixedPointMSE}},}}, + 'sym': Int8WeightPerChannelFixedPointMSE}, + 'per_group': { + 'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}},}}, 'float': { 'float_scale': { 'stats': { @@ -113,7 +131,26 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): 'per_tensor': { 'sym': Fp8e4m3WeightPerTensorFloatMSE}, 'per_channel': { - 'sym': Fp8e4m3WeightPerChannelFloatMSE}}}}} + 'sym': Fp8e4m3WeightPerChannelFloatMSE}}}}, + 'float_ocp': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3OCPWeightPerTensorFloat}, + 'per_channel': { + 'sym': Fp8e4m3OCPWeightPerChannelFloat}}, + 'mse': { + 'per_tensor': { + 'sym': Fp8e4m3OCPWeightPerTensorFloatMSE}, + 'per_channel': { + 'sym': Fp8e4m3OCPWeightPerChannelFloatMSE}}}, + 'po2_scale': { + 'stats': { + 'per_group': { + 'sym': MXFloat8e4m3Weight}}, + 'mse': { + 'per_group': { + 'sym': MXFloat8e4m3WeightMSE}}}}} INPUT_QUANT_MAP = { 'int': { @@ -139,7 +176,10 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): 'stats': { 'per_tensor': { 'sym': CNNInt8DynamicActPerTensorFloat, - 'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}}}, + 'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}, + 'po2_scale': { + 'stats': { + 'per_group': MXInt8Act}}}}, 'float': { 'static': { 'float_scale': { @@ -148,7 +188,21 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): 'sym': Fp8e4m3ActPerTensorFloat}}, 'mse': { 'per_tensor': { - 'sym': Fp8e4m3ActPerTensorFloatMSE}}}}}} + 'sym': Fp8e4m3ActPerTensorFloatMSE}}}}}, + 'float_ocp': { + 'static': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3OCPActPerTensorFloat}}, + 'mse': { + 'per_tensor': { + 'sym': Fp8e4m3OCPActPerTensorFloatMSE}}}}, + 'dynamic': { + 'po2_scale': { + 'stats': { + 'per_group': { + 'sym': MXFloat8e4m3Act}}}}}} def quantize_model( @@ -252,14 +306,14 @@ def layerwise_bit_width_fn_weight(module): weight_bit_width_dict['weight_bit_width'] = weight_bit_width act_bit_width_dict['act_bit_width'] = act_bit_width - if quant_format == 'float' and backend == 'layerwise': + if 'float' in quant_format and backend == 'layerwise': weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act weight_bit_width_dict['weight_mantissa_bit_width'] = layerwise_bit_width_fn_weight_mantissa weight_bit_width_dict['weight_exponent_bit_width'] = layerwise_bit_width_fn_weight_exponent act_bit_width_dict['act_mantissa_bit_width'] = layerwise_bit_width_fn_act_mantissa act_bit_width_dict['act_exponent_bit_width'] = layerwise_bit_width_fn_act_exponent - elif quant_format == 'float' and backend != 'layerwise': + elif 'float' in quant_format and backend != 'layerwise': weight_bit_width_dict['weight_bit_width'] = weight_bit_width act_bit_width_dict['act_bit_width'] = act_bit_width weight_bit_width_dict['weight_mantissa_bit_width'] = weight_mantissa_bit_width @@ -334,12 +388,12 @@ def kwargs_prefix(prefix, weight_kwargs): return {prefix + k: v for k, v in weight_kwargs.items()} weight_bit_width_dict = {'bit_width': weight_bit_width} - if weight_quant_format == 'float': + if 'float' in weight_quant_format: weight_bit_width_dict['exponent_bit_width'] = weight_exponent_bit_width weight_bit_width_dict['mantissa_bit_width'] = weight_mantissa_bit_width act_bit_width_dict = {'bit_width': act_bit_width} - if act_quant_format == 'float': + if 'float' in act_quant_format: act_bit_width_dict['exponent_bit_width'] = act_exponent_bit_width act_bit_width_dict['mantissa_bit_width'] = act_mantissa_bit_width @@ -355,16 +409,12 @@ def kwargs_prefix(prefix, weight_kwargs): # Some activations in MHA should always be symmetric sym_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][ act_scale_type][act_param_method][act_quant_granularity]['sym'] - # Linear layers with 2d input should always be per tensor - per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][ - act_scale_type][act_param_method]['per_tensor'][act_quant_type] + act_quant = act_quant.let(**act_bit_width_dict) sym_act_quant = sym_act_quant.let(**act_bit_width_dict) - per_tensor_act_quant = per_tensor_act_quant.let(**act_bit_width_dict) else: act_quant = None sym_act_quant = None - per_tensor_act_quant = None # Modify the weight quantizer based on the arguments passed in weight_quant = weight_quant.let( @@ -383,13 +433,6 @@ def kwargs_prefix(prefix, weight_kwargs): sym_act_quant = sym_act_quant.let( **{ 'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device}) - if per_tensor_act_quant is not None: - per_tensor_act_quant = per_tensor_act_quant.let( - **{ - 'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device}) - if act_quant_type == 'asym' and act_quant_percentile is not None: - per_tensor_act_quant = per_tensor_act_quant.let( - **{'low_percentile_q': 100 - act_quant_percentile}) weight_quant_dict = {'weight_quant': weight_quant} @@ -431,9 +474,9 @@ def kwargs_prefix(prefix, weight_kwargs): unsigned_quant_act_kwargs['signed'] = False # Layerwise is basic quant kwargs + input_quant - layerwise_quant_wbiol_kwargs = {**quant_wbiol_kwargs, 'input_quant': per_tensor_act_quant} + layerwise_quant_wbiol_kwargs = {**quant_wbiol_kwargs, 'input_quant': act_quant} - layerwise_quant_mha_kwargs = {**quant_mha_kwargs, 'in_proj_input_quant': per_tensor_act_quant} + layerwise_quant_mha_kwargs = {**quant_mha_kwargs, 'in_proj_input_quant': act_quant} quant_layer_map = { torch.nn.Linear: (qnn.QuantLinear, quant_wbiol_kwargs), @@ -526,7 +569,7 @@ def apply_gptq(calib_loader, model, act_order=False): dtype = next(model.parameters()).dtype device = next(model.parameters()).device with torch.no_grad(): - with gptq_mode(model, act_order=act_order, use_quant_activations=False) as gptq: + with gptq_mode(model, act_order=act_order, use_quant_activations=True) as gptq: gptq_model = gptq.model for i in tqdm(range(gptq.num_layers)): for i, (images, target) in enumerate(calib_loader): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 3a9bb29fa..c960a89e6 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -120,7 +120,12 @@ def parse_type(v, default_type): parser.add_argument( '--weight-quant-granularity', default='per_tensor', - choices=['per_tensor', 'per_channel'], + choices=['per_tensor', 'per_channel', 'per_group'], + help='Weight quantization type (default: per_tensor)') +parser.add_argument( + '--act-quant-granularity', + default='per_tensor', + choices=['per_tensor', 'per_group'], help='Activation quantization type (default: per_tensor)') parser.add_argument( '--weight-quant-calibration-type', @@ -168,11 +173,7 @@ def parse_type(v, default_type): '--export-torch-qcdq', action='store_true', help='If true, export the model in torch qcdq format') -add_bool_arg( - parser, - 'scaling-per-output-channel', - default=True, - help='Weight scaling per output channel (default: enabled)') + add_bool_arg( parser, 'bias-corr', default=True, help='Bias correction after calibration (default: enabled)') add_bool_arg( @@ -189,7 +190,7 @@ def parse_type(v, default_type): parser.add_argument( '--quant-format', default='int', - choices=['int', 'float'], + choices=['int', 'float', 'float_ocp'], help='Quantization format to use for weights and activations (default: int)') parser.add_argument( '--layerwise-first-last-mantissa-bit-width', @@ -409,6 +410,7 @@ def main(): weight_narrow_range=args.weight_narrow_range, weight_param_method=args.weight_quant_calibration_type, weight_quant_granularity=args.weight_quant_granularity, + act_quant_granularity=args.act_quant_granularity, weight_quant_type=args.weight_quant_type, layerwise_first_last_bit_width=args.layerwise_first_last_bit_width, act_bit_width=args.act_bit_width, diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 2719b48a0..81107adb7 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -387,6 +387,7 @@ def forward(self, x): input_quant, weight_quant = pytest_cases.param_fixtures("input_quant, weight_quant", [(None, Int8WeightPerTensorFloat), (Int8ActPerTensorFloat, Int8WeightPerTensorFloat), (MXInt8Act, MXInt8Weight), (MXFloat8e4m3Act, MXFloat8e4m3Weight)]) + @pytest_cases.fixture def quant_conv_with_input_quant_model(input_quant, weight_quant):