Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (mx): PTQ MX + Float support #1010

Merged
merged 3 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/brevitas_examples/common/generative/quant_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
# SPDX-License-Identifier: BSD-3-Clause
"""

from typing import Callable, List, Optional, Tuple
from typing import Callable

import torch
from torch import Tensor
import torch.nn as nn

import brevitas
from brevitas.core.function_wrapper.shape import PermuteDims
from brevitas.core.utils import SliceTensor
from brevitas.core.zero_point import _ScaleShiftZeroPoint
from brevitas.function.ops_ste import abs_binary_sign_grad

Expand Down
26 changes: 13 additions & 13 deletions src/brevitas_examples/imagenet_classification/ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir
[--bias-bit-width {32,16,None}]
[--act-quant-type {sym,asym}]
[--weight-quant-type {sym,asym}]
[--weight-quant-granularity {per_tensor,per_channel}]
[--weight-quant-granularity {per_tensor,per_channel,per_group}]
[--act-quant-granularity {per_tensor,per_group}]
[--weight-quant-calibration-type {stats,mse}]
[--act-equalization {fx,layerwise,None}]
[--act-quant-calibration-type {stats,mse}]
Expand All @@ -90,11 +91,11 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir
[--learned-round-lr LEARNED_ROUND_LR]
[--act-quant-percentile ACT_QUANT_PERCENTILE]
[--export-onnx-qcdq] [--export-torch-qcdq]
[--scaling-per-output-channel | --no-scaling-per-output-channel]
[--bias-corr | --no-bias-corr]
[--graph-eq-merge-bias | --no-graph-eq-merge-bias]
[--weight-narrow-range | --no-weight-narrow-range]
[--gpfq-p GPFQ_P] [--quant-format {int,float}]
[--gpfq-p GPFQ_P]
[--quant-format {int,float,float_ocp}]
[--layerwise-first-last-mantissa-bit-width LAYERWISE_FIRST_LAST_MANTISSA_BIT_WIDTH]
[--layerwise-first-last-exponent-bit-width LAYERWISE_FIRST_LAST_EXPONENT_BIT_WIDTH]
[--weight-mantissa-bit-width WEIGHT_MANTISSA_BIT_WIDTH]
Expand All @@ -104,6 +105,7 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir
[--accumulator-bit-width ACCUMULATOR_BIT_WIDTH]
[--onnx-opset-version ONNX_OPSET_VERSION]
[--channel-splitting-ratio CHANNEL_SPLITTING_RATIO]
[--compression-rate COMPRESSION_RATE]
[--gptq | --no-gptq] [--gpfq | --no-gpfq]
[--gpfa2q | --no-gpfa2q]
[--gpxq-act-order | --no-gpxq-act-order]
Expand All @@ -115,7 +117,7 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir

PyTorch ImageNet PTQ Validation

options:
optional arguments:
-h, --help show this help message and exit
--calibration-dir CALIBRATION_DIR
Path to folder containing Imagenet calibration folder
Expand Down Expand Up @@ -176,7 +178,9 @@ options:
Activation quantization type (default: sym)
--weight-quant-type {sym,asym}
Weight quantization type (default: sym)
--weight-quant-granularity {per_tensor,per_channel}
--weight-quant-granularity {per_tensor,per_channel,per_group}
Weight quantization type (default: per_tensor)
--act-quant-granularity {per_tensor,per_group}
Activation quantization type (default: per_tensor)
--weight-quant-calibration-type {stats,mse}
Weight quantization calibration type (default: stats)
Expand All @@ -201,12 +205,6 @@ options:
(default: 99.999)
--export-onnx-qcdq If true, export the model in onnx qcdq format
--export-torch-qcdq If true, export the model in torch qcdq format
--scaling-per-output-channel
Enable Weight scaling per output channel (default:
enabled)
--no-scaling-per-output-channel
Disable Weight scaling per output channel (default:
enabled)
--bias-corr Enable Bias correction after calibration (default:
enabled)
--no-bias-corr Disable Bias correction after calibration (default:
Expand All @@ -224,7 +222,7 @@ options:
Disable Narrow range for weight quantization (default:
disabled)
--gpfq-p GPFQ_P P parameter for GPFQ (default: 1.0)
--quant-format {int,float}
--quant-format {int,float,float_ocp}
Quantization format to use for weights and activations
(default: int)
--layerwise-first-last-mantissa-bit-width LAYERWISE_FIRST_LAST_MANTISSA_BIT_WIDTH
Expand Down Expand Up @@ -252,6 +250,9 @@ options:
--channel-splitting-ratio CHANNEL_SPLITTING_RATIO
Split Ratio for Channel Splitting. When set to 0.0,
Channel Splitting will not be applied. (default: 0.0)
--compression-rate COMPRESSION_RATE
Specify compression rate < 1.0 for random projection.
Default is 0.0 and does not use RP.
--gptq Enable GPTQ (default: disabled)
--no-gptq Disable GPTQ (default: disabled)
--gpfq Enable GPFQ (default: disabled)
Expand Down Expand Up @@ -280,7 +281,6 @@ options:
--no-uint_sym_act_for_unsigned_values
Disable Use unsigned act quant when possible (default:
enabled)

```

The script requires to specify the calibration folder (`--calibration-dir`), from which the calibration samples will be taken (configurable with the `--calibration-samples` argument), and a validation folder (`--validation-dir`).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def unique(sequence):
'act_bit_width': [8], # Act bit width
'bias_bit_width': [32], # Bias Bit-Width for Po2 scale
'weight_quant_granularity': ['per_channel'], # Scaling Per Output Channel
'act_quant_granularity': ['per_tensor'], # Scaling Per Output Channel
'act_quant_type': ['sym'], # Act Quant Type
'act_scale_computation_type': ['static'], # Act Quant Type
'act_param_method': ['stats'], # Act Param Method
'weight_param_method': ['mse'], # Weight Quant Type
'bias_corr': [True], # Bias Correction
Expand Down Expand Up @@ -240,7 +242,9 @@ def ptq_torchvision_models(args):
weight_param_method=config_namespace.weight_param_method,
act_param_method=config_namespace.act_param_method,
bias_bit_width=config_namespace.bias_bit_width,
act_scale_computation_type=config_namespace.act_scale_computation_type,
weight_quant_granularity=config_namespace.weight_quant_granularity,
act_quant_granularity=config_namespace.act_quant_granularity,
act_quant_percentile=config_namespace.act_quant_percentile,
act_quant_type=config_namespace.act_quant_type,
scale_factor_type=config_namespace.scale_factor_type,
Expand Down
91 changes: 67 additions & 24 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloatMSE
from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloatMSE
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloatMSE
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloatMSE
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloatMSE
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight
from brevitas.quant.experimental.mx_quant_ocp import MXInt8WeightMSE
from brevitas.quant.experimental.mx_quant_ocp import ShiftedMXUInt8Weight
from brevitas.quant.experimental.mx_quant_ocp import ShiftedMXUInt8WeightMSE
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE
from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint
Expand Down Expand Up @@ -96,12 +110,16 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
'per_tensor': {
'sym': Int8WeightPerTensorFixedPoint},
'per_channel': {
'sym': Int8WeightPerChannelFixedPoint},},
'sym': Int8WeightPerChannelFixedPoint},
'per_group': {
'sym': MXInt8Weight, 'asym': ShiftedMXUInt8Weight}},
'mse': {
'per_tensor': {
'sym': Int8WeightPerTensorFixedPointMSE},
'per_channel': {
'sym': Int8WeightPerChannelFixedPointMSE}},}},
'sym': Int8WeightPerChannelFixedPointMSE},
'per_group': {
'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}},}},
'float': {
'float_scale': {
'stats': {
Expand All @@ -113,7 +131,26 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
'per_tensor': {
'sym': Fp8e4m3WeightPerTensorFloatMSE},
'per_channel': {
'sym': Fp8e4m3WeightPerChannelFloatMSE}}}}}
'sym': Fp8e4m3WeightPerChannelFloatMSE}}}},
'float_ocp': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e4m3OCPWeightPerTensorFloat},
'per_channel': {
'sym': Fp8e4m3OCPWeightPerChannelFloat}},
'mse': {
'per_tensor': {
'sym': Fp8e4m3OCPWeightPerTensorFloatMSE},
'per_channel': {
'sym': Fp8e4m3OCPWeightPerChannelFloatMSE}}},
'po2_scale': {
'stats': {
'per_group': {
'sym': MXFloat8e4m3Weight}},
'mse': {
'per_group': {
'sym': MXFloat8e4m3WeightMSE}}}}}

INPUT_QUANT_MAP = {
'int': {
Expand All @@ -139,7 +176,10 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
'stats': {
'per_tensor': {
'sym': CNNInt8DynamicActPerTensorFloat,
'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}}},
'asym': CNNShiftedUint8DynamicActPerTensorFloat}}},
'po2_scale': {
'stats': {
'per_group': MXInt8Act}}}},
'float': {
'static': {
'float_scale': {
Expand All @@ -148,7 +188,21 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
'sym': Fp8e4m3ActPerTensorFloat}},
'mse': {
'per_tensor': {
'sym': Fp8e4m3ActPerTensorFloatMSE}}}}}}
'sym': Fp8e4m3ActPerTensorFloatMSE}}}}},
'float_ocp': {
'static': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e4m3OCPActPerTensorFloat}},
'mse': {
'per_tensor': {
'sym': Fp8e4m3OCPActPerTensorFloatMSE}}}},
'dynamic': {
'po2_scale': {
'stats': {
'per_group': {
'sym': MXFloat8e4m3Act}}}}}}


def quantize_model(
Expand Down Expand Up @@ -252,14 +306,14 @@ def layerwise_bit_width_fn_weight(module):
weight_bit_width_dict['weight_bit_width'] = weight_bit_width
act_bit_width_dict['act_bit_width'] = act_bit_width

if quant_format == 'float' and backend == 'layerwise':
if 'float' in quant_format and backend == 'layerwise':
weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight
act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act
weight_bit_width_dict['weight_mantissa_bit_width'] = layerwise_bit_width_fn_weight_mantissa
weight_bit_width_dict['weight_exponent_bit_width'] = layerwise_bit_width_fn_weight_exponent
act_bit_width_dict['act_mantissa_bit_width'] = layerwise_bit_width_fn_act_mantissa
act_bit_width_dict['act_exponent_bit_width'] = layerwise_bit_width_fn_act_exponent
elif quant_format == 'float' and backend != 'layerwise':
elif 'float' in quant_format and backend != 'layerwise':
weight_bit_width_dict['weight_bit_width'] = weight_bit_width
act_bit_width_dict['act_bit_width'] = act_bit_width
weight_bit_width_dict['weight_mantissa_bit_width'] = weight_mantissa_bit_width
Expand Down Expand Up @@ -334,12 +388,12 @@ def kwargs_prefix(prefix, weight_kwargs):
return {prefix + k: v for k, v in weight_kwargs.items()}

weight_bit_width_dict = {'bit_width': weight_bit_width}
if weight_quant_format == 'float':
if 'float' in weight_quant_format:
weight_bit_width_dict['exponent_bit_width'] = weight_exponent_bit_width
weight_bit_width_dict['mantissa_bit_width'] = weight_mantissa_bit_width

act_bit_width_dict = {'bit_width': act_bit_width}
if act_quant_format == 'float':
if 'float' in act_quant_format:
act_bit_width_dict['exponent_bit_width'] = act_exponent_bit_width
act_bit_width_dict['mantissa_bit_width'] = act_mantissa_bit_width

Expand All @@ -355,16 +409,12 @@ def kwargs_prefix(prefix, weight_kwargs):
# Some activations in MHA should always be symmetric
sym_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][
act_scale_type][act_param_method][act_quant_granularity]['sym']
# Linear layers with 2d input should always be per tensor
per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][
act_scale_type][act_param_method]['per_tensor'][act_quant_type]

act_quant = act_quant.let(**act_bit_width_dict)
sym_act_quant = sym_act_quant.let(**act_bit_width_dict)
per_tensor_act_quant = per_tensor_act_quant.let(**act_bit_width_dict)
else:
act_quant = None
sym_act_quant = None
per_tensor_act_quant = None

# Modify the weight quantizer based on the arguments passed in
weight_quant = weight_quant.let(
Expand All @@ -383,13 +433,6 @@ def kwargs_prefix(prefix, weight_kwargs):
sym_act_quant = sym_act_quant.let(
**{
'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device})
if per_tensor_act_quant is not None:
per_tensor_act_quant = per_tensor_act_quant.let(
**{
'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device})
if act_quant_type == 'asym' and act_quant_percentile is not None:
per_tensor_act_quant = per_tensor_act_quant.let(
**{'low_percentile_q': 100 - act_quant_percentile})

weight_quant_dict = {'weight_quant': weight_quant}

Expand Down Expand Up @@ -431,9 +474,9 @@ def kwargs_prefix(prefix, weight_kwargs):
unsigned_quant_act_kwargs['signed'] = False

# Layerwise is basic quant kwargs + input_quant
layerwise_quant_wbiol_kwargs = {**quant_wbiol_kwargs, 'input_quant': per_tensor_act_quant}
layerwise_quant_wbiol_kwargs = {**quant_wbiol_kwargs, 'input_quant': act_quant}

layerwise_quant_mha_kwargs = {**quant_mha_kwargs, 'in_proj_input_quant': per_tensor_act_quant}
layerwise_quant_mha_kwargs = {**quant_mha_kwargs, 'in_proj_input_quant': act_quant}

quant_layer_map = {
torch.nn.Linear: (qnn.QuantLinear, quant_wbiol_kwargs),
Expand Down Expand Up @@ -526,7 +569,7 @@ def apply_gptq(calib_loader, model, act_order=False):
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
with torch.no_grad():
with gptq_mode(model, act_order=act_order, use_quant_activations=False) as gptq:
with gptq_mode(model, act_order=act_order, use_quant_activations=True) as gptq:
gptq_model = gptq.model
for i in tqdm(range(gptq.num_layers)):
for i, (images, target) in enumerate(calib_loader):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ def parse_type(v, default_type):
parser.add_argument(
'--weight-quant-granularity',
default='per_tensor',
choices=['per_tensor', 'per_channel'],
choices=['per_tensor', 'per_channel', 'per_group'],
help='Weight quantization type (default: per_tensor)')
parser.add_argument(
'--act-quant-granularity',
default='per_tensor',
choices=['per_tensor', 'per_group'],
nickfraser marked this conversation as resolved.
Show resolved Hide resolved
help='Activation quantization type (default: per_tensor)')
parser.add_argument(
'--weight-quant-calibration-type',
Expand Down Expand Up @@ -168,11 +173,7 @@ def parse_type(v, default_type):
'--export-torch-qcdq',
action='store_true',
help='If true, export the model in torch qcdq format')
add_bool_arg(
parser,
'scaling-per-output-channel',
default=True,
help='Weight scaling per output channel (default: enabled)')

add_bool_arg(
parser, 'bias-corr', default=True, help='Bias correction after calibration (default: enabled)')
add_bool_arg(
Expand All @@ -189,7 +190,7 @@ def parse_type(v, default_type):
parser.add_argument(
'--quant-format',
default='int',
choices=['int', 'float'],
choices=['int', 'float', 'float_ocp'],
help='Quantization format to use for weights and activations (default: int)')
parser.add_argument(
'--layerwise-first-last-mantissa-bit-width',
Expand Down Expand Up @@ -409,6 +410,7 @@ def main():
weight_narrow_range=args.weight_narrow_range,
weight_param_method=args.weight_quant_calibration_type,
weight_quant_granularity=args.weight_quant_granularity,
act_quant_granularity=args.act_quant_granularity,
weight_quant_type=args.weight_quant_type,
layerwise_first_last_bit_width=args.layerwise_first_last_bit_width,
act_bit_width=args.act_bit_width,
Expand Down
Loading