Skip to content

Commit

Permalink
Fix (minifloat): correct minifloat computation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 23, 2024
1 parent 59f8df7 commit b4c9d34
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 4 deletions.
6 changes: 4 additions & 2 deletions src/brevitas/quant_tensor/float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ def _pre_round_float_value(self):
scale = self.scale.type(torch.float32)
minifloat_value = value / scale
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
eps = torch.finfo(self.scale.dtype).tiny
int_scale = float_internal_scale(
self.value, self.mantissa_bit_width, fp_internal_scale, self.eps)
self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps)
minifloat_value = minifloat_value / int_scale
return minifloat_value

Expand Down Expand Up @@ -140,8 +141,9 @@ def minifloat(self, float_datatype=True):

if self.is_valid:
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
eps = torch.finfo(self.scale.dtype).tiny
int_scale = float_internal_scale(
self.value, self.mantissa_bit_width, fp_internal_scale, self.eps)
self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps)
float_value = torch.round(self._pre_round_float_value) * int_scale
return float_value.type(self.scale.dtype)
else:
Expand Down
8 changes: 6 additions & 2 deletions src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def _pre_round_float_value(self):
scale = scale.type(torch.float32)
minifloat_value = value / scale
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale)
eps = torch.finfo(self.scale_.dtype).tiny
int_scale = float_internal_scale(
self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps)
minifloat_value = minifloat_value / int_scale
return minifloat_value

Expand Down Expand Up @@ -180,7 +182,9 @@ def minifloat(self, float_datatype=True):

if self.is_valid:
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale)
eps = torch.finfo(self.scale_.dtype).tiny
int_scale = float_internal_scale(
self.value / self.scale, self.mantissa_bit_width, fp_internal_scale, eps)
float_value = torch.round(self._pre_round_float_value) * int_scale
return float_value.type(self.scale.dtype)
else:
Expand Down
19 changes: 19 additions & 0 deletions tests/brevitas/quant_tensor/test_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

from packaging import version
import pytest
import pytest_cases
import torch

from brevitas import torch_version
from brevitas.nn import QuantIdentity
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPActPerTensorFloat
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.quant_tensor import FloatQuantTensor
from brevitas.quant_tensor import IntQuantTensor

Expand Down Expand Up @@ -119,3 +122,19 @@ def test_quant_tensor_view():
assert torch.allclose(a.view(2, -1), b.view(2, -1), atol=0.01)
assert torch.allclose(a.view(16, -1), b.view(16, -1), atol=0.01)
assert torch.allclose(a.view(8, 2), b.view(8, 2), atol=0.01)


QUANT_CLASS = {'fp8': Fp8e4m3ActPerTensorFloat, 'mxfp8': MXFloat8e4m3Act}


@pytest_cases.parametrize('quant_class_key_vale', QUANT_CLASS.items())
def test_minifloat(quant_class_key_vale):
key, quant_class = quant_class_key_vale

x = torch.randn((1, 32))
q = QuantIdentity(quant_class, group_dim=-1, return_quant_tensor=True)
q.eval()

qx = q(x)
# Check that minifloat doesn't raise error
qx.minifloat()

0 comments on commit b4c9d34

Please sign in to comment.