Skip to content

Commit

Permalink
[TPU] v4 support (#17227)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Apr 11, 2023
1 parent 7b8fd85 commit 0489f2e
Show file tree
Hide file tree
Showing 32 changed files with 249 additions and 196 deletions.
2 changes: 1 addition & 1 deletion docs/source-pytorch/accelerators/tpu_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Example:
model = WeightSharingModule()
trainer = Trainer(max_epochs=1, accelerator="tpu", devices=8)
trainer = Trainer(max_epochs=1, accelerator="tpu")
See `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_

Expand Down
6 changes: 3 additions & 3 deletions docs/source-pytorch/accelerators/tpu_basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ To run on different cores, modify the ``devices`` argument.
trainer = Trainer(accelerator="tpu", devices=1)
# run on multiple TPU cores
trainer = Trainer(accelerator="tpu", devices=8)
# run on the 5th core
trainer = Trainer(accelerator="tpu", devices=[5])
# run on one specific TPU core: the 2nd core (index 1)
trainer = Trainer(accelerator="tpu", devices=[1])
# choose the number of cores automatically
trainer = Trainer(accelerator="tpu", devices="auto")
Expand Down Expand Up @@ -92,7 +92,7 @@ To get a TPU on colab, follow these steps:

Google Cloud (GCP)
^^^^^^^^^^^^^^^^^^
You could refer to this `page <https://cloud.google.com/tpu/docs/setup-gcp-account>`_ for getting started with Cloud TPU resources on GCP.
You could refer to this `page <https://cloud.google.com/tpu/docs/v4-users-guide>`_ for getting started with Cloud TPU resources on GCP.

Kaggle
^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/accelerators/tpu_faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ How to clear up the programs using TPUs in the background?

.. code-block:: bash
lsof -w /lib/libtpu.so | grep "python" | awk '{print $2}' | xargs -r kill -9
pgrep python | awk '{print $2}' | xargs -r kill -9
Sometimes, there can still be old programs running on the TPUs, which would make the TPUs unavailable to use. You could use the above command in the terminal to kill the running processes.

Expand Down
20 changes: 13 additions & 7 deletions docs/source-pytorch/accelerators/tpu_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,26 @@ better performance and usability while working with TPUs.

The TPUVMs come pre-installed with latest versions of PyTorch and PyTorch XLA.
After connecting to the VM and before running your Lightning code, you would need
to set the XRT TPU device configuration.
to set the `XRT TPU device configuration <https://cloud.google.com/tpu/docs/v4-users-guide#train_ml_workloads_with_pytorch_xla>`__.

.. code-block:: bash
$ export XRT_TPU_CONFIG="localservice;0;localhost:51011"
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
You could learn more about the Cloud TPU VM architecture `here <https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_vms_3>`_
# Set the environment variable to visible devices.
# You might need to change the value depending on how many chips you have
export TPU_NUM_DEVICES=4
# Allow LIBTPU LOAD by multiple processes
export ALLOW_MULTIPLE_LIBTPU_LOAD=1
You can learn more about the Cloud TPU VM architecture `here <https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_vms_3>`_

----------------

TPU Pod
-------
To train on more than 8 cores, your code actually doesn't change!
To train on more than the number of cores in a node, your code actually doesn't change!

All TPU VMs in a Pod setup are required to access the model code and data.
One easy way to achieve this is to use the following startup script when creating the TPU VM pod.
Expand All @@ -102,15 +109,14 @@ on how to set up the instance groups and VMs needed to run TPU Pods.
16 bit precision
----------------
Lightning also supports training in 16-bit precision with TPUs.
By default, TPU training will use 32-bit precision. To enable 16-bit,
set the 16-bit flag.
By default, TPU training will use 32-bit precision. To enable it, do

.. code-block:: python
import lightning.pytorch as pl
my_model = MyLightningModule()
trainer = pl.Trainer(accelerator="tpu", devices=8, precision=16)
trainer = pl.Trainer(accelerator="tpu", precision="16-mixed")
trainer.fit(my_model)
Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.
11 changes: 9 additions & 2 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for the TPU-v4 architecture ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))


- Check for invalid TPU device inputs ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))


