Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Mar 30, 2023
1 parent 037df23 commit ee33a21
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
12 changes: 7 additions & 5 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,10 +715,16 @@ def test_gpu_accelerator_backend_choice_cuda(*_):
assert isinstance(connector.accelerator, CUDAAccelerator)


class DeviceMock(Mock):
def __instancecheck__(self, instance):
return True


@RunIf(min_torch="1.12")
@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=True)
@mock.patch("lightning.fabric.accelerators.mps._get_all_available_mps_gpus", return_value=[0])
def test_gpu_accelerator_backend_choice_mps(*_):
@mock.patch("torch.device", DeviceMock)
def test_gpu_accelerator_backend_choice_mps(*_: object) -> object:
connector = _Connector(accelerator="gpu")
assert connector._accelerator_flag == "mps"
assert isinstance(connector.accelerator, MPSAccelerator)
Expand Down Expand Up @@ -931,10 +937,6 @@ def _mock_interactive():
assert isinstance(connector.strategy, SingleDeviceStrategy)
assert connector._devices_flag == [0]

class DeviceMock(Mock):
def __instancecheck__(self, instance):
return True

# single TPU
with no_cuda, no_mps, monkeypatch.context():
mock_tpu_available(monkeypatch, True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def test_gpu_accelerator_backend_choice_cuda(cuda_count_1):
assert isinstance(trainer.accelerator, CUDAAccelerator)


def test_gpu_accelerator_backend_choice_mps(mps_count_1):
def test_gpu_accelerator_backend_choice_mps(mps_count_1, cuda_count_0):
trainer = Trainer(accelerator="gpu")
assert trainer._accelerator_connector._accelerator_flag == "mps"
assert isinstance(trainer.accelerator, MPSAccelerator)
Expand Down Expand Up @@ -805,6 +805,11 @@ def test_bagua_external_strategy(monkeypatch):
assert isinstance(trainer.strategy, BaguaStrategy)


class DeviceMock(Mock):
def __instancecheck__(self, instance):
return True


@pytest.mark.parametrize("is_interactive", (False, True))
def test_connector_auto_selection(monkeypatch, is_interactive):
import lightning.fabric # avoid breakage with standalone package
Expand Down Expand Up @@ -873,10 +878,6 @@ def _mock_tpu_available(value):
assert isinstance(connector.strategy, SingleDeviceStrategy)
assert connector._devices_flag == [0]

class DeviceMock(Mock):
def __instancecheck__(self, instance):
return True

# single TPU
with monkeypatch.context():
mock_cuda_count(monkeypatch, 0)
Expand Down

0 comments on commit ee33a21

Please sign in to comment.