diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 14c31014c3..0980975403 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -33,13 +33,19 @@ jobs: torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/cpu' gpu-arch-type: "cpu" gpu-arch-version: "" - + - name: ROCM Nightly + runs-on: linux.rocm.gpu.torchao + if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} + torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3' + gpu-arch-type: "rocm" + gpu-arch-version: "6.3" permissions: id-token: write contents: read uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: timeout: 120 + no-sudo: ${{ matrix.gpu-arch-type == 'rocm' }} runner: ${{ matrix.runs-on }} gpu-arch-type: ${{ matrix.gpu-arch-type }} gpu-arch-version: ${{ matrix.gpu-arch-version }} @@ -74,7 +80,6 @@ jobs: torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121' gpu-arch-type: "cuda" gpu-arch-version: "12.1" - - name: CPU 2.3 runs-on: linux.4xlarge torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu' @@ -102,8 +107,6 @@ jobs: conda create -n venv python=3.9 -y conda activate venv echo "::group::Install newer objcopy that supports --set-section-alignment" - yum install -y devtoolset-10-binutils - export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH python -m pip install --upgrade pip pip install ${{ matrix.torch-spec }} pip install -r dev-requirements.txt diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 52b25dab82..a097972515 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -22,6 +22,7 @@ TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_sm_at_least_89, + skip_if_rocm, ) is_cusparselt_available = ( @@ -100,6 +101,7 @@ def test_tensor_core_layout_transpose(self): "apply_quant", get_quantization_functions(is_cusparselt_available, True, "cuda", True), ) + @skip_if_rocm("ROCm enablement in progress") def test_weights_only(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") ql = apply_quant(linear) @@ -179,6 +181,7 @@ def apply_uint6_weight_only_quant(linear): "apply_quant", get_quantization_functions(is_cusparselt_available, True) ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") def test_print_quantized_module(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") ql = apply_quant(linear) @@ -191,6 +194,7 @@ class TestAffineQuantizedBasic(TestCase): @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) + @skip_if_rocm("ROCm enablement in progress") def test_flatten_unflatten(self, device, dtype): apply_quant_list = get_quantization_functions(False, True, device) for apply_quant in apply_quant_list: diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 76b6b74a3d..b60f3251dc 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -1,5 +1,6 @@ import unittest +import pytest import torch from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.testing._internal import common_utils @@ -27,6 +28,9 @@ except ModuleNotFoundError: has_gemlite = False +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) + class TestAffineQuantizedTensorParallel(DTensorTestBase): """Basic test case for tensor subclasses""" diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 8bb39b2cc8..f321d81b9e 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -27,7 +27,7 @@ fpx_weight_only, quantize_, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode, skip_if_rocm _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) _Floatx_DTYPES = [(3, 2), (2, 2)] @@ -109,6 +109,7 @@ def test_to_copy_device(self, ebits, mbits): @parametrize("bias", [False, True]) @parametrize("dtype", [torch.half, torch.bfloat16]) @unittest.skipIf(is_fbcode(), reason="broken in fbcode") + @skip_if_rocm("ROCm enablement in progress") def test_fpx_weight_only(self, ebits, mbits, bias, dtype): N, OC, IC = 4, 256, 64 device = "cuda" diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index caa1a6c7bd..a5190fb679 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -33,6 +33,7 @@ nf4_weight_only, to_nf4, ) +from torchao.utils import skip_if_rocm bnb_available = False @@ -111,6 +112,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype): @unittest.skipIf(not bnb_available, "Need bnb availble") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype): # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47 @@ -133,6 +135,7 @@ def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype): @unittest.skipIf(not bnb_available, "Need bnb availble") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_nf4_bnb_linear(self, dtype: torch.dtype): """ diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index e148d68abb..9d0c4e82df 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -28,7 +28,7 @@ from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm def _apply_weight_only_uint4_quant(model): @@ -92,6 +92,7 @@ def test_basic_tensor_ops(self): # only test locally # print("x:", x[0]) + @skip_if_rocm("ROCm enablement in progress") def test_gpu_quant(self): for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: x = torch.randn(*x_shape) @@ -104,6 +105,7 @@ def test_gpu_quant(self): # make sure it runs opt(x) + @skip_if_rocm("ROCm enablement in progress") def test_pt2e_quant(self): from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( QuantizationConfig, diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 3e894c02b9..7bd287b537 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -18,6 +18,7 @@ TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89, is_sm_at_least_90, + skip_if_rocm, ) if not TORCH_VERSION_AT_LEAST_2_5: @@ -423,6 +424,7 @@ def test_linear_from_config_params( @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize("linear_bias", [True, False]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @skip_if_rocm("ROCm enablement in progress") def test_linear_from_recipe( self, recipe_name, diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index fbe5c9b508..0beb012406 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -43,6 +43,9 @@ if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) +if torch.version.hip is not None: + pytest.skip("ROCm enablement in progress", allow_module_level=True) + class TestFloat8Common: def broadcast_module(self, module: nn.Module) -> None: diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 381886d594..41833859c3 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -10,6 +10,7 @@ ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, + skip_if_rocm, ) cuda_available = torch.cuda.is_available() @@ -110,6 +111,7 @@ def test_hqq_plain_5bit(self): ref_dot_product_error=0.000704, ) + @skip_if_rocm("ROCm enablement in progress") def test_hqq_plain_4bit(self): self._test_hqq( dtype=torch.uint4, diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 56bcaf17df..8327580748 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -85,6 +85,7 @@ benchmark_model, is_fbcode, is_sm_at_least_90, + skip_if_rocm, unwrap_tensor_subclass, ) @@ -95,6 +96,7 @@ except ModuleNotFoundError: has_gemlite = False + logger = logging.getLogger("INFO") torch.manual_seed(0) @@ -582,6 +584,7 @@ def test_per_token_linear_cpu(self): self._test_per_token_linear_impl("cpu", dtype) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") def test_per_token_linear_cuda(self): for dtype in (torch.float32, torch.float16, torch.bfloat16): self._test_per_token_linear_impl("cuda", dtype) @@ -700,6 +703,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -719,6 +723,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -912,6 +917,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -931,6 +937,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1102,6 +1109,7 @@ def test_gemlite_layout(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") diff --git a/test/kernel/test_fused_kernels.py b/test/kernel/test_fused_kernels.py index c5bf6e17f0..cad1f001ff 100644 --- a/test/kernel/test_fused_kernels.py +++ b/test/kernel/test_fused_kernels.py @@ -11,6 +11,8 @@ import torch from galore_test_utils import get_kernel, make_copy, make_data +from torchao.utils import skip_if_rocm + torch.manual_seed(0) MAX_DIFF_no_tf32 = 1e-5 MAX_DIFF_tf32 = 1e-3 @@ -104,6 +106,7 @@ def run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32): @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") @pytest.mark.parametrize("kernel, dtype, M, N, rank, allow_tf32", TEST_CONFIGS) +@skip_if_rocm("ROCm enablement in progress") def test_galore_fused_kernels(kernel, dtype, M, N, rank, allow_tf32): torch.backends.cuda.matmul.allow_tf32 = allow_tf32 diff --git a/test/kernel/test_galore_downproj.py b/test/kernel/test_galore_downproj.py index bab65fc2fb..2388f0be63 100644 --- a/test/kernel/test_galore_downproj.py +++ b/test/kernel/test_galore_downproj.py @@ -11,6 +11,7 @@ from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk from torchao.prototype.galore.kernels.matmul import triton_mm_launcher +from torchao.utils import skip_if_rocm torch.manual_seed(0) @@ -29,6 +30,7 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") @pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS) +@skip_if_rocm("ROCm enablement in progress") def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype): torch.backends.cuda.matmul.allow_tf32 = allow_tf32 MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32 diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 1b91983bc0..409518ae9a 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -5,7 +5,11 @@ import torch from torchao.quantization import quantize_ -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_5, + skip_if_rocm, +) if TORCH_VERSION_AT_LEAST_2_3: from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_ @@ -113,6 +117,7 @@ def test_awq_loading(device, qdtype): @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@skip_if_rocm("ROCm enablement in progress") def test_save_weights_only(): dataset_size = 100 l1, l2, l3 = 512, 256, 128 diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index d7d6fe7dc8..5ce3d08b81 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -30,6 +30,7 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, get_available_devices, + skip_if_rocm, ) try: @@ -42,6 +43,8 @@ except ImportError: lpmm = None +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) _DEVICES = get_available_devices() @@ -112,6 +115,7 @@ class TestOptim(TestCase): ) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) + @skip_if_rocm("ROCm enablement in progress") def test_optim_smoke(self, optim_name, dtype, device): if optim_name.endswith("Fp8") and device == "cuda": if not TORCH_VERSION_AT_LEAST_2_4: @@ -185,6 +189,7 @@ def test_subclass_slice(self, subclass, shape, device): not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA", ) + @skip_if_rocm("ROCm enablement in progress") @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) def test_optim_8bit_correctness(self, optim_name): device = "cuda" @@ -413,6 +418,7 @@ def world_size(self) -> int: not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required." ) @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) + @skip_if_rocm("ROCm enablement in progress") def test_fsdp2(self): optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit] if torch.cuda.get_device_capability() >= (8, 9): @@ -523,6 +529,7 @@ def _test_fsdp2(self, optim_cls): not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required." ) @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) + @skip_if_rocm("ROCm enablement in progress") def test_uneven_shard(self): in_dim = 512 out_dim = _FSDP_WORLD_SIZE * 16 + 1 diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 02b41e8e32..d90990143c 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -20,6 +20,9 @@ TORCH_VERSION_AT_LEAST_2_5, ) +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) + class ToyLinearModel(torch.nn.Module): def __init__(self, m=512, n=256, k=128): diff --git a/test/prototype/test_splitk.py b/test/prototype/test_splitk.py index 48793ba907..04fdd7cff2 100644 --- a/test/prototype/test_splitk.py +++ b/test/prototype/test_splitk.py @@ -13,13 +13,15 @@ except ImportError: triton_available = False -from torchao.utils import skip_if_compute_capability_less_than + +from torchao.utils import skip_if_compute_capability_less_than, skip_if_rocm @unittest.skipIf(not triton_available, "Triton is required but not available") @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") class TestFP8Gemm(TestCase): @skip_if_compute_capability_less_than(9.0) + @skip_if_rocm("ROCm enablement in progress") def test_gemm_split_k(self): dtype = torch.float16 qdtype = torch.float8_e4m3fn diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index 3eb9b0a2c5..277bf6a49f 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -18,6 +18,7 @@ triton_dequant_blockwise, triton_quantize_blockwise, ) +from torchao.utils import skip_if_rocm SEED = 0 torch.manual_seed(SEED) @@ -82,6 +83,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): "dim1,dim2,dtype,signed,blocksize", TEST_CONFIGS, ) +@skip_if_rocm("ROCm enablement in progress") def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize): g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index ebdf2281e0..2dc2377f02 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -19,13 +19,14 @@ MappingType, choose_qparams_and_quantize_affine_qqq, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode, skip_if_rocm @unittest.skipIf( is_fbcode(), "Skipping the test in fbcode since we don't have TARGET file for kernels", ) +@skip_if_rocm("ROCm enablement in progress") class TestMarlinQQQ(TestCase): def setUp(self): super().setUp() @@ -45,6 +46,7 @@ def setUp(self): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_marlin_qqq(self): output_ref = self.model(self.input) for group_size in [-1, 128]: @@ -66,6 +68,7 @@ def test_marlin_qqq(self): @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_marlin_qqq_compile(self): model_copy = copy.deepcopy(self.model) model_copy.forward = torch.compile(model_copy.forward, fullgraph=True) diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index 4da7304a24..c8bdee5e2f 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -15,7 +15,7 @@ ) from torchao.sparsity.marlin import inject_24, pack_to_marlin_24, unpack_from_marlin_24 from torchao.sparsity.sparse_api import apply_fake_sparsity -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm class SparseMarlin24(TestCase): @@ -37,6 +37,7 @@ def setUp(self): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") def test_quant_sparse_marlin_layout_eager(self): apply_fake_sparsity(self.model) model_copy = copy.deepcopy(self.model) @@ -48,13 +49,13 @@ def test_quant_sparse_marlin_layout_eager(self): # Sparse + quantized quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout())) sparse_result = self.model(self.input) - assert torch.allclose( dense_result, sparse_result, atol=3e-1 ), "Results are not close" @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") def test_quant_sparse_marlin_layout_compile(self): apply_fake_sparsity(self.model) model_copy = copy.deepcopy(self.model) diff --git a/test/test_ops.py b/test/test_ops.py index 54efefb026..107b7e8389 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -20,6 +20,9 @@ from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff, is_fbcode +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) + if is_fbcode(): pytest.skip( "Skipping the test in fbcode since we don't have TARGET file for kernels" @@ -37,6 +40,9 @@ pack_tinygemm_scales_and_zeros, ) +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) + class TestOps(TestCase): def _create_floatx_inputs( diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index 95175caacf..abf09cd2f9 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -183,7 +183,7 @@ def __tensor_unflatten__( def get_plain(self): from torchao.quantization.marlin_qqq import ( unpack_from_marlin_qqq, - ) # avoid circular import + ) int_data_expanded, s_group_expanded, s_channel_expanded = ( unpack_from_marlin_qqq( @@ -211,7 +211,7 @@ def from_plain( from torchao.quantization.marlin_qqq import ( const, pack_to_marlin_qqq, - ) # avoid circular import + ) assert isinstance(_layout, MarlinQQQLayout) diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py index 22763eb0c2..01d4562b7f 100644 --- a/torchao/dtypes/uintx/marlin_sparse_layout.py +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -206,7 +206,7 @@ def __tensor_unflatten__( def get_plain(self): from torchao.sparsity.marlin import ( unpack_from_marlin_24, - ) # avoid circular import + ) int_data_expanded, scales_expanded = unpack_from_marlin_24( self.int_data, @@ -231,7 +231,7 @@ def from_plain( from torchao.sparsity.marlin import ( const, pack_to_marlin_24, - ) # avoid circular import + ) assert isinstance(_layout, MarlinSparseLayout) diff --git a/torchao/utils.py b/torchao/utils.py index f67463f9f7..b2481440c6 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -7,6 +7,7 @@ from math import gcd from typing import Any, Callable, Tuple +import pytest import torch import torch.nn.utils.parametrize as parametrize @@ -161,6 +162,33 @@ def wrapper(*args, **kwargs): return decorator +def skip_if_rocm(message=None): + """Decorator to skip tests on ROCm platform with custom message. + + Args: + message (str, optional): Additional information about why the test is skipped. + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if torch.version.hip is not None: + skip_message = "Skipping the test in ROCm" + if message: + skip_message += f": {message}" + pytest.skip(skip_message) + return func(*args, **kwargs) + + return wrapper + + # Handle both @skip_if_rocm and @skip_if_rocm() syntax + if callable(message): + func = message + message = None + return decorator(func) + return decorator + + def compute_max_diff(output: torch.Tensor, output_ref: torch.Tensor) -> torch.Tensor: return torch.mean(torch.abs(output - output_ref)) / torch.mean( torch.abs(output_ref) @@ -607,7 +635,7 @@ def _torch_version_at_least(min_version): def is_MI300(): if torch.cuda.is_available() and torch.version.hip: mxArchName = ["gfx940", "gfx941", "gfx942"] - archName = torch.cuda.get_device_properties().gcnArchName + archName = torch.cuda.get_device_properties(0).gcnArchName for arch in mxArchName: if arch in archName: return True