- Added support for joint setup of model and optimizer with FSDP ([#17305](https://github.com/Lightning-AI/lightning/pull/17305))
- Added support for handling multiple parameter groups in optimizers set up with FSDP ([#17305](https://github.com/Lightning-AI/lightning/pull/17305))


- Added support for handling multiple parameter groups in optimizers set up with FSDP ([#17305](https://github.com/Lightning-AI/lightning/pull/17305))

### Changed

- Generalized `Optimizer` validation to accommodate both FSDP 1.x and 2.x ([#16733](https://github.com/Lightning-AI/lightning/pull/16733))
Expand All @@ -32,7 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed issue where running on TPUs would select the wrong device index ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))


## [2.0.0] - 2023-03-15
Expand Down
83 changes: 47 additions & 36 deletions src/lightning/fabric/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import queue as q
import traceback
from multiprocessing import Process, Queue
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Union

import torch
from lightning_utilities.core.imports import ModuleAvailableCache
Expand All @@ -42,21 +42,35 @@ def teardown(self) -> None:
pass

@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]:
def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]:
"""Accelerator device parsing logic."""
return _parse_tpu_devices(devices)

@staticmethod
def get_parallel_devices(devices: Union[int, List[int]]) -> List[int]:
def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = _parse_tpu_devices(devices)
# In XLA index 0 maps to CPU, in fact, a `xla_device()` with no arguments has index 1
# since the user passes a 0-based index, we need to adjust the indices
if isinstance(devices, int):
return list(range(devices))
return devices
return [torch.device("xla", i) for i in range(1, devices + 1)]
else:
# list of devices is not supported, just a specific index, fine to access [0]
return [torch.device("xla", devices[0] + 1)]
# we cannot create `xla_device` here because processes have not been spawned yet (this is called in the
# accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
# it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy

@staticmethod
# XLA's multiprocessing will pop the TPU_NUM_DEVICES key, so we need to cache it
# https://github.com/pytorch/xla/blob/v2.0.0/torch_xla/distributed/xla_multiprocessing.py#L280
@functools.lru_cache(maxsize=1)
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return 8
import torch_xla.core.xla_env_vars as xenv
from torch_xla.utils.utils import getenv_as

return getenv_as(xenv.TPU_NUM_DEVICES, int, 8)

@staticmethod
@functools.lru_cache(maxsize=1)
Expand Down Expand Up @@ -131,52 +145,49 @@ def _tpu_distributed() -> bool:
return xm.xrt_world_size() > 1


def _parse_tpu_devices(devices: Optional[Union[int, str, List[int]]]) -> Optional[Union[int, List[int]]]:
def _parse_tpu_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]:
"""
Parses the TPU devices given in the format as accepted by the
:class:`~lightning.pytorch.trainer.Trainer` and :class:`~lightning.fabric.Fabric`.
Args:
devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used
An int 8 or string '8' indicates that all 8 cores with multi-processing should be used
A list of ints or a strings containing a list of comma separated integers
indicates the specific TPU core to use.
A single element list of int or string can be used to indicate the specific TPU core to use.
Returns:
A list of tpu_cores to be used or ``None`` if no TPU cores were requested
Raises:
TypeError:
If TPU devices aren't 1, 8 or [<1-8>]
A list of tpu cores to be used.
"""
_check_data_type(devices)

if isinstance(devices, str):
devices = _parse_tpu_devices_str(devices.strip())

if not _tpu_devices_valid(devices):
raise TypeError("`devices` can only be 1, 8 or [<1-8>] for TPUs.")

devices = _parse_tpu_devices_str(devices)
_check_tpu_devices_valid(devices)
return devices


def _tpu_devices_valid(devices: Any) -> bool:
# allow 1 or 8 cores
if devices in (1, 8, None):
return True

# allow picking 1 of 8 indexes
if isinstance(devices, (list, tuple, set)):
has_1_tpu_idx = len(devices) == 1
is_valid_tpu_idx = 1 <= list(devices)[0] <= 8

is_valid_tpu_core_choice = has_1_tpu_idx and is_valid_tpu_idx
return is_valid_tpu_core_choice

return False
def _check_tpu_devices_valid(devices: object) -> None:
device_count = TPUAccelerator.auto_device_count()
if (
# support number of devices
isinstance(devices, int)
and devices in {1, device_count}
# support picking a specific device
or isinstance(devices, (list, tuple))
and len(devices) == 1
and 0 <= devices[0] <= device_count - 1
):
return
raise ValueError(
f"`devices` can only be 'auto', 1, {device_count} or [<0-{device_count - 1}>] for TPUs. Got {devices!r}"
)


def _parse_tpu_devices_str(devices: str) -> Union[int, List[int]]:
if devices in ("1", "8"):
devices = devices.strip()
try:
return int(devices)
return [int(x.strip()) for x in devices.split(",") if len(x) > 0]
except ValueError:
try:
return [int(x.strip()) for x in devices.split(",") if len(x) > 0]
except ValueError:
raise ValueError(f"Could not parse the selected TPU devices: {devices!r}")
2 changes: 1 addition & 1 deletion src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:
return "xla"
else:
# TODO: lazy initialized device, then here could be self._strategy_flag = "single_tpu_device"
return SingleTPUStrategy(device=self._parallel_devices[0]) # type: ignore
return SingleTPUStrategy(device=self._parallel_devices[0])
if self._num_nodes_flag > 1:
return "ddp"
if len(self._parallel_devices) <= 1:
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/fabric/plugins/environments/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,6 @@ def local_rank(self) -> int:

def node_rank(self) -> int:
import torch_xla.core.xla_env_vars as xenv
from torch_xla.utils.utils import getenv_as

return int(os.environ.get(xenv.HOST_ORDINAL, 0))
return getenv_as(xenv.HOST_ORDINAL, int, 0)
4 changes: 3 additions & 1 deletion src/lightning/fabric/strategies/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def __init__(
precision: Precision | None = None,
):
super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision)
self._root_device = torch.device(device)
if not isinstance(device, torch.device):
device = torch.device(device)
self._root_device = device
self.global_rank = 0
self.local_rank = 0
self.world_size = 1
Expand Down
12 changes: 11 additions & 1 deletion src/lightning/fabric/strategies/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,33 @@
# limitations under the License.
from typing import Dict, Optional

import torch

from lightning.fabric.accelerators import Accelerator
from lightning.fabric.accelerators.tpu import _XLA_AVAILABLE
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies.single_device import SingleDeviceStrategy
from lightning.fabric.utilities.types import _DEVICE


class SingleTPUStrategy(SingleDeviceStrategy):
"""Strategy for training on a single TPU device."""

def __init__(
self,
device: int,
device: _DEVICE,
accelerator: Optional[Accelerator] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
):
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
if isinstance(device, torch.device):
# unwrap the `torch.device` in favor of `xla_device`
device = device.index

import torch_xla.core.xla_model as xm

super().__init__(
Expand Down
21 changes: 10 additions & 11 deletions src/lightning/fabric/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, MutableSequence, Optional, Tuple, Union
from typing import List, MutableSequence, Optional, Tuple, Union

import lightning.fabric.accelerators as accelerators # avoid circular dependency
from lightning.fabric.plugins.environments.torchelastic import TorchElasticEnvironment
Expand Down Expand Up @@ -179,23 +179,22 @@ def _check_unique(device_ids: List[int]) -> None:
raise MisconfigurationException("Device ID's (GPU) must be unique.")


def _check_data_type(device_ids: Any) -> None:
"""Checks that the device_ids argument is one of the following: None, int, string, or sequence of integers.
def _check_data_type(device_ids: object) -> None:
"""Checks that the device_ids argument is one of the following: int, string, or sequence of integers.
Args:
device_ids: gpus/tpu_cores parameter as passed to the Trainer
Raises:
MisconfigurationException:
If ``device_ids`` of GPU/TPUs aren't ``int``, ``str``, sequence of ``int`` or ``None``
TypeError:
If ``device_ids`` of GPU/TPUs aren't ``int``, ``str`` or sequence of ``int```
"""
msg = "Device IDs (GPU/TPU) must be an int, a string, a sequence of ints or None, but you passed"

msg = "Device IDs (GPU/TPU) must be an int, a string, a sequence of ints, but you passed"
if device_ids is None:
return
elif isinstance(device_ids, (MutableSequence, tuple)):
raise TypeError(f"{msg} None")
if isinstance(device_ids, (MutableSequence, tuple)):
for id_ in device_ids:
if type(id_) is not int:
raise MisconfigurationException(f"{msg} a sequence of {type(id_).__name__}.")
raise TypeError(f"{msg} a sequence of {type(id_).__name__}.")
elif type(device_ids) not in (int, str):
raise MisconfigurationException(f"{msg} {type(device_ids).__name__}.")
raise TypeError(f"{msg} {device_ids!r}.")
8 changes: 8 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for the TPU-v4 architecture ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))


- Check for invalid TPU device inputs ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))


- Added support for multiple optimizer parameter groups when using the FSDP strategy ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))


Expand All @@ -23,6 +29,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* It now supports neptune-client 0.16.16 and neptune >=1.0, and we have replaced the `log()` method with `append()` and `extend()`.
* It now accepts a namespace `Handler` as an alternative to `Run` for the `run` argument. This means that you can call it like `NeptuneLogger(run=run["some/namespace"])` to log everything to the `some/namespace/` location of the run.

- `Trainer(accelerator="tpu", devices=[i])"` now selects the i-th TPU core (0-based, previously it was 1-based) ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))


- Pickling the `LightningModule` no longer pickles the `Trainer` ([#17133](https://github.com/Lightning-AI/lightning/pull/17133))

Expand Down
Loading

0 comments on commit 0489f2e

Please sign in to comment.