Skip to content

Commit

Permalink
fix linter issues
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre committed Feb 5, 2025
1 parent f2433b1 commit ecc23ae
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
12 changes: 6 additions & 6 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import torch.nn as nn

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
is_sm_at_least_89,
is_sm_at_least_90,
TORCH_VERSION_AT_LEAST_2_5,
)

if not TORCH_VERSION_AT_LEAST_2_5:
Expand All @@ -26,13 +26,13 @@

from torchao.float8.config import (
CastConfig,
e4m3_dtype,
e5m2_dtype,
Float8LinearConfig,
Float8LinearRecipeName,
recipe_name_to_linear_config,
ScalingGranularity,
ScalingType,
e4m3_dtype,
e5m2_dtype,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
Expand All @@ -48,15 +48,15 @@
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
hp_tensor_and_scale_to_float8,
)
from torchao.float8.float8_utils import (
FP8_TYPES,
compute_error,
config_has_stateful_scaling,
fp8_tensor_statistics,
FP8_TYPES,
tensor_to_scale,
)
from torchao.float8.stateful_float8_linear import StatefulFloat8Linear
Expand Down
9 changes: 5 additions & 4 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import pytest

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
is_sm_at_least_89,
is_sm_at_least_90,
TORCH_VERSION_AT_LEAST_2_5,
)

if not TORCH_VERSION_AT_LEAST_2_5:
Expand All @@ -29,11 +29,11 @@
from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes
from torchao.float8.config import (
CastConfig,
e4m3_dtype,
Float8LinearConfig,
Float8LinearRecipeName,
recipe_name_to_linear_config,
ScalingType,
e4m3_dtype,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
Expand Down Expand Up @@ -479,7 +479,8 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_delayed_scaling_pattern_replacement(dtype: torch.dtype):
from torch._inductor import config as inductor_config, metrics
from torch._inductor import config as inductor_config
from torch._inductor import metrics

inductor_config.loop_ordering_after_fusion = True

Expand Down

0 comments on commit ecc23ae

Please sign in to comment.