Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (minifloat): correct minifloat computation and tests #1067

Merged
merged 3 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/brevitas/quant_tensor/float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ def _pre_round_float_value(self):
scale = self.scale.type(torch.float32)
minifloat_value = value / scale
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
eps = torch.finfo(scale.dtype).tiny
int_scale = float_internal_scale(
self.value, self.mantissa_bit_width, fp_internal_scale, self.eps)
minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps)
minifloat_value = minifloat_value / int_scale
return minifloat_value

Expand Down Expand Up @@ -137,11 +138,17 @@ def device(self):
def minifloat(self, float_datatype=True):
# TODO: Check if OCP and cast to proper data-type if matching
assert float_datatype, "Minifloat quant returns only higher precision dtype"

if self.is_valid:
value = self.value
scale = self.scale
if self.scale.dtype == torch.bfloat16:
value = self.value.type(torch.float32)
scale = self.scale.type(torch.float32)
minifloat_value = value / scale
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
eps = torch.finfo(scale.dtype).tiny
int_scale = float_internal_scale(
self.value, self.mantissa_bit_width, fp_internal_scale, self.eps)
minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps)
float_value = torch.round(self._pre_round_float_value) * int_scale
return float_value.type(self.scale.dtype)
else:
Expand Down
13 changes: 11 additions & 2 deletions src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def _pre_round_float_value(self):
scale = scale.type(torch.float32)
minifloat_value = value / scale
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale)
eps = torch.finfo(scale.dtype).tiny
int_scale = float_internal_scale(
minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps)
minifloat_value = minifloat_value / int_scale
return minifloat_value

Expand Down Expand Up @@ -179,8 +181,15 @@ def minifloat(self, float_datatype=True):
assert float_datatype, "Minifloat quant returns only higher precision dtype"

if self.is_valid:
value, scale, zp = self.expand()
if self.scale.dtype == torch.bfloat16:
value = value.type(torch.float32)
scale = scale.type(torch.float32)
minifloat_value = value / scale
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale)
eps = torch.finfo(scale.dtype).tiny
int_scale = float_internal_scale(
minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps)
float_value = torch.round(self._pre_round_float_value) * int_scale
return float_value.type(self.scale.dtype)
else:
Expand Down
19 changes: 19 additions & 0 deletions tests/brevitas/quant_tensor/test_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

from packaging import version
import pytest
import pytest_cases
import torch

from brevitas import torch_version
from brevitas.nn import QuantIdentity
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPActPerTensorFloat
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.quant_tensor import FloatQuantTensor
from brevitas.quant_tensor import IntQuantTensor

Expand Down Expand Up @@ -119,3 +122,19 @@ def test_quant_tensor_view():
assert torch.allclose(a.view(2, -1), b.view(2, -1), atol=0.01)
assert torch.allclose(a.view(16, -1), b.view(16, -1), atol=0.01)
assert torch.allclose(a.view(8, 2), b.view(8, 2), atol=0.01)


QUANT_CLASS = {'fp8': Fp8e4m3ActPerTensorFloat, 'mxfp8': MXFloat8e4m3Act}


@pytest_cases.parametrize('quant_class_key_vale', QUANT_CLASS.items())
def test_minifloat(quant_class_key_vale):
key, quant_class = quant_class_key_vale

x = torch.randn((1, 32))
q = QuantIdentity(quant_class, group_dim=-1, return_quant_tensor=True)
q.eval()

qx = q(x)
# Check that minifloat doesn't raise error
qx.minifloat()
Loading