Skip to content

Commit

Permalink
PTQ update
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 1, 2024
1 parent e97b733 commit d653001
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 35 deletions.
5 changes: 1 addition & 4 deletions src/brevitas_examples/common/generative/quant_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
# SPDX-License-Identifier: BSD-3-Clause
"""

from typing import Callable, List, Optional, Tuple
from typing import Callable

import torch
from torch import Tensor
import torch.nn as nn

import brevitas
from brevitas.core.function_wrapper.shape import PermuteDims
from brevitas.core.utils import SliceTensor
from brevitas.core.zero_point import _ScaleShiftZeroPoint
from brevitas.function.ops_ste import abs_binary_sign_grad

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def unique(sequence):
'act_bit_width': [8], # Act bit width
'bias_bit_width': [32], # Bias Bit-Width for Po2 scale
'weight_quant_granularity': ['per_channel'], # Scaling Per Output Channel
'act_quant_granularity': ['per_tensor'], # Scaling Per Output Channel
'act_quant_type': ['sym'], # Act Quant Type
'act_scale_computation_type': ['static'], # Act Quant Type
'act_param_method': ['stats'], # Act Param Method
'weight_param_method': ['mse'], # Weight Quant Type
'bias_corr': [True], # Bias Correction
Expand Down Expand Up @@ -240,7 +242,9 @@ def ptq_torchvision_models(args):
weight_param_method=config_namespace.weight_param_method,
act_param_method=config_namespace.act_param_method,
bias_bit_width=config_namespace.bias_bit_width,
act_scale_computation_type=config_namespace.act_scale_computation_type,
weight_quant_granularity=config_namespace.weight_quant_granularity,
act_quant_granularity=config_namespace.act_quant_granularity,
act_quant_percentile=config_namespace.act_quant_percentile,
act_quant_type=config_namespace.act_quant_type,
scale_factor_type=config_namespace.scale_factor_type,
Expand Down
91 changes: 67 additions & 24 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloatMSE
from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloatMSE
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloatMSE
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloatMSE
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloatMSE
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight
from brevitas.quant.experimental.mx_quant_ocp import MXInt8WeightMSE
from brevitas.quant.experimental.mx_quant_ocp import ShiftedMXUInt8Weight
from brevitas.quant.experimental.mx_quant_ocp import ShiftedMXUInt8WeightMSE
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE
from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint
Expand Down Expand Up @@ -96,12 +110,16 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
'per_tensor': {
'sym': Int8WeightPerTensorFixedPoint},
'per_channel': {
'sym': Int8WeightPerChannelFixedPoint},},
'sym': Int8WeightPerChannelFixedPoint},
'per_group': {
'sym': MXInt8Weight, 'asym': ShiftedMXUInt8Weight}},
'mse': {
'per_tensor': {
'sym': Int8WeightPerTensorFixedPointMSE},
'per_channel': {
'sym': Int8WeightPerChannelFixedPointMSE}},}},
'sym': Int8WeightPerChannelFixedPointMSE},
'per_group': {
'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}},}},
'float': {
'float_scale': {
'stats': {
Expand All @@ -113,7 +131,26 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
'per_tensor': {
'sym': Fp8e4m3WeightPerTensorFloatMSE},
'per_channel': {
'sym': Fp8e4m3WeightPerChannelFloatMSE}}}}}
'sym': Fp8e4m3WeightPerChannelFloatMSE}}}},
'float_ocp': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e4m3OCPWeightPerTensorFloat},
'per_channel': {
'sym': Fp8e4m3OCPWeightPerChannelFloat}},
'mse': {
'per_tensor': {
'sym': Fp8e4m3OCPWeightPerTensorFloatMSE},
'per_channel': {
'sym': Fp8e4m3OCPWeightPerChannelFloatMSE}}},
'po2_scale': {
'stats': {
'per_group': {
'sym': MXFloat8e4m3Weight}},
'mse': {
'per_group': {
'sym': MXFloat8e4m3WeightMSE}}}}}

INPUT_QUANT_MAP = {
'int': {
Expand All @@ -139,7 +176,10 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
'stats': {
'per_tensor': {
'sym': CNNInt8DynamicActPerTensorFloat,
'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}}},
'asym': CNNShiftedUint8DynamicActPerTensorFloat}}},
'po2_scale': {
'stats': {
'per_group': MXInt8Act}}}},
'float': {
'static': {
'float_scale': {
Expand All @@ -148,7 +188,21 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
'sym': Fp8e4m3ActPerTensorFloat}},
'mse': {
'per_tensor': {
'sym': Fp8e4m3ActPerTensorFloatMSE}}}}}}
'sym': Fp8e4m3ActPerTensorFloatMSE}}}}},
'float_ocp': {
'static': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e4m3OCPActPerTensorFloat}},
'mse': {
'per_tensor': {
'sym': Fp8e4m3OCPActPerTensorFloatMSE}}}},
'dynamic': {
'po2_scale': {
'stats': {
'per_group': {
'sym': MXFloat8e4m3Act}}}}}}


def quantize_model(
Expand Down Expand Up @@ -252,14 +306,14 @@ def layerwise_bit_width_fn_weight(module):
weight_bit_width_dict['weight_bit_width'] = weight_bit_width
act_bit_width_dict['act_bit_width'] = act_bit_width

if quant_format == 'float' and backend == 'layerwise':
if 'float' in quant_format and backend == 'layerwise':
weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight
act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act
weight_bit_width_dict['weight_mantissa_bit_width'] = layerwise_bit_width_fn_weight_mantissa
weight_bit_width_dict['weight_exponent_bit_width'] = layerwise_bit_width_fn_weight_exponent
act_bit_width_dict['act_mantissa_bit_width'] = layerwise_bit_width_fn_act_mantissa
act_bit_width_dict['act_exponent_bit_width'] = layerwise_bit_width_fn_act_exponent
elif quant_format == 'float' and backend != 'layerwise':
elif 'float' in quant_format and backend != 'layerwise':
weight_bit_width_dict['weight_bit_width'] = weight_bit_width
act_bit_width_dict['act_bit_width'] = act_bit_width
weight_bit_width_dict['weight_mantissa_bit_width'] = weight_mantissa_bit_width
Expand Down Expand Up @@ -334,12 +388,12 @@ def kwargs_prefix(prefix, weight_kwargs):
return {prefix + k: v for k, v in weight_kwargs.items()}

weight_bit_width_dict = {'bit_width': weight_bit_width}
if weight_quant_format == 'float':
if 'float' in weight_quant_format:
weight_bit_width_dict['exponent_bit_width'] = weight_exponent_bit_width
weight_bit_width_dict['mantissa_bit_width'] = weight_mantissa_bit_width

act_bit_width_dict = {'bit_width': act_bit_width}
if act_quant_format == 'float':
if 'float' in act_quant_format:
act_bit_width_dict['exponent_bit_width'] = act_exponent_bit_width
act_bit_width_dict['mantissa_bit_width'] = act_mantissa_bit_width

Expand All @@ -355,16 +409,12 @@ def kwargs_prefix(prefix, weight_kwargs):
# Some activations in MHA should always be symmetric
sym_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][
act_scale_type][act_param_method][act_quant_granularity]['sym']
# Linear layers with 2d input should always be per tensor
per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][
act_scale_type][act_param_method]['per_tensor'][act_quant_type]

act_quant = act_quant.let(**act_bit_width_dict)
sym_act_quant = sym_act_quant.let(**act_bit_width_dict)
per_tensor_act_quant = per_tensor_act_quant.let(**act_bit_width_dict)
else:
act_quant = None
sym_act_quant = None
per_tensor_act_quant = None

# Modify the weight quantizer based on the arguments passed in
weight_quant = weight_quant.let(
Expand All @@ -383,13 +433,6 @@ def kwargs_prefix(prefix, weight_kwargs):
sym_act_quant = sym_act_quant.let(
**{
'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device})
if per_tensor_act_quant is not None:
per_tensor_act_quant = per_tensor_act_quant.let(
**{
'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device})
if act_quant_type == 'asym' and act_quant_percentile is not None:
per_tensor_act_quant = per_tensor_act_quant.let(
**{'low_percentile_q': 100 - act_quant_percentile})

weight_quant_dict = {'weight_quant': weight_quant}

Expand Down Expand Up @@ -431,9 +474,9 @@ def kwargs_prefix(prefix, weight_kwargs):
unsigned_quant_act_kwargs['signed'] = False

# Layerwise is basic quant kwargs + input_quant
layerwise_quant_wbiol_kwargs = {**quant_wbiol_kwargs, 'input_quant': per_tensor_act_quant}
layerwise_quant_wbiol_kwargs = {**quant_wbiol_kwargs, 'input_quant': act_quant}

layerwise_quant_mha_kwargs = {**quant_mha_kwargs, 'in_proj_input_quant': per_tensor_act_quant}
layerwise_quant_mha_kwargs = {**quant_mha_kwargs, 'in_proj_input_quant': act_quant}

quant_layer_map = {
torch.nn.Linear: (qnn.QuantLinear, quant_wbiol_kwargs),
Expand Down Expand Up @@ -526,7 +569,7 @@ def apply_gptq(calib_loader, model, act_order=False):
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
with torch.no_grad():
with gptq_mode(model, act_order=act_order, use_quant_activations=False) as gptq:
with gptq_mode(model, act_order=act_order, use_quant_activations=True) as gptq:
gptq_model = gptq.model
for i in tqdm(range(gptq.num_layers)):
for i, (images, target) in enumerate(calib_loader):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ def parse_type(v, default_type):
parser.add_argument(
'--weight-quant-granularity',
default='per_tensor',
choices=['per_tensor', 'per_channel'],
choices=['per_tensor', 'per_channel', 'per_group'],
help='Weight quantization type (default: per_tensor)')
parser.add_argument(
'--act-quant-granularity',
default='per_tensor',
choices=['per_tensor', 'per_group'],
help='Activation quantization type (default: per_tensor)')
parser.add_argument(
'--weight-quant-calibration-type',
Expand Down Expand Up @@ -168,11 +173,7 @@ def parse_type(v, default_type):
'--export-torch-qcdq',
action='store_true',
help='If true, export the model in torch qcdq format')
add_bool_arg(
parser,
'scaling-per-output-channel',
default=True,
help='Weight scaling per output channel (default: enabled)')

add_bool_arg(
parser, 'bias-corr', default=True, help='Bias correction after calibration (default: enabled)')
add_bool_arg(
Expand All @@ -189,7 +190,7 @@ def parse_type(v, default_type):
parser.add_argument(
'--quant-format',
default='int',
choices=['int', 'float'],
choices=['int', 'float', 'float_ocp'],
help='Quantization format to use for weights and activations (default: int)')
parser.add_argument(
'--layerwise-first-last-mantissa-bit-width',
Expand Down Expand Up @@ -409,6 +410,7 @@ def main():
weight_narrow_range=args.weight_narrow_range,
weight_param_method=args.weight_quant_calibration_type,
weight_quant_granularity=args.weight_quant_granularity,
act_quant_granularity=args.act_quant_granularity,
weight_quant_type=args.weight_quant_type,
layerwise_first_last_bit_width=args.layerwise_first_last_bit_width,
act_bit_width=args.act_bit_width,
Expand Down
1 change: 1 addition & 0 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def forward(self, x):
input_quant, weight_quant = pytest_cases.param_fixtures("input_quant, weight_quant", [(None, Int8WeightPerTensorFloat), (Int8ActPerTensorFloat, Int8WeightPerTensorFloat), (MXInt8Act, MXInt8Weight), (MXFloat8e4m3Act, MXFloat8e4m3Weight)])



@pytest_cases.fixture
def quant_conv_with_input_quant_model(input_quant, weight_quant):

Expand Down

0 comments on commit d653001

Please sign in to comment.