Skip to content

Commit

Permalink
convert to fp32 before rounding scale down to power of 2; update unit…
Browse files Browse the repository at this point in the history
… tests
  • Loading branch information
danielvegamyhre committed Feb 6, 2025
1 parent c434498 commit 40166e1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 55 deletions.
53 changes: 18 additions & 35 deletions test/float8/test_float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,30 @@

# source for notable single-precision cases:
# https://en.wikipedia.org/wiki/Single-precision_floating-point_format
#
# TODO(danielvegamyhre):
# 1. add case for largest normal fp32 value: 2**127 * (2 - 2**-23).
# need to investigate why exp2(floor(log2(x)))=inf, but bitshift returns real value.
# 2. add case for "nan"
# need to investigate why exp2(floor(log2(nan)))=nan, but bitshift returns inf.
# 3. adjust cases for subnormal values so we aren't clamping the expected results
# into the normal range.
# preliminary investigation shows it may not be possible to support all subnormals
# with bitshifting, so we will need to debug/improve performance of exp2(floor(log2(x)))
# approach.
@pytest.mark.parametrize(
"input",
"test_case",
[
1.0,
float("inf"),
# smallest positive subnormal number
2**-126 * 2**-23,
# largest subnormal number
2**-126 * (1 - 2**-23),
# smallest positive normal number
2**-126,
# largest number less than one
1.0 - 2**-24,
# smallest number larger than one
1.0 + 2**-23,
# "test_case_name": [input, expected result]
("one", [1.0, 1.0]),
("inf", [float("inf"), float("inf")]),
("smallest positive subnormal number", [2**-126 * 2**-23, 2**-126 * 2**-23]),
("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]),
("largest normal number", [2**127 * (2 - 2**-23), float("inf")]),
("smallest positive normal number", [2**-126, 2**-126]),
("largest number less than one", [1.0 - 2**-24, 0.5]),
("smallest number larger than one", [1.0 + 2**-23, 1.0]),
],
)
def test_round_scale_down_to_power_of_2_valid_inputs(input: float):
input_tensor = torch.tensor(input, dtype=torch.float32)
result = _round_scale_down_to_power_of_2(input_tensor)

# get expected value for comparison
# TODO(danielvegamyhre): support subnormal values
expected_result = torch.exp2(torch.floor(torch.log2(input_tensor)))
smallest_normal_fp32_value = torch.tensor(2**-126, dtype=torch.float32)
expected_result = torch.max(expected_result, smallest_normal_fp32_value)
def test_round_scale_down_to_power_of_2_valid_inputs(
test_case: dict,
):
test_case_name, (input, expected_result) = test_case
input_tensor, expected_tensor = torch.tensor(input), torch.tensor(expected_result)

result = _round_scale_down_to_power_of_2(input_tensor)
assert torch.equal(
result, expected_result
), f"input: {input_tensor}, expected {expected_result}, but got {result}"
result, expected_tensor
), f"test: {test_case_name}, input: {input_tensor}, expected {expected_tensor}, but got {result}"


@pytest.mark.parametrize(
Expand Down
23 changes: 3 additions & 20 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,23 +285,6 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool:
)


def _round_scale_down_to_power_of_2(x: torch.Tensor):
assert x.dtype == torch.float32, "scale must be float32 tensor"

# eps = smallest normal fp32 value
# TODO(danielvegamyhre): support subnormal values
eps = 2**-126
x = torch.clamp(
x,
min=eps,
)

# view as int32 to allow bitshifting
x_int = x.view(torch.int32)

# clear mantissa bits (rightmost 23 bits)
x_int = (x_int >> 23) << 23

# return result as fp32
result = x_int.view(torch.float32)
return result
def _round_scale_down_to_power_of_2(scale: torch.Tensor):
assert scale.dtype == torch.float32, "scale must be float32 tensor"
return torch.exp2(torch.floor(torch.log2(scale)))

0 comments on commit 40166e1

Please sign in to comment.