Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (llm): small fixes to LLM #1035

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ def __init__(self, expanded_groupwise_shape, group_size, group_dim) -> None:

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
# This one is a bit tricky but we could end up here:
# - If we quantize the zero point, which will already have expanded shape matching the scale (although no padding, but we don't need the padding)
# - Groupwise HQO quantization, where weight will already have been padded and expanded
if len(x.shape) == len(self.expanded_groupwise_shape):
return x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird stuff / comments like this make me wonder if we need to re-think our implementation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(but let's not block this release)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

y = torch.nn.functional.pad(
x, padding(x, self.group_size, self.group_dim), mode='constant', value=0.)
y = y.view(self.expanded_groupwise_shape)
Expand Down
6 changes: 5 additions & 1 deletion src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,14 +692,18 @@ def parameter_search(self, xl, x):
self.set_local_loss_mode(False)
qt_value = self.input_view_shape_impl(quant_tensor.value)
qt_scale = self.input_view_shape_impl(quant_tensor.scale)
qt_int = self.input_view_shape_impl(quant_tensor.int())
qt_zp = self.input_view_shape_impl(quant_tensor.zero_point)
qt_int = qt_value / qt_scale + qt_zp
loss = torch.abs(qt_value - x).mean()
best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
if loss >= best_loss:
break
best_loss = torch.min(loss, best_loss)
W_e = shrink_lp_op(x - qt_value, self.beta, self.lp_norm)

# Compared to the original formulation, the value we're looking for is:
# - scaled by qt_scale
# - opposite sign
val = self.input_view_shape_impl((x - W_e) - qt_int * qt_scale)

if self.stats_reduce_dim is None:
Expand Down
53 changes: 36 additions & 17 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,41 +429,31 @@ class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl


class MSESubInjectorBase(ExtendedInjector):

@value
def inner_stats_input_view_shape_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output == ScalingPerOutputType.GROUP:
raise RuntimeError("Not implemented yet")

permute_dims = (this << 1).permute_dims


class MSESymmetricScaleSubInjector(MSESubInjectorBase):
class MSESymmetricScaleSubInjector(ExtendedInjector):
scaling_per_output = (this << 1).scaling_per_output
proxy_module = (this << 1).proxy_module
mse_init_op = AbsMax
stats_impl = MSE
stats_reduce_dim = (this << 1).stats_reduce_dim
device = (this << 1).device
type = (this << 1).type
permute_dims = (this << 1).permute_dims
inner_stats_input_view_shape_impl = (this << 1).inner_stats_input_view_shape_impl


class MSEAsymmetricScaleSubInjector(MSESubInjectorBase):
class MSEAsymmetricScaleSubInjector(ExtendedInjector):
scaling_per_output = (this << 1).scaling_per_output
proxy_module = (this << 1).proxy_module
mse_init_op = AbsMinMax
stats_impl = MSE
stats_reduce_dim = (this << 1).stats_reduce_dim
device = (this << 1).device
dtype = (this << 1).dtype
permute_dims = (this << 1).permute_dims
inner_stats_input_view_shape_impl = (this << 1).inner_stats_input_view_shape_impl


class MSEZeroPointSubInjector(MSESubInjectorBase):
class MSEZeroPointSubInjector(ExtendedInjector):
# zp is per channel when scaling is per channel
scaling_per_output = (this << 1).scaling_per_output
proxy_module = (this << 1).proxy_module
Expand All @@ -473,6 +463,8 @@ class MSEZeroPointSubInjector(MSESubInjectorBase):
stats_reduce_dim = (this << 1).stats_reduce_dim
device = (this << 1).device
dtype = (this << 1).dtype
permute_dims = (this << 1).permute_dims
inner_stats_input_view_shape_impl = (this << 1).inner_stats_input_view_shape_impl


class MSEAsymmetricScale(ExtendedInjector):
Expand All @@ -484,6 +476,15 @@ class MSEAsymmetricScale(ExtendedInjector):
scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS
scaling_stats_input_view_shape_impl = nn.Identity()

@value
def inner_stats_input_view_shape_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK

@value
def scaling_stats_impl():
return this.mse_scale.stats_impl
Expand All @@ -498,6 +499,15 @@ class MSESymmetricScale(ExtendedInjector):
scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS
scaling_stats_input_view_shape_impl = nn.Identity()

@value
def inner_stats_input_view_shape_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK

@value
def scaling_stats_impl():
return this.mse_scale.stats_impl
Expand All @@ -511,6 +521,15 @@ class MSEZeroPoint(ExtendedInjector):
mse_zero_point = MSEZeroPointSubInjector
zero_point_stats_input_view_shape_impl = nn.Identity()

@value
def inner_stats_input_view_shape_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK

@value
def zero_point_stats_impl():
return this.mse_zero_point.stats_impl
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/quant/experimental/mx_quant_ocp.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ class MXInt8Act(MXActMixin, GroupwiseActProxyMixin, IntQuant, MaxStatsScaling, A
bit_width = 8


class MXInt8WeightMSE(MXInt8Weight, MSESymmetricScale):
class MXInt8WeightMSE(MSESymmetricScale, MXInt8Weight):
"""
MX Int signed weight quantizer with per-channel MSE-based scaling.
"""
pass


class ShiftedMXUInt8WeightMSE(ShiftedMXUInt8Weight, MSEAsymmetricScale):
class ShiftedMXUInt8WeightMSE(MSEAsymmetricScale, ShiftedMXUInt8Weight):
"""
MX Int signed weight quantizer with per-channel MSE-based scaling.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/quant/shifted_scaled_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector
from brevitas.quant.base import *
from brevitas.quant.base import HQOActZeroPoint
from brevitas.quant.base import HQOZeroPoint
from brevitas.quant.base import HQOWeightZeroPoint
from brevitas.quant.solver.act import ActQuantSolver
from brevitas.quant.solver.weight import WeightQuantSolver

Expand Down Expand Up @@ -145,7 +145,7 @@ class ShiftedUint8WeightPerChannelFloatMSE(MSEAsymmetricScale,
pass


class ShiftedUint8WeightPerTensorFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerTensorFloat):
class ShiftedUint8WeightPerTensorFloatHQO(HQOWeightZeroPoint, ShiftedUint8WeightPerTensorFloat):
"""
8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer
zero point. Zero-point is initialized from HQO local loss.
Expand All @@ -157,7 +157,7 @@ class ShiftedUint8WeightPerTensorFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerTen
quantize_zero_point = False


class ShiftedUint8WeightPerChannelFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerChannelFloat):
class ShiftedUint8WeightPerChannelFloatHQO(HQOWeightZeroPoint, ShiftedUint8WeightPerChannelFloat):
"""
8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer
zero point. Zero-point is initialized from HQO local loss.
Expand Down
16 changes: 10 additions & 6 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@
'per_tensor': {
'sym': Int8WeightPerTensorFixedPointMSE},
'per_channel': {
'sym': Int8WeightPerChannelFixedPointMSE}},
'per_group': {
'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}}},
'sym': Int8WeightPerChannelFixedPointMSE},
'per_group': {
'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}}}},
'float': {
'float_scale': {
'stats': {
Expand Down Expand Up @@ -210,6 +210,7 @@ def generate_quantizers(
weight_group_size,
quantize_weight_zero_point,
weight_quant_format='int',
weight_group_dim=None,
input_bit_width=None,
input_quant_format='',
input_scale_precision=None,
Expand Down Expand Up @@ -276,6 +277,10 @@ def generate_quantizers(
'narrow_range': False,
'quantize_zero_point': quantize_weight_zero_point},
**weight_float_format)

if weight_group_dim is not None:
weight_quant = weight_quant.let(**{'group_dim': weight_group_dim})

if dtype == torch.float16:
weight_quant = weight_quant.let(**{'scaling_min_val': 1e-4})
if weight_kwargs is not None:
Expand All @@ -285,9 +290,8 @@ def generate_quantizers(
if weight_quant_granularity == 'per_group':
weight_quant = weight_quant.let(**{'group_size': weight_group_size})
# weight scale is converted to a standalone parameter
# This is done already by default in the per_group quantizer
if weight_quant_granularity != 'per_group':
weight_quant = weight_quant.let(scaling_impl_type='parameter_from_stats')

weight_quant = weight_quant.let(scaling_impl_type='parameter_from_stats')
# weight zero-point is converted to a standalone parameter
# This is done already by default in the per_group quantizer
if weight_quant_type == 'asym' and weight_quant_granularity != 'per_group':
Expand Down
13 changes: 11 additions & 2 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Curren
usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--seqlen SEQLEN] [--eval] [--dataset {wikitext2,c4}]
[--weight-bit-width WEIGHT_BIT_WIDTH]
[--weight-param-method {stats,mse}]
[--weight-param-method {stats,mse,hqo}]
[--weight-scale-precision {float_scale,po2_scale}]
[--weight-quant-type {sym,asym}]
[--weight-quant-format WEIGHT_QUANT_FORMAT]
[--weight-quant-granularity {per_channel,per_tensor,per_group}]
[--weight-group-dim {1,0}]
[--weight-group-size WEIGHT_GROUP_SIZE]
[--quantize-weight-zero-point]
[--input-bit-width INPUT_BIT_WIDTH]
Expand All @@ -38,6 +39,7 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--weight-equalization]
[--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ]
[--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}]
[--export-prefix EXPORT_PREFIX]
[--checkpoint-name CHECKPOINT_NAME]

options:
Expand All @@ -51,7 +53,7 @@ options:
Dataset to use for quantization (default: wikitext2)
--weight-bit-width WEIGHT_BIT_WIDTH
Weight bit width. Default: 8.
--weight-param-method {stats,mse}
--weight-param-method {stats,mse,hqo}
How scales/zero-point are determined. Default: stats.
--weight-scale-precision {float_scale,po2_scale}
Whether scale is a float value or a po2. Default: po2.
Expand All @@ -65,6 +67,9 @@ options:
--weight-quant-granularity {per_channel,per_tensor,per_group}
Granularity for scales/zero-point of weights. Default:
per_group.
--weight-group-dim {1,0}
Override default group_dim for groupsize quantization.
Default: layer-dependant
--weight-group-size WEIGHT_GROUP_SIZE
Group size for per_group weight quantization. Default:
128.
Expand Down Expand Up @@ -119,6 +124,10 @@ options:
--load-awq LOAD_AWQ Load the awq search results.
--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}
Model export.
--export-prefix EXPORT_PREFIX
Path prefix to use for the various export flows. If
None, a path will be derived from the model name
(default: None)
--checkpoint-name CHECKPOINT_NAME
Filename to save checkpoint. If `None`, no checkpoint
is saved (default: None)
Expand Down
7 changes: 7 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def main(args):
weight_quant_type=args.weight_quant_type,
weight_quant_granularity=args.weight_quant_granularity,
weight_group_size=args.weight_group_size,
weight_group_dim=args.weight_group_dim,
quantize_weight_zero_point=args.quantize_weight_zero_point,
weight_quant_format=args.weight_quant_format,
input_bit_width=args.input_bit_width,
Expand Down Expand Up @@ -358,6 +359,12 @@ def parse_args(args):
default='per_group',
choices=['per_channel', 'per_tensor', 'per_group'],
help='Granularity for scales/zero-point of weights. Default: per_group.')
parser.add_argument(
'--weight-group-dim',
type=int,
default=None,
choices=[1, 0],
help='Override default group_dim for groupsize quantization. Default: layer-dependant')
parser.add_argument(
'--weight-group-size',
type=int,
Expand Down
Loading