From d3b4d5fb2b10e1c335fa081bc1c2011d03ef7b4a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 23 Sep 2024 17:13:58 +0100 Subject: [PATCH 01/12] Feat (calibrate/activation_calibration): speed-up by skipping quantization --- src/brevitas/core/quant/float.py | 18 +++++++++++------- src/brevitas/core/quant/int.py | 17 ++++++++++++++--- src/brevitas/function/ops.py | 16 ++++++++-------- src/brevitas/graph/calibrate.py | 11 ++++++----- 4 files changed, 39 insertions(+), 23 deletions(-) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index f4fd79f1a..20a513907 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -64,11 +64,10 @@ def __init__( if dtype is None: dtype = torch.get_default_dtype() self.eps = torch.finfo(dtype).tiny + self.observer_only = brevitas.jit.Attribute(False, bool) @brevitas.jit.script_method - def quantize(self, x: torch.Tensor): - scale = self.scaling_impl(x) - + def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor]: if self.float_scaling_impl is not None: float_scaling_impl_value = self.float_scaling_impl( self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) @@ -86,10 +85,15 @@ def dequantize(self, y, scale): @brevitas.jit.script_method def forward(self, x): - y, scale = self.quantize(x) - # after quantizing, clamp to special cases like NaN/inf if they are set - y, saturating, inf_values, nan_values = self.float_clamp_impl( - y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) + scale = self.scaling_impl(x) + if self.observer_only: + y = x + saturating, inf_values, nan_values = self.float_clamp_impl.saturating, self.float_clamp_impl.inf_values, self.float_clamp_impl.nan_values + else: + y, scale = self.quantize(x, scale) + # after quantizing, clamp to special cases like NaN/inf if they are set + y, saturating, inf_values, nan_values = self.float_clamp_impl( + y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) y = self.dequantize(y, scale) # This is to respect the current interface of proxies return y, scale, self.zero_point_impl(), self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias(), saturating, inf_values, nan_values diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index cdb75df74..e1cc271d8 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -145,6 +145,7 @@ def __init__( self.int_scaling_impl = int_scaling_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl + self.observer_only = brevitas.jit.Attribute(False, bool) @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: @@ -153,7 +154,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: int_threshold = self.int_scaling_impl(bit_width) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) - y = self.int_quant(scale, zero_point, bit_width, x) + if self.observer_only: + y = x + else: + y = self.int_quant(scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width @@ -176,6 +180,7 @@ def __init__( self.pre_zero_point_impl = pre_zero_point_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl + self.observer_only = brevitas.jit.Attribute(False, bool) @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: @@ -187,7 +192,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te threshold = self.scaling_impl(x) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) - y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) + if self.observer_only: + y = x + else: + y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width, pre_scale, pre_zero_point @@ -253,5 +261,8 @@ def forward(self, x: Tensor, input_bit_width: Tensor, threshold = self.scaling_impl(x) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) - y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) + if self.observer_only: + y = x + else: + y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width, pre_scale, pre_zero_point diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 74da08e19..67b57df6a 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -189,16 +189,16 @@ def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor: return value -@brevitas.jit.ignore +def max_mantissa_func(val): + return torch.sum((2. ** torch.arange(0, -1. * val - 1., -1.))) + + +MAX_MANTISSA_DICT = {x: max_mantissa_func(x) for x in range(0, 16)} + + def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor): max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias - max_mantissa = torch.sum(( - 2. ** torch.arange( - 0, - -1. * mantissa_bit_width - 1., - -1., - dtype=mantissa_bit_width.dtype, - device=mantissa_bit_width.device))) + max_mantissa = MAX_MANTISSA_DICT[mantissa_bit_width.item()] max_val = max_mantissa * (2 ** max_exponent) return max_val diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 2b1f6833e..6335d6d45 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -201,8 +201,9 @@ def disable_act_quantization(self, model, is_training): if isinstance(module, ActQuantProxyFromInjectorBase): module.train(is_training) if self.call_act_quantizer_impl: - hook = module.register_forward_hook(self.disable_act_quant_hook) - self.disable_act_quant_hooks.append(hook) + for m in module.modules(): + if hasattr(m, 'observer_only'): + m.observer_only = True else: module.disable_quant = True elif isinstance(module, _ACC_PROXIES): @@ -229,9 +230,9 @@ def enable_act_quantization(self, model, is_training): elif isinstance(module, ActQuantProxyFromInjectorBase): module.disable_quant = False module.train(is_training) - for hook in self.disable_act_quant_hooks: - hook.remove() - self.disable_act_quant_hooks = [] + for m in module.modules(): + if hasattr(m, 'observer_only'): + m.observer_only = False def enable_param_quantization(self, model, is_training): for module in model.modules(): From 5e5d9e78121166ceded75ccc3cd8ad237c975cdb Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 23 Sep 2024 17:49:28 +0100 Subject: [PATCH 02/12] fix tests --- src/brevitas/core/stats/stats_op.py | 20 ++++++++++++++++++++ tests/brevitas/core/test_float_quant.py | 7 ++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index 461aeb3e6..3cd6172d7 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -442,6 +442,19 @@ def _set_local_loss_mode(module, enabled): m.local_loss_mode = enabled +def _set_observer_mode(module, enabled, previous_observer_mode): + for m in module.modules(): + if hasattr(m, 'observer_only'): + previous_observer_mode[m] = m.observer_only + m.observer_only = enabled + + +def _restore_observer_mode(module, previous_observer_mode): + for m in module.modules(): + if hasattr(m, 'observer_only'): + m.observer_only = previous_observer_mode[m] + + class MSE(torch.nn.Module): # References: # https://github.com/cornell-zhang/dnn-quant-ocs/blob/master/distiller/quantization/clip.py @@ -459,7 +472,12 @@ def __init__( self.mse_init_op = mse_init_op self.input_view_shape_impl = inner_stats_input_view_shape_impl self.proxy_forward = proxy_module.forward + self.previous_observer_mode = dict() self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) + self.set_observer_mode = lambda enabled: _set_observer_mode( + proxy_module, enabled, self.previous_observer_mode) + self.restore_observer_mode = lambda: _restore_observer_mode( + proxy_module, self.previous_observer_mode) self.internal_candidate = None self.num = mse_iters self.search_method = mse_search_method @@ -480,10 +498,12 @@ def evaluate_loss(self, x, candidate): self.internal_candidate = candidate # Set to local_loss_mode before calling the proxy self.set_local_loss_mode(True) + self.set_observer_mode(False) quant_value = self.proxy_forward(x) quant_value = _unpack_quant_tensor(quant_value) loss = self.mse_loss_fn(x, quant_value) self.set_local_loss_mode(False) + self.restore_observer_mode() return loss def mse_grid_search(self, xl, x): diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 16b8a4b5f..52352c38b 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -98,8 +98,8 @@ def test_float_to_quant_float(inp, minifloat_format): signed=signed, float_clamp_impl=float_clamp) expected_out, *_ = float_quant(inp) - - out_quant, scale = float_quant.quantize(inp) + scale = float_quant.scaling_impl(inp) + out_quant = float_quant.quantize(inp, scale) exponent_bit_width, mantissa_bit_width, exponent_bias = torch.tensor(exponent_bit_width, dtype=torch.float), torch.tensor(mantissa_bit_width, dtype=torch.float), torch.tensor(exponent_bias, dtype=torch.float) out_quant, *_ = float_quant.float_clamp_impl( out_quant, exponent_bit_width, mantissa_bit_width, exponent_bias) @@ -142,7 +142,8 @@ def test_scaling_impls_called_once(inp, minifloat_format): scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=float_clamp) - _ = float_quant.quantize(inp) + scale = float_quant.scaling_impl(inp) + _ = float_quant.quantize(inp, scale) # scaling implementations should be called exaclty once on the input float_scaling_impl.assert_called_once_with( torch.tensor(exponent_bit_width), From 7ddf2f6d4cfdc94ca555931af7f76d4a796453be Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 23 Sep 2024 17:55:55 +0100 Subject: [PATCH 03/12] restore change --- src/brevitas/function/ops.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 67b57df6a..dddb60b83 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -193,12 +193,16 @@ def max_mantissa_func(val): return torch.sum((2. ** torch.arange(0, -1. * val - 1., -1.))) -MAX_MANTISSA_DICT = {x: max_mantissa_func(x) for x in range(0, 16)} - - +@brevitas.jit.ignore def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor): max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias - max_mantissa = MAX_MANTISSA_DICT[mantissa_bit_width.item()] + max_mantissa = torch.sum(( + 2. ** torch.arange( + 0, + -1. * mantissa_bit_width - 1., + -1., + dtype=mantissa_bit_width.dtype, + device=mantissa_bit_width.device))) max_val = max_mantissa * (2 ** max_exponent) return max_val From 50e68938f1f0d2d710178785663b2e8c0c1bd05e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 23 Sep 2024 18:32:48 +0100 Subject: [PATCH 04/12] typing --- src/brevitas/core/quant/float.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 20a513907..260b11110 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -67,7 +67,7 @@ def __init__( self.observer_only = brevitas.jit.Attribute(False, bool) @brevitas.jit.script_method - def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor]: + def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.float_scaling_impl is not None: float_scaling_impl_value = self.float_scaling_impl( self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) From 5902061d89a4a73af1ac30c4c4adc723dedfe38a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 23 Sep 2024 18:35:51 +0100 Subject: [PATCH 05/12] fix --- tests/brevitas/core/test_float_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 52352c38b..552472717 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -99,7 +99,7 @@ def test_float_to_quant_float(inp, minifloat_format): float_clamp_impl=float_clamp) expected_out, *_ = float_quant(inp) scale = float_quant.scaling_impl(inp) - out_quant = float_quant.quantize(inp, scale) + out_quant, scale = float_quant.quantize(inp, scale) exponent_bit_width, mantissa_bit_width, exponent_bias = torch.tensor(exponent_bit_width, dtype=torch.float), torch.tensor(mantissa_bit_width, dtype=torch.float), torch.tensor(exponent_bias, dtype=torch.float) out_quant, *_ = float_quant.float_clamp_impl( out_quant, exponent_bit_width, mantissa_bit_width, exponent_bias) From 5d5dfcec8e8b25e6af9875767072eb5aabc5c52e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Sep 2024 07:30:04 +0100 Subject: [PATCH 06/12] Fix --- src/brevitas/core/quant/float.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 260b11110..65f56a134 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -94,6 +94,6 @@ def forward(self, x): # after quantizing, clamp to special cases like NaN/inf if they are set y, saturating, inf_values, nan_values = self.float_clamp_impl( y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) - y = self.dequantize(y, scale) + y = self.dequantize(y, scale) # This is to respect the current interface of proxies return y, scale, self.zero_point_impl(), self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias(), saturating, inf_values, nan_values From 2b8d1f2f92abb26b4c807ac1333f43a10a2ba896 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Sep 2024 10:07:50 +0100 Subject: [PATCH 07/12] calibration with reference values --- tests/brevitas/graph/test_calibration.py | 42 ++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 86ef58b77..1da67ff3a 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -4,6 +4,7 @@ import math from hypothesis import given +import pytest_cases from pytest_cases import fixture import torch import torch.nn as nn @@ -13,6 +14,7 @@ from brevitas.graph.calibrate import load_quant_model_mode import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloat # Use custom implementation of kthvalue as work around to (b)float16 kernel limitations from brevitas.utils.torch_utils import kthvalue @@ -21,6 +23,10 @@ IN_CH = 8 OUT_CH = 16 BATCH = 1 +REFERENCE_SCALES = { + 'int_quant': (0.00935234408825635910, 0.00859776325523853302), + 'fp_quant': (0.00249395845457911491, 0.00190271728206425905)} +REFERNECE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) def compute_quantile(x, q): @@ -65,6 +71,42 @@ def forward(self, x): assert torch.allclose(expected_scale, scale) +QUANTS = {'int_quant': Int8ActPerTensorFloat, 'fp_quant': Fp8e4m3ActPerTensorFloat} + + +@pytest_cases.parametrize("act_quant", QUANTS.items(), ids=QUANTS.keys()) +def test_scale_factors_ptq_calibration_reference(act_quant): + + reference, act_quant = act_quant + + class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + self.act = qnn.QuantReLU(act_quant=act_quant) + self.linear = qnn.QuantLinear(3, 8) + self.act_1 = qnn.QuantIdentity(act_quant=act_quant) + + def forward(self, x): + o = self.act(x) + o = self.linear(o) + return self.act_1(o) + + # Reference input + inp = REFERNECE_INP + model = TestModel() + model.eval() + with torch.no_grad(): + with calibration_mode(model): + model(inp) + + computed_scale = model.act.act_quant.scale(), model.act_1.act_quant.scale() + reference_values = REFERENCE_SCALES[reference] + assert all([ + torch.allclose(comp, torch.tensor(ref)) for comp, + ref in zip(computed_scale, reference_values)]) + + def test_calibration_training_state(): class TestModel(nn.Module): From 347a381fcbc5031981525406b43597313acdd2c0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Sep 2024 10:55:00 +0100 Subject: [PATCH 08/12] Fast --- src/brevitas/core/function_wrapper/clamp.py | 11 +++++++---- src/brevitas/core/quant/float.py | 1 + src/brevitas/core/scaling/float_scaling.py | 2 +- src/brevitas/export/inference/handler.py | 6 ++---- src/brevitas/function/ops.py | 14 ++------------ src/brevitas/utils/quant_utils.py | 9 ++++++++- tests/brevitas/core/test_clamp.py | 5 +++-- tests/brevitas/core/test_float_quant.py | 3 ++- tests/brevitas/graph/test_calibration.py | 3 ++- 9 files changed, 28 insertions(+), 26 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 163e63a22..b77164d90 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -14,6 +14,7 @@ from brevitas.core.utils import StatelessBuffer from brevitas.function import tensor_clamp from brevitas.function.ops import max_float +from brevitas.utils.quant_utils import MAX_MANTISSA_DICT class TensorClamp(brevitas.jit.ScriptModule): @@ -106,6 +107,7 @@ def __init__( self.inf_values = inf_values self.nan_values = nan_values self.signed = signed + self.max_mantissa_dict = MAX_MANTISSA_DICT if max_available_float: max_available_float = torch.tensor(max_available_float, device=device, dtype=dtype) @@ -144,15 +146,16 @@ def forward( mantissa_bit_width: Tensor, exponent_bias: Tensor): - max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias) + max_value = max_float(exponent_bit_width, self.max_mantissa_dict[mantissa_bit_width.item()], exponent_bias) max_value = max_value if self.max_available_float is None else torch.min( max_value, self.max_available_float()) min_value = torch.tensor(0.) if not self.signed else -max_value # Compute masks - inf_mask = x.isinf() - p_max_val_mask = x > max_value - n_max_val_mask = -x > max_value + if not self.saturating: + inf_mask = x.isinf() + p_max_val_mask = x > max_value + n_max_val_mask = -x > max_value # first clamp everything to +- max_value, basically the saturating case x = self.saturating_clamp(x, max_value, min_value) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 65f56a134..049a52c9e 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -11,6 +11,7 @@ from brevitas.core.scaling import ConstScaling from brevitas.core.utils import StatelessBuffer from brevitas.utils.torch_utils import float_internal_scale +import time class FloatQuant(brevitas.jit.ScriptModule): diff --git a/src/brevitas/core/scaling/float_scaling.py b/src/brevitas/core/scaling/float_scaling.py index e082589a0..7cf99d73d 100644 --- a/src/brevitas/core/scaling/float_scaling.py +++ b/src/brevitas/core/scaling/float_scaling.py @@ -36,7 +36,7 @@ def __init__( def forward( self, exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor) -> Tensor: - max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias) + max_value = max_float(exponent_bit_width, self.max_mantissa_dict[mantissa_bit_width.item()], exponent_bias) max_value = max_value if self.max_available_float is None else torch.min( max_value, self.max_available_float()) return max_value diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 1416014ec..cb9347a09 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -16,6 +16,7 @@ from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector +from brevitas.utils.quant_utils import MAX_MANTISSA_DICT from brevitas.utils.torch_utils import float_internal_scale @@ -101,12 +102,9 @@ def prepare_for_export(self, module): self.float_to_int_impl = module.fused_activation_quant_proxy.tensor_quant.float_to_int_impl self.float_clamp_impl = module.fused_activation_quant_proxy.tensor_quant.float_clamp_impl - self.max_clamp = max_float( - self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) - self.min_clamp = -self.max_clamp self.fp_internal_scale_min = 1. - self.exponent_bias - self.mantissa_bit_width self.max_value = max_float( - self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) + self.exponent_bit_width, MAX_MANTISSA_DICT[self.mantissa_bit_width.item()], self.exponent_bias) self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value def quantize(self, x): diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index dddb60b83..f2e94e1cc 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -10,6 +10,7 @@ from torch import Tensor import brevitas +from brevitas.utils.quant_utils import MAX_MANTISSA_DICT @brevitas.jit.script @@ -189,20 +190,9 @@ def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor: return value -def max_mantissa_func(val): - return torch.sum((2. ** torch.arange(0, -1. * val - 1., -1.))) - -@brevitas.jit.ignore -def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor): +def max_float(exponent_bit_width: Tensor, max_mantissa: Tensor, exponent_bias: Tensor): max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias - max_mantissa = torch.sum(( - 2. ** torch.arange( - 0, - -1. * mantissa_bit_width - 1., - -1., - dtype=mantissa_bit_width.dtype, - device=mantissa_bit_width.device))) max_val = max_mantissa * (2 ** max_exponent) return max_val diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index 6fd519b41..c6810dfbf 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -9,7 +9,7 @@ from brevitas.quant_tensor import GroupwiseFloatQuantTensor from brevitas.quant_tensor import GroupwiseIntQuantTensor from brevitas.quant_tensor import IntQuantTensor - +import torch class _CachedIO: @@ -221,3 +221,10 @@ def float_to_int_impl_to_enum(module): return FloatToIntImplType.STOCHASTIC_ROUND else: return None + + + +def max_mantissa_func(val): + return torch.sum((2. ** torch.arange(0, -1. * val - 1., -1.))) + +MAX_MANTISSA_DICT = {x: max_mantissa_func(x) for x in range(0, 16)} \ No newline at end of file diff --git a/tests/brevitas/core/test_clamp.py b/tests/brevitas/core/test_clamp.py index 96a999494..8d4a6c117 100644 --- a/tests/brevitas/core/test_clamp.py +++ b/tests/brevitas/core/test_clamp.py @@ -14,6 +14,7 @@ from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight from brevitas.utils.float_quant_utils import get_max_available_float from brevitas.utils.float_quant_utils import get_min_available_float +from brevitas.utils.quant_utils import MAX_MANTISSA_DICT from tests.brevitas.hyp_helper import float_tensor_random_shape_st from .minifloat_fixtures import * @@ -51,7 +52,7 @@ def test_max_value(minifloat, expected_max_val): max_val = max_float( torch.tensor(minifloat.exponent_bit_width, dtype=torch.float32), - torch.tensor(minifloat.mantissa_bit_width, dtype=torch.float32), + MAX_MANTISSA_DICT[minifloat.mantissa_bit_width], torch.tensor(minifloat.exponent_bias, dtype=torch.float32)) max_available_float = get_max_available_float( minifloat.exponent_bit_width, @@ -84,7 +85,7 @@ def test_float_clamp(inp, fp8_clamp): max_val = max_float( torch.tensor(fp8_clamp.exponent_bit_width, dtype=torch.float32), - torch.tensor(fp8_clamp.mantissa_bit_width, dtype=torch.float32), + MAX_MANTISSA_DICT[fp8_clamp.mantissa_bit_width], torch.tensor(fp8_clamp.exponent_bias, dtype=torch.float32)) max_available_float = get_max_available_float( fp8_clamp.exponent_bit_width, diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 552472717..b088fd036 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -14,6 +14,7 @@ from brevitas.core.scaling import ConstScaling from brevitas.core.scaling import FloatScaling from brevitas.function.ops import max_float +from brevitas.utils.quant_utils import MAX_MANTISSA_DICT from brevitas.utils.torch_utils import float_internal_scale from tests.brevitas.hyp_helper import float_st from tests.brevitas.hyp_helper import float_tensor_random_shape_st @@ -197,7 +198,7 @@ def test_inner_scale(inp, minifloat_format, scale): scaled_inp = inp / scale max_val = max_float( torch.tensor(exponent_bit_width), - torch.tensor(mantissa_bit_width), + MAX_MANTISSA_DICT[mantissa_bit_width], torch.tensor(exponent_bias)) max_available_float = float_clamp.max_available_float max_value = max_val if max_available_float is None else torch.min( diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 1da67ff3a..df1760eb5 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -15,6 +15,7 @@ import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloat # Use custom implementation of kthvalue as work around to (b)float16 kernel limitations from brevitas.utils.torch_utils import kthvalue @@ -71,7 +72,7 @@ def forward(self, x): assert torch.allclose(expected_scale, scale) -QUANTS = {'int_quant': Int8ActPerTensorFloat, 'fp_quant': Fp8e4m3ActPerTensorFloat} +QUANTS = {'int_quant': Int8ActPerTensorFloat, 'fp_quant': Fp8e4m3OCPActPerTensorFloat} @pytest_cases.parametrize("act_quant", QUANTS.items(), ids=QUANTS.keys()) From 33257923f00a8b5019541a2b9a44a59d13fc8e85 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Sep 2024 10:58:24 +0100 Subject: [PATCH 09/12] missing fix --- src/brevitas/core/scaling/float_scaling.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/brevitas/core/scaling/float_scaling.py b/src/brevitas/core/scaling/float_scaling.py index 7cf99d73d..86451523b 100644 --- a/src/brevitas/core/scaling/float_scaling.py +++ b/src/brevitas/core/scaling/float_scaling.py @@ -9,6 +9,7 @@ import brevitas from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_float +from brevitas.utils.quant_utils import MAX_MANTISSA_DICT class FloatScaling(brevitas.jit.ScriptModule): @@ -25,6 +26,7 @@ def __init__( self.inf_values = inf_values self.nan_values = nan_values self.saturating = saturating + self.max_mantissa_dict = MAX_MANTISSA_DICT if max_available_float: max_available_float = torch.tensor(max_available_float, device=device, dtype=dtype) @@ -36,7 +38,8 @@ def __init__( def forward( self, exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor) -> Tensor: - max_value = max_float(exponent_bit_width, self.max_mantissa_dict[mantissa_bit_width.item()], exponent_bias) + max_value = max_float( + exponent_bit_width, self.max_mantissa_dict[mantissa_bit_width.item()], exponent_bias) max_value = max_value if self.max_available_float is None else torch.min( max_value, self.max_available_float()) return max_value From a2ef7bc7b0394850ae167d4af2982133f87b0beb Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Sep 2024 11:08:11 +0100 Subject: [PATCH 10/12] fix --- src/brevitas/core/function_wrapper/clamp.py | 3 ++- src/brevitas/core/quant/float.py | 2 +- src/brevitas/export/inference/handler.py | 4 +++- src/brevitas/function/ops.py | 1 - src/brevitas/utils/quant_utils.py | 7 ++++--- 5 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index b77164d90..1d6391d21 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -146,7 +146,8 @@ def forward( mantissa_bit_width: Tensor, exponent_bias: Tensor): - max_value = max_float(exponent_bit_width, self.max_mantissa_dict[mantissa_bit_width.item()], exponent_bias) + max_value = max_float( + exponent_bit_width, self.max_mantissa_dict[mantissa_bit_width.item()], exponent_bias) max_value = max_value if self.max_available_float is None else torch.min( max_value, self.max_available_float()) min_value = torch.tensor(0.) if not self.signed else -max_value diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 049a52c9e..50ae15b80 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import time from typing import Optional, Tuple import torch @@ -11,7 +12,6 @@ from brevitas.core.scaling import ConstScaling from brevitas.core.utils import StatelessBuffer from brevitas.utils.torch_utils import float_internal_scale -import time class FloatQuant(brevitas.jit.ScriptModule): diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index cb9347a09..5945d0a8a 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -104,7 +104,9 @@ def prepare_for_export(self, module): self.fp_internal_scale_min = 1. - self.exponent_bias - self.mantissa_bit_width self.max_value = max_float( - self.exponent_bit_width, MAX_MANTISSA_DICT[self.mantissa_bit_width.item()], self.exponent_bias) + self.exponent_bit_width, + MAX_MANTISSA_DICT[self.mantissa_bit_width.item()], + self.exponent_bias) self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value def quantize(self, x): diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index f2e94e1cc..2aa4c609d 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -190,7 +190,6 @@ def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor: return value - def max_float(exponent_bit_width: Tensor, max_mantissa: Tensor, exponent_bias: Tensor): max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias max_val = max_mantissa * (2 ** max_exponent) diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index c6810dfbf..be2847b32 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -9,7 +9,7 @@ from brevitas.quant_tensor import GroupwiseFloatQuantTensor from brevitas.quant_tensor import GroupwiseIntQuantTensor from brevitas.quant_tensor import IntQuantTensor -import torch + class _CachedIO: @@ -223,8 +223,9 @@ def float_to_int_impl_to_enum(module): return None - def max_mantissa_func(val): + import torch return torch.sum((2. ** torch.arange(0, -1. * val - 1., -1.))) -MAX_MANTISSA_DICT = {x: max_mantissa_func(x) for x in range(0, 16)} \ No newline at end of file + +MAX_MANTISSA_DICT = {x: max_mantissa_func(x) for x in range(0, 16)} From 7992becf1b3b07323efce5776e64a450fe5b3391 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Sep 2024 11:13:00 +0100 Subject: [PATCH 11/12] removed unused import --- src/brevitas/function/ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 2aa4c609d..1e814d3d8 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -10,7 +10,6 @@ from torch import Tensor import brevitas -from brevitas.utils.quant_utils import MAX_MANTISSA_DICT @brevitas.jit.script From 9bb9d5f9b120c4d62b5a510367268d1810bcf464 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Sep 2024 11:20:54 +0100 Subject: [PATCH 12/12] fix import --- src/brevitas/core/function_wrapper/clamp.py | 2 +- src/brevitas/core/scaling/float_scaling.py | 2 +- src/brevitas/export/inference/handler.py | 2 +- src/brevitas/utils/quant_utils.py | 8 -------- src/brevitas/utils/torch_utils.py | 8 ++++++++ tests/brevitas/core/test_clamp.py | 2 +- tests/brevitas/core/test_float_quant.py | 2 +- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 1d6391d21..7cd9f047f 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -14,7 +14,7 @@ from brevitas.core.utils import StatelessBuffer from brevitas.function import tensor_clamp from brevitas.function.ops import max_float -from brevitas.utils.quant_utils import MAX_MANTISSA_DICT +from brevitas.utils.torch_utils import MAX_MANTISSA_DICT class TensorClamp(brevitas.jit.ScriptModule): diff --git a/src/brevitas/core/scaling/float_scaling.py b/src/brevitas/core/scaling/float_scaling.py index 86451523b..fc1721d91 100644 --- a/src/brevitas/core/scaling/float_scaling.py +++ b/src/brevitas/core/scaling/float_scaling.py @@ -9,7 +9,7 @@ import brevitas from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_float -from brevitas.utils.quant_utils import MAX_MANTISSA_DICT +from brevitas.utils.torch_utils import MAX_MANTISSA_DICT class FloatScaling(brevitas.jit.ScriptModule): diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 5945d0a8a..bafeb67ef 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -16,8 +16,8 @@ from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector -from brevitas.utils.quant_utils import MAX_MANTISSA_DICT from brevitas.utils.torch_utils import float_internal_scale +from brevitas.utils.torch_utils import MAX_MANTISSA_DICT class InferenceHandler(torch.nn.Module, ABC): diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index be2847b32..6fd519b41 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -221,11 +221,3 @@ def float_to_int_impl_to_enum(module): return FloatToIntImplType.STOCHASTIC_ROUND else: return None - - -def max_mantissa_func(val): - import torch - return torch.sum((2. ** torch.arange(0, -1. * val - 1., -1.))) - - -MAX_MANTISSA_DICT = {x: max_mantissa_func(x) for x in range(0, 16)} diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 2f0d34fba..ea4be5047 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -113,3 +113,11 @@ def padding(x: torch.Tensor, group_size: int, group_dim: int) -> List[int]: padding[2 * group_dim] = group_size - size[group_dim] % group_size padding = list(reversed(padding)) return padding + + +def max_mantissa_func(val): + import torch + return torch.sum((2. ** torch.arange(0, -1. * val - 1., -1.))) + + +MAX_MANTISSA_DICT = {x: max_mantissa_func(x) for x in range(0, 16)} diff --git a/tests/brevitas/core/test_clamp.py b/tests/brevitas/core/test_clamp.py index 8d4a6c117..e5430e140 100644 --- a/tests/brevitas/core/test_clamp.py +++ b/tests/brevitas/core/test_clamp.py @@ -14,7 +14,7 @@ from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight from brevitas.utils.float_quant_utils import get_max_available_float from brevitas.utils.float_quant_utils import get_min_available_float -from brevitas.utils.quant_utils import MAX_MANTISSA_DICT +from brevitas.utils.torch_utils import MAX_MANTISSA_DICT from tests.brevitas.hyp_helper import float_tensor_random_shape_st from .minifloat_fixtures import * diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index b088fd036..a471f7bbf 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -14,8 +14,8 @@ from brevitas.core.scaling import ConstScaling from brevitas.core.scaling import FloatScaling from brevitas.function.ops import max_float -from brevitas.utils.quant_utils import MAX_MANTISSA_DICT from brevitas.utils.torch_utils import float_internal_scale +from brevitas.utils.torch_utils import MAX_MANTISSA_DICT from tests.brevitas.hyp_helper import float_st from tests.brevitas.hyp_helper import float_tensor_random_shape_st from tests.brevitas.hyp_helper import random_minifloat_format