Skip to content

Commit

Permalink
Initial AXE support for imagnet
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Oct 18, 2024
1 parent b161fb8 commit ac735f3
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,18 @@ def single_layer_update(self, percdamp=0.01):
weight = weight.flatten(1)

# TODO: add support for signed input activations
assert not self.quant_metadata.signed
if self.quant_metadata.signed:
raise NotImplementedError("Signed inputs not yet supported.")

n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size)

s = self.layer.weight_quant.scale()
P = torch.tensor(self.max_accumulator_bit_width)
N = self.quant_metadata.bit_width
# TODO: add support for two's complement accumulator representation
# NOTE: using sign-magnitude here, which is sufficient to support both
# sign-magnitude and 2s complement accumulators
A = (pow(2, P - 1) - 1) / float(pow(2, N) - 1)
B = (-pow(2, P - 1) - 1) / float(pow(2, N) - 1)
B = -A
Z = (pow(2, P) - 2) / float(pow(2, N) - 1)
# translating into the quantized range; need to pad to get these thresholds
wT = pad_tensor_with_zeros(weight / s, self.max_accumulator_tile_size).view(
Expand Down Expand Up @@ -203,7 +205,7 @@ def single_layer_update(self, percdamp=0.01):
q_groups = self.get_quant_weights(i, i1, permutation_list) # [Groups, OC/groups]
for group_index in range(self.groups):
perm = permutation_list[group_index]
q = q_groups[group_index] # [OC/groups]
q = q_groups[group_index].to(torch.float32) # [OC/groups]
w = weight[group_index, :, perm[i1:i2][i]].to(torch.float32) # [OC/groups]
d = h_inv_block[group_index, i, i] # [1]
error = (w - q) / d # [OC/groups]
Expand Down Expand Up @@ -280,16 +282,18 @@ def single_layer_update(self, percdamp=0.01):
weight = weight.flatten(1)

# TODO: add support for signed input activations
assert not self.quant_metadata.signed
if self.quant_metadata.signed:
raise NotImplementedError("Signed inputs not yet supported.")

n_tiles = math.ceil(weight.shape[-1] / self.max_accumulator_tile_size)

s = self.layer.weight_quant.scale()
P = torch.tensor(self.max_accumulator_bit_width)
N = self.quant_metadata.bit_width
# TODO: add support for two's complement accumulator representation
# NOTE: using sign-magnitude here, which is sufficient to support both
# sign-magnitude and 2s complement accumulators
A = (pow(2, P - 1) - 1) / float(pow(2, N) - 1)
B = (-pow(2, P - 1) - 1) / float(pow(2, N) - 1)
B = -A
Z = (pow(2, P) - 2) / float(pow(2, N) - 1)
# translating into the quantized range; need to pad to get these thresholds
wT = pad_tensor_with_zeros(weight / s, self.max_accumulator_tile_size).view(
Expand All @@ -307,7 +311,7 @@ def single_layer_update(self, percdamp=0.01):
b = torch.zeros_like(T, device=dev) # neg

# stablize G with a dampening factor and then square root the matrix
norms = torch.zeros((self.groups, self.columns), device=dev, dtype=dtype)
norms = torch.zeros((self.groups, self.columns), device=dev, dtype=torch.float32)
self.H = self.H.to(dev)
diag = torch.arange(self.columns, device='cpu')
for i in range(self.groups):
Expand Down Expand Up @@ -342,14 +346,17 @@ def single_layer_update(self, percdamp=0.01):
permutation_list = self._get_permutation_list(weight)

U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev,
dtype=dtype) # [Groups, OC/groups, Samples]
weight.shape[0],
weight.shape[1],
self.float_input.shape[1],
device=dev,
dtype=torch.float32) # [Groups, OC/groups, Samples]

for t in range(weight.shape[-1]):
for group_index in range(self.groups):
i = permutation_list[group_index][t]
U[group_index] += torch.matmul(
weight[group_index, :, i].unsqueeze(1),
weight[group_index, :, i].unsqueeze(1).to(torch.float32),
self.float_input[group_index, :, i].unsqueeze(0))
norm = norms[group_index, i]
if norm > 0:
Expand All @@ -364,12 +371,12 @@ def single_layer_update(self, percdamp=0.01):
q_max = s[group_index] * torch.clamp_min(A - a[group_index, bx, :] - 0.5, 0.0)
q_min = s[group_index] * torch.clamp_max(B - b[group_index, bx, :] + 0.5, 0.0)
q_arg.clamp_(q_min, q_max)
weight[group_index, :, i] = q_arg
weight[group_index, :, i] = q_arg.to(dtype)
q_groups: Tensor = self.get_quant_weights(t, 0, permutation_list)
for group_index in range(self.groups):
i = permutation_list[group_index][t]
U[group_index] -= torch.matmul(
q_groups[group_index].unsqueeze(1),
q_groups[group_index].unsqueeze(1).to(torch.float32),
self.quant_input[group_index, :, i].unsqueeze(0))
bx = i // self.max_accumulator_tile_size # block index
q = q_groups[group_index] / s[group_index] # [OC/groups]
Expand Down
54 changes: 41 additions & 13 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from copy import deepcopy
from functools import partial
import math

