Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (learned_round): disable return QuantTensor during float inference #1059

Merged
merged 3 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@

from brevitas import config
from brevitas.core.function_wrapper.learned_round import LearnedRoundSte
from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.inject.enum import FloatToIntImplType
from brevitas.inject.enum import LearnedRoundImplType
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
Expand Down Expand Up @@ -185,6 +187,7 @@ def save_inp_out_data(
disable_quant_class = DisableEnableQuantization()
disable_quant_class.disable_act_quantization(model, False)
disable_quant_class.disable_param_quantization(model, False)
return_quant_tensor_state = disable_return_quant_tensor(model)
device = next(model.parameters()).device
data_saver = DataSaverHook(store_output=store_out)
handle = module.register_forward_hook(data_saver)
Expand Down Expand Up @@ -213,4 +216,5 @@ def save_inp_out_data(
if disable_quant:
disable_quant_class.enable_act_quantization(model, False)
disable_quant_class.enable_param_quantization(model, False)
restore_return_quant_tensor(model, return_quant_tensor_state)
return cached
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def parse_type(v, default_type):
return default_type(v)


def validate_args(args):
if args.learned_round:
assert args.target_backend == "layerwise", "Currently, learned round is only supported with target-backend=layerwise"


model_names = sorted(
name for name in torchvision.models.__dict__ if name.islower() and not name.startswith("__") and
callable(torchvision.models.__dict__[name]) and not name.startswith("get_"))
Expand Down Expand Up @@ -280,6 +285,7 @@ def generate_ref_input(args, device, dtype):

def main():
args = parser.parse_args()
validate_args(args)
dtype = getattr(torch, args.dtype)

random.seed(SEED)
Expand Down
57 changes: 57 additions & 0 deletions tests/brevitas/graph/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,27 @@
# SPDX-License-Identifier: BSD-3-Clause

import math
from typing import Union

from hypothesis import given
import pytest
import pytest_cases
from pytest_cases import fixture
import torch
import torch.nn as nn

from brevitas.graph.calibrate import bias_correction_mode
from brevitas.graph.calibrate import calibration_mode
from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.calibrate import load_quant_model_mode
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.inject.enum import RestrictValueType
import brevitas.nn as qnn
from brevitas.quant import Int8ActPerTensorFixedPoint
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant_tensor import QuantTensor
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue
from tests.brevitas.hyp_helper import float_tensor_random_size_st
Expand Down Expand Up @@ -307,3 +313,54 @@ def forward(self, inp):
for m in model.modules():
if isinstance(m, qnn.QuantLinear):
assert m.bias is None


class TestDisableEnableQuantization():

@fixture
def model(self):

class TestQuantModel(nn.Module):

def __init__(self) -> None:
super().__init__()
# Note that the
self.act = qnn.QuantIdentity(return_quant_tensor=True,)

def forward(self, x: Union[torch.Tensor,
QuantTensor]) -> Union[torch.Tensor, QuantTensor]:
return self.act(x)

model = TestQuantModel()
model.eval()
return model

def test_disable_enable_quantization(self, model):
disable_quant_class = DisableEnableQuantization()
# Sample input, not relevant to the task
sample_input = torch.rand(size=(2, 3))

# (1) Verify that an appropiate tensor is returned
quant_out = model(sample_input)
assert isinstance(quant_out, QuantTensor) and quant_out.is_valid

# (2) Disable activation quantisation
disable_quant_class.disable_act_quantization(model, is_training=False)
# Verify that an error is raised when return_quant_tensor=True and
# disable_return_quant_tensor is not applied
with pytest.raises(
AssertionError,
match="QuantLayer is not correctly configured, check if warnings were raised"):
model(sample_input)

# (3) Disable return quant tensor and verify no error is raised
return_quant_tensor_state = disable_return_quant_tensor(model)
fp_out = model(sample_input)
assert isinstance(fp_out, torch.Tensor)

# (4) Enable again activation quantisation and check that a QuantTensor
# is returned
restore_return_quant_tensor(model, return_quant_tensor_state)
disable_quant_class.enable_act_quantization(model, is_training=False)
quant_out = model(sample_input)
assert isinstance(quant_out, QuantTensor) and quant_out.is_valid
Loading