diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py index f907a3112a..184fa21343 100644 --- a/test/float8/test_float8_utils.py +++ b/test/float8/test_float8_utils.py @@ -6,47 +6,30 @@ # source for notable single-precision cases: # https://en.wikipedia.org/wiki/Single-precision_floating-point_format -# -# TODO(danielvegamyhre): -# 1. add case for largest normal fp32 value: 2**127 * (2 - 2**-23). -# need to investigate why exp2(floor(log2(x)))=inf, but bitshift returns real value. -# 2. add case for "nan" -# need to investigate why exp2(floor(log2(nan)))=nan, but bitshift returns inf. -# 3. adjust cases for subnormal values so we aren't clamping the expected results -# into the normal range. -# preliminary investigation shows it may not be possible to support all subnormals -# with bitshifting, so we will need to debug/improve performance of exp2(floor(log2(x))) -# approach. @pytest.mark.parametrize( - "input", + "test_case", [ - 1.0, - float("inf"), - # smallest positive subnormal number - 2**-126 * 2**-23, - # largest subnormal number - 2**-126 * (1 - 2**-23), - # smallest positive normal number - 2**-126, - # largest number less than one - 1.0 - 2**-24, - # smallest number larger than one - 1.0 + 2**-23, + # "test_case_name": [input, expected result] + ("one", [1.0, 1.0]), + ("inf", [float("inf"), float("inf")]), + ("smallest positive subnormal number", [2**-126 * 2**-23, 2**-126 * 2**-23]), + ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]), + ("largest normal number", [2**127 * (2 - 2**-23), float("inf")]), + ("smallest positive normal number", [2**-126, 2**-126]), + ("largest number less than one", [1.0 - 2**-24, 0.5]), + ("smallest number larger than one", [1.0 + 2**-23, 1.0]), ], ) -def test_round_scale_down_to_power_of_2_valid_inputs(input: float): - input_tensor = torch.tensor(input, dtype=torch.float32) - result = _round_scale_down_to_power_of_2(input_tensor) - - # get expected value for comparison - # TODO(danielvegamyhre): support subnormal values - expected_result = torch.exp2(torch.floor(torch.log2(input_tensor))) - smallest_normal_fp32_value = torch.tensor(2**-126, dtype=torch.float32) - expected_result = torch.max(expected_result, smallest_normal_fp32_value) +def test_round_scale_down_to_power_of_2_valid_inputs( + test_case: dict, +): + test_case_name, (input, expected_result) = test_case + input_tensor, expected_tensor = torch.tensor(input), torch.tensor(expected_result) + result = _round_scale_down_to_power_of_2(input_tensor) assert torch.equal( - result, expected_result - ), f"input: {input_tensor}, expected {expected_result}, but got {result}" + result, expected_tensor + ), f"test: {test_case_name}, input: {input_tensor}, expected {expected_tensor}, but got {result}" @pytest.mark.parametrize( diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index ea669c08b4..926b97edb8 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -285,23 +285,6 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool: ) -def _round_scale_down_to_power_of_2(x: torch.Tensor): - assert x.dtype == torch.float32, "scale must be float32 tensor" - - # eps = smallest normal fp32 value - # TODO(danielvegamyhre): support subnormal values - eps = 2**-126 - x = torch.clamp( - x, - min=eps, - ) - - # view as int32 to allow bitshifting - x_int = x.view(torch.int32) - - # clear mantissa bits (rightmost 23 bits) - x_int = (x_int >> 23) << 23 - - # return result as fp32 - result = x_int.view(torch.float32) - return result +def _round_scale_down_to_power_of_2(scale: torch.Tensor): + assert scale.dtype == torch.float32, "scale must be float32 tensor" + return torch.exp2(torch.floor(torch.log2(scale)))