diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index 163e63a22..7cd9f047f 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -14,6 +14,7 @@ from brevitas.core.utils import StatelessBuffer from brevitas.function import tensor_clamp from brevitas.function.ops import max_float +from brevitas.utils.torch_utils import MAX_MANTISSA_DICT class TensorClamp(brevitas.jit.ScriptModule): @@ -106,6 +107,7 @@ def __init__( self.inf_values = inf_values self.nan_values = nan_values self.signed = signed + self.max_mantissa_dict = MAX_MANTISSA_DICT if max_available_float: max_available_float = torch.tensor(max_available_float, device=device, dtype=dtype) @@ -144,15 +146,17 @@ def forward( mantissa_bit_width: Tensor, exponent_bias: Tensor): - max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias) + max_value = max_float( + exponent_bit_width, self.max_mantissa_dict[mantissa_bit_width.item()], exponent_bias) max_value = max_value if self.max_available_float is None else torch.min( max_value, self.max_available_float()) min_value = torch.tensor(0.) if not self.signed else -max_value # Compute masks - inf_mask = x.isinf() - p_max_val_mask = x > max_value - n_max_val_mask = -x > max_value + if not self.saturating: + inf_mask = x.isinf() + p_max_val_mask = x > max_value + n_max_val_mask = -x > max_value # first clamp everything to +- max_value, basically the saturating case x = self.saturating_clamp(x, max_value, min_value) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index f4fd79f1a..50ae15b80 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import time from typing import Optional, Tuple import torch @@ -64,11 +65,10 @@ def __init__( if dtype is None: dtype = torch.get_default_dtype() self.eps = torch.finfo(dtype).tiny + self.observer_only = brevitas.jit.Attribute(False, bool) @brevitas.jit.script_method - def quantize(self, x: torch.Tensor): - scale = self.scaling_impl(x) - + def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.float_scaling_impl is not None: float_scaling_impl_value = self.float_scaling_impl( self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) @@ -86,10 +86,15 @@ def dequantize(self, y, scale): @brevitas.jit.script_method def forward(self, x): - y, scale = self.quantize(x) - # after quantizing, clamp to special cases like NaN/inf if they are set - y, saturating, inf_values, nan_values = self.float_clamp_impl( - y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) - y = self.dequantize(y, scale) + scale = self.scaling_impl(x) + if self.observer_only: + y = x + saturating, inf_values, nan_values = self.float_clamp_impl.saturating, self.float_clamp_impl.inf_values, self.float_clamp_impl.nan_values + else: + y, scale = self.quantize(x, scale) + # after quantizing, clamp to special cases like NaN/inf if they are set + y, saturating, inf_values, nan_values = self.float_clamp_impl( + y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) + y = self.dequantize(y, scale) # This is to respect the current interface of proxies return y, scale, self.zero_point_impl(), self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias(), saturating, inf_values, nan_values diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index cdb75df74..e1cc271d8 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -145,6 +145,7 @@ def __init__( self.int_scaling_impl = int_scaling_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl + self.observer_only = brevitas.jit.Attribute(False, bool) @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: @@ -153,7 +154,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: int_threshold = self.int_scaling_impl(bit_width) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) - y = self.int_quant(scale, zero_point, bit_width, x) + if self.observer_only: + y = x + else: + y = self.int_quant(scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width @@ -176,6 +180,7 @@ def __init__( self.pre_zero_point_impl = pre_zero_point_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl + self.observer_only = brevitas.jit.Attribute(False, bool) @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: @@ -187,7 +192,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te threshold = self.scaling_impl(x) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) - y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) + if self.observer_only: + y = x + else: + y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width, pre_scale, pre_zero_point @@ -253,5 +261,8 @@ def forward(self, x: Tensor, input_bit_width: Tensor, threshold = self.scaling_impl(x) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) - y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) + if self.observer_only: + y = x + else: + y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width, pre_scale, pre_zero_point diff --git a/src/brevitas/core/scaling/float_scaling.py b/src/brevitas/core/scaling/float_scaling.py index e082589a0..fc1721d91 100644 --- a/src/brevitas/core/scaling/float_scaling.py +++ b/src/brevitas/core/scaling/float_scaling.py @@ -9,6 +9,7 @@ import brevitas from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_float +from brevitas.utils.torch_utils import MAX_MANTISSA_DICT class FloatScaling(brevitas.jit.ScriptModule): @@ -25,6 +26,7 @@ def __init__( self.inf_values = inf_values self.nan_values = nan_values self.saturating = saturating + self.max_mantissa_dict = MAX_MANTISSA_DICT if max_available_float: max_available_float = torch.tensor(max_available_float, device=device, dtype=dtype) @@ -36,7 +38,8 @@ def __init__( def forward( self, exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor) -> Tensor: - max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias) + max_value = max_float( + exponent_bit_width, self.max_mantissa_dict[mantissa_bit_width.item()], exponent_bias) max_value = max_value if self.max_available_float is None else torch.min( max_value, self.max_available_float()) return max_value diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index 461aeb3e6..3cd6172d7 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -442,6 +442,19 @@ def _set_local_loss_mode(module, enabled): m.local_loss_mode = enabled +def _set_observer_mode(module, enabled, previous_observer_mode): + for m in module.modules(): + if hasattr(m, 'observer_only'): + previous_observer_mode[m] = m.observer_only + m.observer_only = enabled + + +def _restore_observer_mode(module, previous_observer_mode): + for m in module.modules(): + if hasattr(m, 'observer_only'): + m.observer_only = previous_observer_mode[m] + + class MSE(torch.nn.Module): # References: # https://github.com/cornell-zhang/dnn-quant-ocs/blob/master/distiller/quantization/clip.py @@ -459,7 +472,12 @@ def __init__( self.mse_init_op = mse_init_op self.input_view_shape_impl = inner_stats_input_view_shape_impl self.proxy_forward = proxy_module.forward + self.previous_observer_mode = dict() self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) + self.set_observer_mode = lambda enabled: _set_observer_mode( + proxy_module, enabled, self.previous_observer_mode) + self.restore_observer_mode = lambda: _restore_observer_mode( + proxy_module, self.previous_observer_mode) self.internal_candidate = None self.num = mse_iters self.search_method = mse_search_method @@ -480,10 +498,12 @@ def evaluate_loss(self, x, candidate): self.internal_candidate = candidate # Set to local_loss_mode before calling the proxy self.set_local_loss_mode(True) + self.set_observer_mode(False) quant_value = self.proxy_forward(x) quant_value = _unpack_quant_tensor(quant_value) loss = self.mse_loss_fn(x, quant_value) self.set_local_loss_mode(False) + self.restore_observer_mode() return loss def mse_grid_search(self, xl, x): diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 1416014ec..bafeb67ef 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -17,6 +17,7 @@ from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector from brevitas.utils.torch_utils import float_internal_scale +from brevitas.utils.torch_utils import MAX_MANTISSA_DICT class InferenceHandler(torch.nn.Module, ABC): @@ -101,12 +102,11 @@ def prepare_for_export(self, module): self.float_to_int_impl = module.fused_activation_quant_proxy.tensor_quant.float_to_int_impl self.float_clamp_impl = module.fused_activation_quant_proxy.tensor_quant.float_clamp_impl - self.max_clamp = max_float( - self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) - self.min_clamp = -self.max_clamp self.fp_internal_scale_min = 1. - self.exponent_bias - self.mantissa_bit_width self.max_value = max_float( - self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias) + self.exponent_bit_width, + MAX_MANTISSA_DICT[self.mantissa_bit_width.item()], + self.exponent_bias) self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value def quantize(self, x): diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 74da08e19..1e814d3d8 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -189,16 +189,8 @@ def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor: return value -@brevitas.jit.ignore -def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor): +def max_float(exponent_bit_width: Tensor, max_mantissa: Tensor, exponent_bias: Tensor): max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias - max_mantissa = torch.sum(( - 2. ** torch.arange( - 0, - -1. * mantissa_bit_width - 1., - -1., - dtype=mantissa_bit_width.dtype, - device=mantissa_bit_width.device))) max_val = max_mantissa * (2 ** max_exponent) return max_val diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 2b1f6833e..6335d6d45 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -201,8 +201,9 @@ def disable_act_quantization(self, model, is_training): if isinstance(module, ActQuantProxyFromInjectorBase): module.train(is_training) if self.call_act_quantizer_impl: - hook = module.register_forward_hook(self.disable_act_quant_hook) - self.disable_act_quant_hooks.append(hook) + for m in module.modules(): + if hasattr(m, 'observer_only'): + m.observer_only = True else: module.disable_quant = True elif isinstance(module, _ACC_PROXIES): @@ -229,9 +230,9 @@ def enable_act_quantization(self, model, is_training): elif isinstance(module, ActQuantProxyFromInjectorBase): module.disable_quant = False module.train(is_training) - for hook in self.disable_act_quant_hooks: - hook.remove() - self.disable_act_quant_hooks = [] + for m in module.modules(): + if hasattr(m, 'observer_only'): + m.observer_only = False def enable_param_quantization(self, model, is_training): for module in model.modules(): diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 2f0d34fba..ea4be5047 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -113,3 +113,11 @@ def padding(x: torch.Tensor, group_size: int, group_dim: int) -> List[int]: padding[2 * group_dim] = group_size - size[group_dim] % group_size padding = list(reversed(padding)) return padding + + +def max_mantissa_func(val): + import torch + return torch.sum((2. ** torch.arange(0, -1. * val - 1., -1.))) + + +MAX_MANTISSA_DICT = {x: max_mantissa_func(x) for x in range(0, 16)} diff --git a/tests/brevitas/core/test_clamp.py b/tests/brevitas/core/test_clamp.py index 96a999494..e5430e140 100644 --- a/tests/brevitas/core/test_clamp.py +++ b/tests/brevitas/core/test_clamp.py @@ -14,6 +14,7 @@ from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeight from brevitas.utils.float_quant_utils import get_max_available_float from brevitas.utils.float_quant_utils import get_min_available_float +from brevitas.utils.torch_utils import MAX_MANTISSA_DICT from tests.brevitas.hyp_helper import float_tensor_random_shape_st from .minifloat_fixtures import * @@ -51,7 +52,7 @@ def test_max_value(minifloat, expected_max_val): max_val = max_float( torch.tensor(minifloat.exponent_bit_width, dtype=torch.float32), - torch.tensor(minifloat.mantissa_bit_width, dtype=torch.float32), + MAX_MANTISSA_DICT[minifloat.mantissa_bit_width], torch.tensor(minifloat.exponent_bias, dtype=torch.float32)) max_available_float = get_max_available_float( minifloat.exponent_bit_width, @@ -84,7 +85,7 @@ def test_float_clamp(inp, fp8_clamp): max_val = max_float( torch.tensor(fp8_clamp.exponent_bit_width, dtype=torch.float32), - torch.tensor(fp8_clamp.mantissa_bit_width, dtype=torch.float32), + MAX_MANTISSA_DICT[fp8_clamp.mantissa_bit_width], torch.tensor(fp8_clamp.exponent_bias, dtype=torch.float32)) max_available_float = get_max_available_float( fp8_clamp.exponent_bit_width, diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 16b8a4b5f..a471f7bbf 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -15,6 +15,7 @@ from brevitas.core.scaling import FloatScaling from brevitas.function.ops import max_float from brevitas.utils.torch_utils import float_internal_scale +from brevitas.utils.torch_utils import MAX_MANTISSA_DICT from tests.brevitas.hyp_helper import float_st from tests.brevitas.hyp_helper import float_tensor_random_shape_st from tests.brevitas.hyp_helper import random_minifloat_format @@ -98,8 +99,8 @@ def test_float_to_quant_float(inp, minifloat_format): signed=signed, float_clamp_impl=float_clamp) expected_out, *_ = float_quant(inp) - - out_quant, scale = float_quant.quantize(inp) + scale = float_quant.scaling_impl(inp) + out_quant, scale = float_quant.quantize(inp, scale) exponent_bit_width, mantissa_bit_width, exponent_bias = torch.tensor(exponent_bit_width, dtype=torch.float), torch.tensor(mantissa_bit_width, dtype=torch.float), torch.tensor(exponent_bias, dtype=torch.float) out_quant, *_ = float_quant.float_clamp_impl( out_quant, exponent_bit_width, mantissa_bit_width, exponent_bias) @@ -142,7 +143,8 @@ def test_scaling_impls_called_once(inp, minifloat_format): scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=float_clamp) - _ = float_quant.quantize(inp) + scale = float_quant.scaling_impl(inp) + _ = float_quant.quantize(inp, scale) # scaling implementations should be called exaclty once on the input float_scaling_impl.assert_called_once_with( torch.tensor(exponent_bit_width), @@ -196,7 +198,7 @@ def test_inner_scale(inp, minifloat_format, scale): scaled_inp = inp / scale max_val = max_float( torch.tensor(exponent_bit_width), - torch.tensor(mantissa_bit_width), + MAX_MANTISSA_DICT[mantissa_bit_width], torch.tensor(exponent_bias)) max_available_float = float_clamp.max_available_float max_value = max_val if max_available_float is None else torch.min( diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 86ef58b77..df1760eb5 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -4,6 +4,7 @@ import math from hypothesis import given +import pytest_cases from pytest_cases import fixture import torch import torch.nn as nn @@ -13,6 +14,8 @@ from brevitas.graph.calibrate import load_quant_model_mode import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloat # Use custom implementation of kthvalue as work around to (b)float16 kernel limitations from brevitas.utils.torch_utils import kthvalue @@ -21,6 +24,10 @@ IN_CH = 8 OUT_CH = 16 BATCH = 1 +REFERENCE_SCALES = { + 'int_quant': (0.00935234408825635910, 0.00859776325523853302), + 'fp_quant': (0.00249395845457911491, 0.00190271728206425905)} +REFERNECE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) def compute_quantile(x, q): @@ -65,6 +72,42 @@ def forward(self, x): assert torch.allclose(expected_scale, scale) +QUANTS = {'int_quant': Int8ActPerTensorFloat, 'fp_quant': Fp8e4m3OCPActPerTensorFloat} + + +@pytest_cases.parametrize("act_quant", QUANTS.items(), ids=QUANTS.keys()) +def test_scale_factors_ptq_calibration_reference(act_quant): + + reference, act_quant = act_quant + + class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + self.act = qnn.QuantReLU(act_quant=act_quant) + self.linear = qnn.QuantLinear(3, 8) + self.act_1 = qnn.QuantIdentity(act_quant=act_quant) + + def forward(self, x): + o = self.act(x) + o = self.linear(o) + return self.act_1(o) + + # Reference input + inp = REFERNECE_INP + model = TestModel() + model.eval() + with torch.no_grad(): + with calibration_mode(model): + model(inp) + + computed_scale = model.act.act_quant.scale(), model.act_1.act_quant.scale() + reference_values = REFERENCE_SCALES[reference] + assert all([ + torch.allclose(comp, torch.tensor(ref)) for comp, + ref in zip(computed_scale, reference_values)]) + + def test_calibration_training_state(): class TestModel(nn.Module):