Skip to content

Commit

Permalink
Test (graph/calibrate): add calibration reference (#1031)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Oct 1, 2024
1 parent b28ac0f commit 1e71a04
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions tests/brevitas/graph/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math

from hypothesis import given
import pytest_cases
from pytest_cases import fixture
import torch
import torch.nn as nn
Expand All @@ -13,14 +14,23 @@
from brevitas.graph.calibrate import load_quant_model_mode
import brevitas.nn as qnn
from brevitas.quant import Int8ActPerTensorFixedPoint
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue
from tests.brevitas.hyp_helper import float_tensor_random_size_st
from tests.conftest import SEED

torch.manual_seed(SEED)
IN_CH = 8
OUT_CH = 16
BATCH = 1
REFERENCE_SCALES = {
'int_quant': (0.00935234408825635910, 0.01362917013466358185),
'fp_quant': (0.00249395845457911491, 0.00363444536924362183)}
REFERENCE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]])
REFERENCE_WEIGHTS = torch.tensor([[1.0023, 0.0205, 1.4604], [-0.2918, -1.8218, -0.7010],
[1.4573, -0.9074, -0.2708]])


def compute_quantile(x, q):
Expand Down Expand Up @@ -65,6 +75,41 @@ def forward(self, x):
assert torch.allclose(expected_scale, scale)


QUANTS = {'int_quant': Int8ActPerTensorFloat, 'fp_quant': Fp8e4m3ActPerTensorFloat}


@pytest_cases.parametrize("act_quant", QUANTS.items(), ids=QUANTS.keys())
def test_scale_factors_ptq_calibration_reference(act_quant):

reference, act_quant = act_quant

class TestModel(nn.Module):

def __init__(self):
super(TestModel, self).__init__()
self.act = qnn.QuantReLU(act_quant=act_quant)
self.linear_weights = REFERENCE_WEIGHTS
self.act_1 = qnn.QuantIdentity(act_quant=act_quant)

def forward(self, x):
o = self.act(x)
o = torch.matmul(o, self.linear_weights)
return self.act_1(o)

# Reference input
inp = REFERENCE_INP
model = TestModel()
model.eval()
with torch.no_grad():
with calibration_mode(model):
model(inp)

computed_scale = model.act.act_quant.scale(), model.act_1.act_quant.scale()
reference_values = REFERENCE_SCALES[reference]
assert torch.allclose(computed_scale[0], torch.tensor(reference_values[0]))
assert torch.allclose(computed_scale[1], torch.tensor(reference_values[1]))


def test_calibration_training_state():

class TestModel(nn.Module):
Expand Down

0 comments on commit 1e71a04

Please sign in to comment.