import torch
import torch.backends.cudnn as cudnn
from tqdm import tqdm

from brevitas.core.function_wrapper.shape import OverBatchOverTensorView
Expand All @@ -16,6 +15,8 @@
from brevitas.graph.calibrate import norm_correction_mode
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.gpfq import gpfq_mode
from brevitas.graph.gpfq import GPFQv2
from brevitas.graph.gptq import GPTQ
from brevitas.graph.gptq import gptq_mode
from brevitas.graph.quantize import layerwise_quantize
from brevitas.graph.quantize import quantize
Expand Down Expand Up @@ -60,14 +61,15 @@
from brevitas.quant.scaled_int import Int32Bias
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFixedPoint
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatHQO
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatHQO
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatHQO
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE
from brevitas_examples.common.axe import A2GPFQ
from brevitas_examples.common.axe import A2GPTQ
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat
from brevitas_examples.imagenet_classification.ptq.learned_round_utils import learned_round_iterator
Expand Down Expand Up @@ -574,12 +576,32 @@ def apply_act_equalization(model, calib_loader, layerwise):
model(images)


def apply_gptq(calib_loader, model, act_order=False):
def apply_gptq(
calib_loader,
model,
act_order=False,
use_quant_activations=False,
create_weight_orig=False,
max_accumulator_bit_width=None,
max_accumulator_tile_size=128):
if max_accumulator_bit_width is not None:
# Use accumulator-aware extension (AXE) framework
print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...")
gptq_class = partial(
A2GPTQ,
max_accumulator_bit_width=max_accumulator_bit_width,
max_accumulator_tile_size=max_accumulator_tile_size)
else:
gptq_class = GPTQ
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
with torch.no_grad():
with gptq_mode(model, act_order=act_order, use_quant_activations=True) as gptq:
with gptq_mode(model,
act_order=act_order,
use_quant_activations=use_quant_activations,
create_weight_orig=create_weight_orig,
gptq_class=gptq_class) as gptq:
gptq_model = gptq.model
for i in tqdm(range(gptq.num_layers)):
for i, (images, target) in enumerate(calib_loader):
Expand All @@ -593,21 +615,27 @@ def apply_gpfq(
calib_loader,
model,
act_order,
p=1.0,
use_gpfa2q=False,
accumulator_bit_width=None,
compression_rate=0.0):
create_weight_orig=False,
max_accumulator_bit_width=None,
max_accumulator_tile_size=128):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
if max_accumulator_bit_width is not None:
# Use accumulator-aware extension (AXE) framework
print(f"Using AXE to target {max_accumulator_bit_width}-bit accumulation...")
gpfq_class = partial(
A2GPFQ,
max_accumulator_bit_width=max_accumulator_bit_width,
max_accumulator_tile_size=max_accumulator_tile_size)
else:
gpfq_class = GPFQv2
with torch.no_grad():
with gpfq_mode(model,
p=p,
create_weight_orig=create_weight_orig,
use_quant_activations=True,
act_order=act_order,
use_gpfa2q=use_gpfa2q,
accumulator_bit_width=accumulator_bit_width,
compression_rate=compression_rate) as gpfq:
gpfq_class=gpfq_class) as gpfq:
gpfq_model = gpfq.model
for i in tqdm(range(gpfq.num_layers)):
for i, (images, target) in enumerate(calib_loader):
Expand Down
58 changes: 31 additions & 27 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,15 @@ def parse_type(v, default_type):
type=int,
help='Exponent bit width used with float quantization for activations (default: 3)')
parser.add_argument(
'--accumulator-bit-width',
'--gpxq-accumulator-bit-width',
default=None,
type=int,
help='Accumulator Bit Width for GPFA2Q (default: None)')
help='Accumulator Bit Width for GPxQ (default: None)')
parser.add_argument(
'--gpxq-accumulator-tile-size',
default=None,
type=int,
help='Accumulator tile size for GPxQ (default: None)')
parser.add_argument('--onnx-opset-version', default=None, type=int, help='ONNX opset version')
parser.add_argument(
'--channel-splitting-ratio',
Expand All @@ -240,17 +245,20 @@ def parse_type(v, default_type):
help=
'Split Ratio for Channel Splitting. When set to 0.0, Channel Splitting will not be applied. (default: 0.0)'
)
parser.add_argument(
'--compression-rate',
default=0.0,
type=float,
help='Specify compression rate < 1.0 for random projection. Default is 0.0 and does not use RP.'
)
add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)')
add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)')
add_bool_arg(
parser, 'gpxq-act-order', default=False, help='GPxQ Act order heuristic (default: disabled)')
add_bool_arg(
parser,
'gptq-use-quant-activations',
default=False,
help='Use quant activations for GPTQ (default: disabled)')
add_bool_arg(
parser,
'gpxq-create-weight-orig',
default=False,
help='Maintain original weights for non-quant forward pass (default: disabled)')
add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)')
add_bool_arg(parser, 'calibrate-bn', default=False, help='Calibrate BN (default: disabled)')
add_bool_arg(
Expand All @@ -265,7 +273,7 @@ def parse_type(v, default_type):
help='Merge BN layers before quantizing the model (default: enabled)')
add_bool_arg(
parser,
'uint_sym_act_for_unsigned_values',
'uint-sym-act-for-unsigned-values',
default=True,
help='Use unsigned act quant when possible (default: enabled)')
add_bool_arg(parser, 'compile', default=False, help='Use torch.compile (default: disabled)')
Expand Down Expand Up @@ -306,7 +314,6 @@ def main():
f"w{args.weight_bit_width}_"
f"{'gptq_' if args.gptq else ''}"
f"{'gpfq_' if args.gpfq else ''}"
f"{'gpfa2q_' if args.gpfa2q else ''}"
f"{'gpxq_act_order_' if args.gpxq_act_order else ''}"
f"{'learned_round_' if args.learned_round else ''}"
f"{'weight_narrow_range_' if args.weight_narrow_range else ''}"
Expand All @@ -329,10 +336,8 @@ def main():
f"Weight bit width: {args.weight_bit_width} - "
f"GPTQ: {args.gptq} - "
f"GPFQ: {args.gpfq} - "
f"GPFA2Q: {args.gpfa2q} - "
f"GPFQ P: {args.gpfq_p} - "
f"GPxQ Act Order: {args.gpxq_act_order} - "
f"GPFA2Q Accumulator Bit Width: {args.accumulator_bit_width} - "
f"GPxQ Accumulator Bit Width: {args.gpxq_accumulator_bit_width} - "
f"Learned Round: {args.learned_round} - "
f"Weight narrow range: {args.weight_narrow_range} - "
f"Bias bit width: {args.bias_bit_width} - "
Expand Down Expand Up @@ -406,7 +411,9 @@ def main():
if args.act_equalization is not None:
print("Applying activation equalization:")
apply_act_equalization(model, calib_loader, layerwise=args.act_equalization == 'layerwise')

