diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 00f8e472e..7a5a283ea 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -32,7 +32,9 @@ from brevitas import config from brevitas.core.function_wrapper.learned_round import LearnedRoundSte +from brevitas.graph.calibrate import disable_return_quant_tensor from brevitas.graph.calibrate import DisableEnableQuantization +from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.inject.enum import FloatToIntImplType from brevitas.inject.enum import LearnedRoundImplType from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL @@ -185,6 +187,7 @@ def save_inp_out_data( disable_quant_class = DisableEnableQuantization() disable_quant_class.disable_act_quantization(model, False) disable_quant_class.disable_param_quantization(model, False) + return_quant_tensor_state = disable_return_quant_tensor(model) device = next(model.parameters()).device data_saver = DataSaverHook(store_output=store_out) handle = module.register_forward_hook(data_saver) @@ -213,4 +216,5 @@ def save_inp_out_data( if disable_quant: disable_quant_class.enable_act_quantization(model, False) disable_quant_class.enable_param_quantization(model, False) + restore_return_quant_tensor(model, return_quant_tensor_state) return cached diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index fd5e5c386..8a70e29ba 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -47,6 +47,11 @@ def parse_type(v, default_type): return default_type(v) +def validate_args(args): + if args.learned_round: + assert args.target_backend == "layerwise", "Currently, learned round is only supported with target-backend=layerwise" + + model_names = sorted( name for name in torchvision.models.__dict__ if name.islower() and not name.startswith("__") and callable(torchvision.models.__dict__[name]) and not name.startswith("get_")) @@ -280,6 +285,7 @@ def generate_ref_input(args, device, dtype): def main(): args = parser.parse_args() + validate_args(args) dtype = getattr(torch, args.dtype) random.seed(SEED) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 10d8f7e7c..fbfc76842 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -2,8 +2,10 @@ # SPDX-License-Identifier: BSD-3-Clause import math +from typing import Union from hypothesis import given +import pytest import pytest_cases from pytest_cases import fixture import torch @@ -11,12 +13,16 @@ from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode +from brevitas.graph.calibrate import disable_return_quant_tensor +from brevitas.graph.calibrate import DisableEnableQuantization from brevitas.graph.calibrate import load_quant_model_mode +from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.inject.enum import RestrictValueType import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloat +from brevitas.quant_tensor import QuantTensor # Use custom implementation of kthvalue as work around to (b)float16 kernel limitations from brevitas.utils.torch_utils import kthvalue from tests.brevitas.hyp_helper import float_tensor_random_size_st @@ -307,3 +313,54 @@ def forward(self, inp): for m in model.modules(): if isinstance(m, qnn.QuantLinear): assert m.bias is None + + +class TestDisableEnableQuantization(): + + @fixture + def model(self): + + class TestQuantModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + # Note that the + self.act = qnn.QuantIdentity(return_quant_tensor=True,) + + def forward(self, x: Union[torch.Tensor, + QuantTensor]) -> Union[torch.Tensor, QuantTensor]: + return self.act(x) + + model = TestQuantModel() + model.eval() + return model + + def test_disable_enable_quantization(self, model): + disable_quant_class = DisableEnableQuantization() + # Sample input, not relevant to the task + sample_input = torch.rand(size=(2, 3)) + + # (1) Verify that an appropiate tensor is returned + quant_out = model(sample_input) + assert isinstance(quant_out, QuantTensor) and quant_out.is_valid + + # (2) Disable activation quantisation + disable_quant_class.disable_act_quantization(model, is_training=False) + # Verify that an error is raised when return_quant_tensor=True and + # disable_return_quant_tensor is not applied + with pytest.raises( + AssertionError, + match="QuantLayer is not correctly configured, check if warnings were raised"): + model(sample_input) + + # (3) Disable return quant tensor and verify no error is raised + return_quant_tensor_state = disable_return_quant_tensor(model) + fp_out = model(sample_input) + assert isinstance(fp_out, torch.Tensor) + + # (4) Enable again activation quantisation and check that a QuantTensor + # is returned + restore_return_quant_tensor(model, return_quant_tensor_state) + disable_quant_class.enable_act_quantization(model, is_training=False) + quant_out = model(sample_input) + assert isinstance(quant_out, QuantTensor) and quant_out.is_valid