diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 1e27b31af8403..81a07b71f8f44 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -715,10 +715,16 @@ def test_gpu_accelerator_backend_choice_cuda(*_): assert isinstance(connector.accelerator, CUDAAccelerator) +class DeviceMock(Mock): + def __instancecheck__(self, instance): + return True + + @RunIf(min_torch="1.12") @mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=True) @mock.patch("lightning.fabric.accelerators.mps._get_all_available_mps_gpus", return_value=[0]) -def test_gpu_accelerator_backend_choice_mps(*_): +@mock.patch("torch.device", DeviceMock) +def test_gpu_accelerator_backend_choice_mps(*_: object) -> object: connector = _Connector(accelerator="gpu") assert connector._accelerator_flag == "mps" assert isinstance(connector.accelerator, MPSAccelerator) @@ -931,10 +937,6 @@ def _mock_interactive(): assert isinstance(connector.strategy, SingleDeviceStrategy) assert connector._devices_flag == [0] - class DeviceMock(Mock): - def __instancecheck__(self, instance): - return True - # single TPU with no_cuda, no_mps, monkeypatch.context(): mock_tpu_available(monkeypatch, True) diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 7c5402df9c69b..1dc815920be64 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -725,7 +725,7 @@ def test_gpu_accelerator_backend_choice_cuda(cuda_count_1): assert isinstance(trainer.accelerator, CUDAAccelerator) -def test_gpu_accelerator_backend_choice_mps(mps_count_1): +def test_gpu_accelerator_backend_choice_mps(mps_count_1, cuda_count_0): trainer = Trainer(accelerator="gpu") assert trainer._accelerator_connector._accelerator_flag == "mps" assert isinstance(trainer.accelerator, MPSAccelerator) @@ -805,6 +805,11 @@ def test_bagua_external_strategy(monkeypatch): assert isinstance(trainer.strategy, BaguaStrategy) +class DeviceMock(Mock): + def __instancecheck__(self, instance): + return True + + @pytest.mark.parametrize("is_interactive", (False, True)) def test_connector_auto_selection(monkeypatch, is_interactive): import lightning.fabric # avoid breakage with standalone package @@ -873,10 +878,6 @@ def _mock_tpu_available(value): assert isinstance(connector.strategy, SingleDeviceStrategy) assert connector._devices_flag == [0] - class DeviceMock(Mock): - def __instancecheck__(self, instance): - return True - # single TPU with monkeypatch.context(): mock_cuda_count(monkeypatch, 0)