Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (export): qonnx minifloat export #1070

Merged
merged 5 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/brevitas/export/onnx/qonnx/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,55 @@ def forward(ctx, x, scale, zero_point, bit_width, narrow_range, signed, rounding
return y


class BrevitasFloatQuantFn(Function):

@staticmethod
def symbolic(
g,
x,
scale,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
has_inf,
has_nan,
saturating,
has_subnormal,
rounding_mode,
max_val):
ret = g.op(
f'{DOMAIN_STRING}::FloatQuant',
x,
scale,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
max_val,
has_inf_i=int(has_inf),
has_nan_i=int(has_nan),
has_subnormal_i=int(has_subnormal),
rounding_mode_s=rounding_mode,
saturation_i=saturating)
ret.setType(x.type())
return ret

@staticmethod
def forward(
g,
x,
scale,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
has_inf,
has_nan,
saturating,
has_subnormal,
rounding_mode,
max_val):
return x


class BrevitasTruncFn(Function):

@staticmethod
Expand Down
102 changes: 102 additions & 0 deletions src/brevitas/export/onnx/qonnx/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,116 @@
from brevitas.proxy import DecoupledWeightQuantProxyFromInjector
from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector
from brevitas.proxy import WeightQuantProxyFromInjector
from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector
from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector

from .function import BrevitasBinaryQuantFn
from .function import BrevitasFloatQuantFn
from .function import BrevitasQuantFn
from .function import BrevitasQuantLSTMCellFn
from .function import BrevitasTruncFn


class BrevitasFloatQuantProxyHandler(ONNXBaseHandler, ABC):

def validate(self, module):
assert not module.is_groupwise, "Export with Per Group quantization not supported"
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
self.symbolic_kwargs = {
'scale':
module.scale(),
'exponent_bit_width':
module.exponent_bit_width(),
'mantissa_bit_width':
module.mantissa_bit_width(),
'exponent_bias':
module.exponent_bias(),
'has_inf':
module.inf_values() is not None,
'has_nan':
module.nan_values() is not None,
'saturating':
module.is_saturating(),
'has_subnormal':
True, # Currently we only support subnormal
'rounding_mode':
module.rounding_mode,
'max_float':
torch.tensor(module.quant_injector.max_available_float).type_as(module.scale())}
self.return_args = {
'scale': module.scale(),
'zero_point': torch.zeros_like(module.scale()),
'exponent_bit_width': module.exponent_bit_width(),
'mantissa_bit_width': module.mantissa_bit_width(),
'exponent_bias': module.exponent_bias(),
'saturating': module.is_saturating(),
'inf_values': module.inf_values(),
'nan_values': module.nan_values(),}

def symbolic_execution(self, x: Tensor):
x = BrevitasFloatQuantFn.apply(x, *self.symbolic_kwargs.values())
return_args = (x, *self.return_args.values())
return return_args


class BrevitasWeightFloatQuantProxyHandler(BrevitasFloatQuantProxyHandler):
handled_layer = WeightFloatQuantProxyFromInjector

def __init__(self):
super().__init__()
self.quant_weights = None

def validate(self, zero_point):
assert zero_point == 0, "Zero-point not supported for minifloat quant."

def prepare_for_export(self, module: WeightQuantProxyFromInjector):
if module.is_quant_enabled:
first_qweight = module.tracked_module_list[0].quant_weight()
self.validate(first_qweight.zero_point)
self.symbolic_kwargs = {
'scale':
first_qweight.scale,
'exponent_bit_width':
first_qweight.exponent_bit_width,
'mantissa_bit_width':
first_qweight.mantissa_bit_width,
'exponent_bias':
first_qweight.exponent_bias,
'has_inf':
first_qweight.inf_values is not None,
'has_nan':
first_qweight.nan_values is not None,
'saturating':
first_qweight.saturating,
'has_subnormal':
True, # Currently we only support subnormal
'rounding_mode':
module.rounding_mode,
'max_float':
torch.tensor(module.quant_injector.max_available_float
).type_as(first_qweight.scale)}
self.return_args = {
'scale': first_qweight.scale,
'zero_point': torch.zeros_like(first_qweight.scale),
'exponent_bit_width': first_qweight.exponent_bit_width,
'mantissa_bit_width': first_qweight.mantissa_bit_width,
'exponent_bias': first_qweight.exponent_bias,
'saturating': first_qweight.saturating,
'inf_values': first_qweight.inf_values,
'nan_values': first_qweight.nan_values,}

def symbolic_execution(self, x: Tensor):
return super().symbolic_execution(x)


class BrevitasActFloatQuantProxyHandler(BrevitasFloatQuantProxyHandler):
handled_layer = ActFloatQuantProxyFromInjector


class BrevitasQuantProxyHandler(ONNXBaseHandler, ABC):

def validate(self, module):
Expand Down
6 changes: 5 additions & 1 deletion src/brevitas/export/onnx/qonnx/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
from .function import BrevitasQuantFn
from .function import BrevitasQuantLSTMCellFn
from .function import BrevitasTruncFn
from .handler import BrevitasActFloatQuantProxyHandler
from .handler import BrevitasActQuantProxyHandler
from .handler import BrevitasBiasQuantProxyHandler
from .handler import BrevitasDecoupledWeightQuantProxyHandler
from .handler import BrevitasDecoupledWeightQuantWithInputProxyHandler
from .handler import BrevitasQuantLSTMLayerHandler
from .handler import BrevitasTruncQuantProxyHandler
from .handler import BrevitasWeightFloatQuantProxyHandler
from .handler import BrevitasWeightQuantProxyHandler


Expand All @@ -42,7 +44,9 @@ class QONNXManager(ONNXBaseManager):
BrevitasDecoupledWeightQuantProxyHandler,
BrevitasDecoupledWeightQuantWithInputProxyHandler,
BrevitasTruncQuantProxyHandler,
BrevitasQuantLSTMLayerHandler]
BrevitasQuantLSTMLayerHandler,
BrevitasWeightFloatQuantProxyHandler,
BrevitasActFloatQuantProxyHandler]

custom_fns = [
DebugMarkerFunction,
Expand Down
12 changes: 12 additions & 0 deletions tests/brevitas/export/test_onnx_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from brevitas import torch_version
from brevitas.export import export_onnx_qcdq
from brevitas.export import export_qonnx
import brevitas.nn as qnn
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
Expand All @@ -23,6 +24,17 @@ def test_simple_fp8_export():
assert True


@jit_disabled_for_export()
def test_qonnx_simple_fp8_export():
if torch_version < version.parse('2.1.0'):
pytest.skip(f"OCP FP8 types not supported by {torch_version}")

model = qnn.QuantLinear(
3, 16, weight_quant=Fp8e4m3OCPWeightPerTensorFloat, input_quant=Fp8e4m3OCPActPerTensorFloat)
export_qonnx(model, torch.randn(1, 3), 'qonnx_act_weight_fp8.onnx')
assert True


@jit_disabled_for_export()
def test_fp8_export_activation():
if torch_version < version.parse('2.1.0'):
Expand Down
Loading