device = next(iter(model.parameters())).device

# Define the quantized model
quant_model = quantize_model(
model,
Expand Down Expand Up @@ -446,24 +453,21 @@ def main():
apply_gpfq(
calib_loader,
quant_model,
p=args.gpfq_p,
act_order=args.gpxq_act_order,
compression_rate=args.compression_rate)
create_weight_orig=args.gpxq_create_weight_orig,
max_accumulator_bit_width=args.gpxq_accumulator_bit_width,
max_accumulator_tile_size=args.gpxq_accumulator_tile_size)

if args.gpfa2q:
print("Performing GPFA2Q:")
apply_gpfq(
if args.gptq:
print("Performing GPTQ:")
apply_gptq(
calib_loader,
quant_model,
p=args.gpfq_p,
act_order=args.gpxq_act_order,
use_gpfa2q=args.gpfa2q,
accumulator_bit_width=args.accumulator_bit_width,
compression_rate=args.compression_rate)

if args.gptq:
print("Performing GPTQ:")
apply_gptq(calib_loader, quant_model, act_order=args.gpxq_act_order)
use_quant_activations=args.gptq_use_quant_activations,
create_weight_orig=args.gpxq_create_weight_orig,
max_accumulator_bit_width=args.gpxq_accumulator_bit_width,
max_accumulator_tile_size=args.gpxq_accumulator_tile_size)

if args.learned_round:
print("Applying Learned Round:")
Expand Down
5 changes: 2 additions & 3 deletions src/brevitas_examples/llm/llm_quant/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
from brevitas.graph.gptq import gptq_mode
from brevitas.graph.gpxq import StopFwdException
from brevitas.utils.python_utils import recurse_getattr

from .axe import A2GPFQ
from .axe import A2GPTQ
from brevitas_examples.common.axe import A2GPFQ
from brevitas_examples.common.axe import A2GPTQ


@torch.no_grad()
Expand Down

0 comments on commit ac735f3

Please sign in to comment.