Skip to content

Commit

Permalink
enhance documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
uniartisan committed Oct 19, 2024
1 parent baf3e5c commit ae3ae6b
Showing 1 changed file with 37 additions and 2 deletions.
39 changes: 37 additions & 2 deletions docs/source-pytorch/extensions/accelerator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,45 +36,79 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc

.. code-block:: python
import torch
import xpulib
from functools import lru_cache
from typing import Any, Dict, Union
from lightning.pytorch.accelerators.accelerator import Accelerator
from typing_extensions import override
class XPUAccelerator(Accelerator):
"""Support for a hypothetical XPU, optimized for large-scale machine learning."""
@override
def setup_device(self, device: torch.device) -> None:
"""
Raises:
ValueError:
If the selected device is not of type hypothetical XPU.
"""
if device.type != "xpu":
raise ValueError(f"Device should be of type 'xpu', got '{device.type}' instead.")
if device.index is None:
device = torch.device("xpu", 0)
xpulib.set_device(device.index)
@override
def teardown(self) -> None:
xpulib.empty_cache()
@staticmethod
@override
def parse_devices(devices: Any) -> Any:
# Put parsing logic here how devices can be passed into the Trainer
# via the `devices` argument
return devices
@staticmethod
@override
def get_parallel_devices(devices: Any) -> Any:
# Here, convert the device indices to actual device objects
return [torch.device("xpu", idx) for idx in devices]
@staticmethod
@override
def auto_device_count() -> int:
# Return a value for auto-device selection when `Trainer(devices="auto")`
return xpulib.available_devices()
@staticmethod
@override
def is_available() -> bool:
return xpulib.is_available()
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
# Return optional device statistics for loggers
return {}
@staticmethod
@override
def get_device() -> str:
return "xpu"
Finally, add the XPUAccelerator to the Trainer:

.. code-block:: python
from lightning.pytorch import Trainer
from lightning.pytorch.strategies import DDPStrategy
accelerator = XPUAccelerator()
trainer = Trainer(accelerator=accelerator, devices=2)
strategy = DDPStrategy(parallel_devices=accelerator.get_parallel_devices(2))
trainer = Trainer(accelerator=accelerator, strategy=strategy, devices=2)
:doc:`Learn more about Strategies <../extensions/strategy>` and how they interact with the Accelerator.
Expand All @@ -93,6 +127,7 @@ If you wish to switch to a custom accelerator from the CLI without code changes,
...
@classmethod
@override
def register_accelerators(cls, accelerator_registry):
accelerator_registry.register(
"xpu",
Expand Down

0 comments on commit ae3ae6b

Please sign in to comment.