Skip to content

Commit

Permalink
enhance 3d-party devices in mix-precision
Browse files Browse the repository at this point in the history
  • Loading branch information
uniartisan committed Oct 22, 2024
1 parent 06a8d5b commit 2a89640
Show file tree
Hide file tree
Showing 25 changed files with 187 additions and 24 deletions.
39 changes: 37 additions & 2 deletions docs/source-pytorch/extensions/accelerator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,45 +36,79 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc

.. code-block:: python
import torch
import xpulib
from functools import lru_cache
from typing import Any, Dict, Union
from lightning.pytorch.accelerators.accelerator import Accelerator
from typing_extensions import override
class XPUAccelerator(Accelerator):
"""Support for a hypothetical XPU, optimized for large-scale machine learning."""
@override
def setup_device(self, device: torch.device) -> None:
"""
Raises:
ValueError:
If the selected device is not of type hypothetical XPU.
"""
if device.type != "xpu":
raise ValueError(f"Device should be of type 'xpu', got '{device.type}' instead.")
if device.index is None:
device = torch.device("xpu", 0)
xpulib.set_device(device.index)
@override
def teardown(self) -> None:
xpulib.empty_cache()
@staticmethod
@override
def parse_devices(devices: Any) -> Any:
# Put parsing logic here how devices can be passed into the Trainer
# via the `devices` argument
return devices
@staticmethod
@override
def get_parallel_devices(devices: Any) -> Any:
# Here, convert the device indices to actual device objects
return [torch.device("xpu", idx) for idx in devices]
@staticmethod
@override
def auto_device_count() -> int:
# Return a value for auto-device selection when `Trainer(devices="auto")`
return xpulib.available_devices()
@staticmethod
@override
def is_available() -> bool:
return xpulib.is_available()
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
# Return optional device statistics for loggers
return {}
@staticmethod
@override
def get_device() -> str:
return "xpu"
Finally, add the XPUAccelerator to the Trainer:

