Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Mar 30, 2023
1 parent a2ac44f commit 037df23
Show file tree
Hide file tree
Showing 11 changed files with 25 additions and 10 deletions.
2 changes: 1 addition & 1 deletion docs/source-pytorch/accelerators/tpu_faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ How to clear up the programs using TPUs in the background?

.. code-block:: bash
lsof -w /lib/libtpu.so | grep "python" | awk '{print $2}' | xargs -r kill -9
pgrep python | awk '{print $2}' | xargs -r kill -9
Sometimes, there can still be old programs running on the TPUs, which would make the TPUs unavailable to use. You could use the above command in the terminal to kill the running processes.

Expand Down
3 changes: 3 additions & 0 deletions src/lightning/fabric/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]:
# it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy

@staticmethod
# XLA's multiprocessing will pop the TPU_NUM_DEVICES key, so we need to cache it
# https://github.com/pytorch/xla/blob/v2.0.0/torch_xla/distributed/xla_multiprocessing.py#L280
@functools.lru_cache(maxsize=1)
def auto_device_count() -> int:
"""Get the devices when set to auto."""
import torch_xla.core.xla_env_vars as xenv
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:
return "xla"
else:
# TODO: lazy initialized device, then here could be self._strategy_flag = "single_tpu_device"
return SingleTPUStrategy(device=self._parallel_devices[0]) # type: ignore
return SingleTPUStrategy(device=self._parallel_devices[0])
if self._num_nodes_flag > 1:
return "ddp"
if len(self._parallel_devices) <= 1:
Expand Down
4 changes: 3 additions & 1 deletion src/lightning/fabric/strategies/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def __init__(
precision: Precision | None = None,
):
super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision)
self._root_device = torch.device(device)
if not isinstance(device, torch.device):
device = torch.device(device)
self._root_device = device
self.global_rank = 0
self.local_rank = 0
self.world_size = 1
Expand Down
4 changes: 3 additions & 1 deletion src/lightning/pytorch/strategies/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def __init__(
precision_plugin: PrecisionPlugin | None = None,
):
super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin)
self._root_device = torch.device(device)
if not isinstance(device, torch.device):
device = torch.device(device)
self._root_device = device
self.global_rank = 0
self.local_rank = 0
self.world_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:
return XLAStrategy.strategy_name
else:
# TODO: lazy initialized device, then here could be self._strategy_flag = "single_tpu_device"
return SingleTPUStrategy(device=self._parallel_devices[0]) # type: ignore
return SingleTPUStrategy(device=self._parallel_devices[0])
if self._num_nodes_flag > 1:
return "ddp"
if len(self._parallel_devices) <= 1:
Expand Down
1 change: 1 addition & 0 deletions tests/tests_fabric/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def reset_cudnn_benchmark():
def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None:
monkeypatch.setattr(lightning.fabric.accelerators.tpu, "_XLA_AVAILABLE", value)
monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value)
monkeypatch.setattr(lightning.fabric.strategies.single_tpu, "_XLA_AVAILABLE", value)
monkeypatch.setattr(lightning.fabric.strategies.xla, "_XLA_AVAILABLE", value)
monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value)

Expand Down
6 changes: 5 additions & 1 deletion tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,11 +931,15 @@ def _mock_interactive():
assert isinstance(connector.strategy, SingleDeviceStrategy)
assert connector._devices_flag == [0]

class DeviceMock(Mock):
def __instancecheck__(self, instance):
return True

# single TPU
with no_cuda, no_mps, monkeypatch.context():
mock_tpu_available(monkeypatch, True)
monkeypatch.setattr(lightning.fabric.accelerators.TPUAccelerator, "auto_device_count", lambda *_: 1)
monkeypatch.setattr(torch, "device", Mock())
monkeypatch.setattr(torch, "device", DeviceMock())
connector = _Connector()
assert isinstance(connector.accelerator, TPUAccelerator)
assert isinstance(connector.strategy, SingleTPUStrategy)
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def test_dataloaders_passed_to_fit(tmpdir):
trainer.fit(model, train_dataloaders=model.train_dataloader(), val_dataloaders=model.val_dataloader())


@pytest.mark.parametrize("devices", [[1, 8], "9, ", [9], [0], 2, 10])
@pytest.mark.parametrize("devices", [[1, 8], "9, ", [9], [-1], 2, 10])
def test_tpu_misconfiguration(devices, tpu_available):
with pytest.raises(ValueError, match="`devices` can only be"):
Trainer(accelerator="tpu", devices=devices)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -873,14 +873,18 @@ def _mock_tpu_available(value):
assert isinstance(connector.strategy, SingleDeviceStrategy)
assert connector._devices_flag == [0]

class DeviceMock(Mock):
def __instancecheck__(self, instance):
return True

# single TPU
with monkeypatch.context():
mock_cuda_count(monkeypatch, 0)
mock_mps_count(monkeypatch, 0)
mock_ipu_available(monkeypatch, False)
_mock_tpu_available(True)
monkeypatch.setattr(lightning.pytorch.accelerators.TPUAccelerator, "auto_device_count", lambda *_: 1)
monkeypatch.setattr(torch, "device", Mock())
monkeypatch.setattr(torch, "device", DeviceMock())
connector = _AcceleratorConnector()
assert isinstance(connector.accelerator, TPUAccelerator)
assert isinstance(connector.strategy, SingleTPUStrategy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from torch.utils.data import DataLoader

from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import TPUAccelerator
from lightning.pytorch.demos.boring_classes import BoringModel, RandomIterableDataset
from lightning.pytorch.strategies.ipu import IPUStrategy
from tests_pytorch.conftest import mock_cuda_count
Expand Down Expand Up @@ -144,7 +143,7 @@ def test_num_stepping_batches_with_tpu_single():

class MultiprocessModel(BoringModel):
def on_train_start(self):
device_count = TPUAccelerator.auto_device_count()
device_count = self.trainer.accelerator.auto_device_count()
assert self.trainer.world_size == device_count
assert self.trainer.estimated_stepping_batches == len(self.train_dataloader()) // device_count

Expand Down

0 comments on commit 037df23

Please sign in to comment.