Skip to content

Commit

Permalink
Fix (learned_round): disable return QuantTensor during float inference (
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago authored Oct 21, 2024
1 parent 59f8df7 commit 7307942
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@

from brevitas import config
from brevitas.core.function_wrapper.learned_round import LearnedRoundSte
from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.inject.enum import FloatToIntImplType
from brevitas.inject.enum import LearnedRoundImplType
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
Expand Down Expand Up @@ -185,6 +187,7 @@ def save_inp_out_data(
disable_quant_class = DisableEnableQuantization()
disable_quant_class.disable_act_quantization(model, False)
disable_quant_class.disable_param_quantization(model, False)
return_quant_tensor_state = disable_return_quant_tensor(model)
device = next(model.parameters()).device
data_saver = DataSaverHook(store_output=store_out)
handle = module.register_forward_hook(data_saver)
Expand Down Expand Up @@ -213,4 +216,5 @@ def save_inp_out_data(
if disable_quant:
disable_quant_class.enable_act_quantization(model, False)
disable_quant_class.enable_param_quantization(model, False)
restore_return_quant_tensor(model, return_quant_tensor_state)
return cached
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def parse_type(v, default_type):
return default_type(v)


def validate_args(args):
if args.learned_round:
assert args.target_backend == "layerwise", "Currently, learned round is only supported with target-backend=layerwise"


model_names = sorted(
name for name in torchvision.models.__dict__ if name.islower() and not name.startswith("__") and
callable(torchvision.models.__dict__[name]) and not name.startswith("get_"))
Expand Down Expand Up @@ -280,6 +285,7 @@ def generate_ref_input(args, device, dtype):

def main():
args = parser.parse_args()
validate_args(args)
dtype = getattr(torch, args.dtype)

random.seed(SEED)
Expand Down
57 changes: 57 additions & 0 deletions tests/brevitas/graph/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,27 @@
# SPDX-License-Identifier: BSD-3-Clause

import math
from typing import Union

from hypothesis import given
import pytest
import pytest_cases
from pytest_cases import fixture
import torch
import torch.nn as nn

from brevitas.graph.calibrate import bias_correction_mode
from brevitas.graph.calibrate import calibration_mode
from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.calibrate import load_quant_model_mode
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.inject.enum import RestrictValueType
import brevitas.nn as qnn
from brevitas.quant import Int8ActPerTensorFixedPoint
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant_tensor import QuantTensor
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue
from tests.brevitas.hyp_helper import float_tensor_random_size_st
Expand Down Expand Up @@ -307,3 +313,54 @@ def forward(self, inp):
for m in model.modules():
if isinstance(m, qnn.QuantLinear):
assert m.bias is None


class TestDisableEnableQuantization():

@fixture
def model(self):

class TestQuantModel(nn.Module):

def __init__(self) -> None:
super().__init__()
# Note that the
self.act = qnn.QuantIdentity(return_quant_tensor=True,)

def forward(self, x: Union[torch.Tensor,
QuantTensor]) -> Union[torch.Tensor, QuantTensor]:
return self.act(x)

model = TestQuantModel()
model.eval()
return model

def test_disable_enable_quantization(self, model):
disable_quant_class = DisableEnableQuantization()
# Sample input, not relevant to the task
sample_input = torch.rand(size=(2, 3))

# (1) Verify that an appropiate tensor is returned
quant_out = model(sample_input)
assert isinstance(quant_out, QuantTensor) and quant_out.is_valid

# (2) Disable activation quantisation
disable_quant_class.disable_act_quantization(model, is_training=False)
# Verify that an error is raised when return_quant_tensor=True and
# disable_return_quant_tensor is not applied
with pytest.raises(
AssertionError,
match="QuantLayer is not correctly configured, check if warnings were raised"):
model(sample_input)

# (3) Disable return quant tensor and verify no error is raised
return_quant_tensor_state = disable_return_quant_tensor(model)
fp_out = model(sample_input)
assert isinstance(fp_out, torch.Tensor)

# (4) Enable again activation quantisation and check that a QuantTensor
# is returned
restore_return_quant_tensor(model, return_quant_tensor_state)
disable_quant_class.enable_act_quantization(model, is_training=False)
quant_out = model(sample_input)
assert isinstance(quant_out, QuantTensor) and quant_out.is_valid

0 comments on commit 7307942

Please sign in to comment.