.. code-block:: python
from lightning.pytorch import Trainer
from lightning.pytorch.strategies import DDPStrategy
accelerator = XPUAccelerator()
trainer = Trainer(accelerator=accelerator, devices=2)
strategy = DDPStrategy(parallel_devices=accelerator.get_parallel_devices(2))
trainer = Trainer(accelerator=accelerator, strategy=strategy, devices=2)
:doc:`Learn more about Strategies <../extensions/strategy>` and how they interact with the Accelerator.
Expand All @@ -93,6 +127,7 @@ If you wish to switch to a custom accelerator from the CLI without code changes,
...
@classmethod
@override
def register_accelerators(cls, accelerator_registry):
accelerator_registry.register(
"xpu",
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/fabric/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def parse_devices(devices: Any) -> Any:
def get_parallel_devices(devices: Any) -> Any:
"""Gets parallel devices for the Accelerator."""

@staticmethod
@abstractmethod
def get_device() -> Any:
"""Get the device for the current Accelerator."""

@staticmethod
@abstractmethod
def auto_device_count() -> int:
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/fabric/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.devi
"""Gets parallel devices for the Accelerator."""
devices = _parse_cpu_cores(devices)
return [torch.device("cpu")] * devices

@staticmethod
@override
def get_device() -> str:
return "cpu"

@staticmethod
@override
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/fabric/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def get_parallel_devices(devices: List[int]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
return [torch.device("cuda", i) for i in devices]

@staticmethod
@override
def get_device() -> str:
return "cuda"

@staticmethod
@override
def auto_device_count() -> int:
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/fabric/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.devi
assert parsed_devices is not None
return [torch.device("mps", i) for i in range(len(parsed_devices))]

@staticmethod
@override
def get_device() -> str:
return "mps"

@staticmethod
@override
def auto_device_count() -> int:
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/fabric/accelerators/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]:
# accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
# it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy

@staticmethod
@override
def get_device() -> str:
return "xla"

@staticmethod
@override
# XLA's multiprocessing will pop the TPU_NUM_DEVICES key, so we need to cache it
Expand Down
9 changes: 8 additions & 1 deletion src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def __init__(
self._accelerator_flag = self._choose_auto_accelerator()
elif self._accelerator_flag == "gpu":
self._accelerator_flag = self._choose_gpu_accelerator_backend()
elif isinstance(self._accelerator_flag, Accelerator):
pass # for 3rd party accelerator, just do nothing

self._set_parallel_devices_and_init_accelerator()

Expand Down Expand Up @@ -461,7 +463,10 @@ def _check_and_init_precision(self) -> Precision:
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecision(self._precision_input) # type: ignore
if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecision(precision=self._precision_input) # type: ignore[arg-type]
return FSDPPrecision(
precision=self._precision_input, # type: ignore[arg-type]
device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None,
)
mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported:
raise ValueError(
Expand Down Expand Up @@ -493,6 +498,8 @@ def _check_and_init_precision(self) -> Precision:
else "Using bfloat16 Automatic Mixed Precision (AMP)"
)
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
if isinstance(self._accelerator_flag, Accelerator):
device = self._accelerator_flag.get_device()
return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type]

raise RuntimeError("No precision set")
Expand Down
10 changes: 9 additions & 1 deletion src/lightning/fabric/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,15 @@ def __init__(

self.precision = precision
if scaler is None and self.precision == "16-mixed":
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
scaler = (
torch.amp.GradScaler(device=device)
if _TORCH_GREATER_EQUAL_2_4
else getattr(
torch,
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu"
else device.split(":")[0]
).amp.GradScaler()
)
if scaler is not None and self.precision == "bf16-mixed":
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
Expand Down
9 changes: 7 additions & 2 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,16 @@ class FSDPPrecision(Precision):
"""

def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None:
def __init__(
self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: Optional[str] = None
) -> None:
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`precision={precision!r})` is not supported in FSDP."
f" `precision` must be one of: {supported_precision}."
)
self.device = device if device is not None else "cuda"

from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

Expand Down Expand Up @@ -110,7 +113,9 @@ def module_init_context(self) -> ContextManager:
@override
def forward_context(self) -> ContextManager:
if "mixed" in self.precision:
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
return torch.autocast(
self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)
)
return self.tensor_init_context()

@override
Expand Down
8 changes: 7 additions & 1 deletion src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,13 @@ def setup_module(self, module: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self._determine_ddp_device_ids()
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
ctx = (
getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(
getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()
)
if device_ids is not None
else nullcontext()
)
with ctx:
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)

Expand Down
12 changes: 8 additions & 4 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,9 @@ def load_checkpoint(

optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())

torch.cuda.empty_cache()
getattr(
torch, f"{self.root_device.type.split(':')[0]}"
).empty_cache() if self.root_device.type != "cpu" else None
_, client_state = engine.load_checkpoint(
path,
tag="checkpoint",
Expand Down Expand Up @@ -616,10 +618,12 @@ def _initialize_engine(

@override
def setup_environment(self) -> None:
if not isinstance(self.accelerator, CUDAAccelerator):
from deepspeed.runtime.utils import get_accelerator
if (not isinstance(self.accelerator, CUDAAccelerator)) and \
self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr]
raise RuntimeError(
f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`"
" is used."
f"The DeepSpeed strategy is only supported on {get_accelerator().device_name()} GPUs,"
f"but `{self.accelerator.__class__.__name__}` is used."
)
super().setup_environment()

Expand Down
4 changes: 3 additions & 1 deletion src/lightning/fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ def load_checkpoint(
given, the full checkpoint will be returned.
"""
torch.cuda.empty_cache()
getattr(
torch, f"{self.root_device.type.split(':')[0]}"
).empty_cache() if self.root_device.type != "cpu" else None
checkpoint = self.checkpoint_io.load_checkpoint(path)
if not state:
return checkpoint
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,8 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
"""
raise NotImplementedError

@staticmethod
def get_device() -> str:
"""Get the device for the current process."""
raise NotImplementedError
5 changes: 5 additions & 0 deletions src/lightning/pytorch/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
description=cls.__name__,
)

@staticmethod
@override
def get_device() -> str:
return "cpu"


# CPU device metrics
_CPU_VM_PERCENT = "cpu_vm_percent"
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
description=cls.__name__,
)

@staticmethod
@override
def get_device() -> str:
return "cuda"


def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
description=cls.__name__,
)

@staticmethod
@override
def get_device() -> str:
return "mps"


# device metrics
_VM_PERCENT = "M1_vm_percent"
Expand Down
10 changes: 9 additions & 1 deletion src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,15 @@ def __init__(

self.precision = precision
if scaler is None and self.precision == "16-mixed":
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
scaler = (
torch.amp.GradScaler(device=device)
if _TORCH_GREATER_EQUAL_2_4
else getattr(
torch,
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu"
else device.split(":")[0]
).amp.GradScaler()
)
if scaler is not None and self.precision == "bf16-mixed":
raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
Expand Down
9 changes: 7 additions & 2 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,16 @@ class FSDPPrecision(Precision):
"""

def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None:
def __init__(
self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: Optional[str] = None
) -> None:
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`precision={precision!r})` is not supported in FSDP."
f" `precision` must be one of: {supported_precision}."
)
self.device = device if device is not None else "cuda"

from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

Expand Down Expand Up @@ -119,7 +122,9 @@ def module_init_context(self) -> ContextManager:
@override
def forward_context(self) -> ContextManager:
if "mixed" in self.precision:
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
return torch.autocast(
self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)
)
return _DtypeContextManager(self._desired_input_dtype)

@override
Expand Down
8 changes: 7 additions & 1 deletion src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,13 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
device_ids = self.determine_ddp_device_ids()
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
ctx = (
getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(
getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()
)
if device_ids is not None
else nullcontext()
)
with ctx:
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)

Expand Down
Loading

0 comments on commit 2a89640

Please sign in to comment.