From af90eb9968d745c3a53595d090bda38551ee9749 Mon Sep 17 00:00:00 2001 From: "Xu, Jing" Date: Fri, 26 May 2023 20:11:26 +0900 Subject: [PATCH] add xpu support [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci update typos and bug fixes [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci xpu seeding PR1 [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci add seeding for pytorch utilities mp_fabric xpu forking xpu multiprocess pytorch add header for xpu rename change to lightning.pytorch [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Teardown from lightning-xpu (from #PR- 3) From #3 [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci add torch.xpu.stream to ddp update docs [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci update _LIGHTNING_XPU_AVAILABLE to _lightning_xpu_available correct fabric imports.py 1. remove xpu.py from _graveyard 2. correct _lightning_xpu_available() usage fix _try_import function not defined issue in fabric add docs [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source-fabric/fundamentals/launch.rst | 3 +- docs/source-pytorch/common/index.rst | 8 ++++ docs/source-pytorch/common_usecases.rst | 7 ++++ docs/source-pytorch/conf.py | 6 +++ .../source-pytorch/extensions/accelerator.rst | 31 +++++++------- docs/source-pytorch/glossary/index.rst | 8 ++++ .../source-pytorch/integrations/xpu/index.rst | 40 ++++++++++++++++++ .../levels/advanced_level_23.rst | 37 ++++++++++++++++ requirements/_integrations/accelerators.txt | 3 ++ src/lightning/fabric/accelerators/__init__.py | 10 +++++ src/lightning/fabric/cli.py | 11 ++++- src/lightning/fabric/connector.py | 38 +++++++++++++++-- src/lightning/fabric/strategies/ddp.py | 13 ++++-- src/lightning/fabric/strategies/deepspeed.py | 22 ++++++++-- .../strategies/launchers/multiprocessing.py | 27 +++++++++++- .../fabric/utilities/device_parser.py | 42 ++++++++++++++----- src/lightning/fabric/utilities/distributed.py | 9 +++- src/lightning/fabric/utilities/imports.py | 27 ++++++++++-- src/lightning/fabric/utilities/seed.py | 14 ++++++- src/lightning/pytorch/strategies/ddp.py | 13 ++++-- src/lightning/pytorch/strategies/deepspeed.py | 3 +- .../strategies/launchers/multiprocessing.py | 7 ++++ .../connectors/accelerator_connector.py | 24 +++++++++++ src/lightning/pytorch/trainer/setup.py | 19 ++++++++- src/lightning/pytorch/trainer/trainer.py | 2 +- src/lightning/pytorch/utilities/imports.py | 7 ++++ src/lightning/pytorch/utilities/seed.py | 4 +- src/pytorch_lightning/README.md | 4 +- 28 files changed, 380 insertions(+), 59 deletions(-) create mode 100644 docs/source-pytorch/integrations/xpu/index.rst create mode 100644 docs/source-pytorch/levels/advanced_level_23.rst diff --git a/docs/source-fabric/fundamentals/launch.rst b/docs/source-fabric/fundamentals/launch.rst index bc4a5bafd42db7..5f231166b43c25 100644 --- a/docs/source-fabric/fundamentals/launch.rst +++ b/docs/source-fabric/fundamentals/launch.rst @@ -93,8 +93,9 @@ This is essentially the same as running ``python path/to/your/script.py``, but i itself and are expected to be parsed there. Options: - --accelerator [cpu|gpu|cuda|mps|tpu] + --accelerator [cpu|gpu|cuda|mps|tpu|xpu] The hardware accelerator to run on. + Install Lightning-XPU to enable ``xpu``. --strategy [ddp|dp|deepspeed] Strategy for how to run across multiple devices. --devices TEXT Number of devices to run on (``int``), which diff --git a/docs/source-pytorch/common/index.rst b/docs/source-pytorch/common/index.rst index 03647c70b9caad..99615491bda87f 100644 --- a/docs/source-pytorch/common/index.rst +++ b/docs/source-pytorch/common/index.rst @@ -17,6 +17,7 @@ ../advanced/model_parallel Train on single or multiple GPUs <../accelerators/gpu> Train on single or multiple HPUs <../integrations/hpu/index> + Train on single or multiple XPUs <../integrations/xpu/index> Train on single or multiple IPUs <../accelerators/ipu> Train on single or multiple TPUs <../accelerators/tpu> Train on MPS <../accelerators/mps> @@ -168,6 +169,13 @@ How-to Guides :col_css: col-md-4 :height: 180 +.. displayitem:: + :header: Train on single or multiple XPUs + :description: Train models faster with XPU accelerators + :button_link: ../integrations/xpu/index.html + :col_css: col-md-4 + :height: 180 + .. displayitem:: :header: Train on single or multiple IPUs :description: Train models faster with IPU accelerators diff --git a/docs/source-pytorch/common_usecases.rst b/docs/source-pytorch/common_usecases.rst index 9bbc9856b547ac..de4f926db1d25c 100644 --- a/docs/source-pytorch/common_usecases.rst +++ b/docs/source-pytorch/common_usecases.rst @@ -133,6 +133,13 @@ Customize and extend Lightning for things like custom hardware or distributed st :button_link: integrations/hpu/index.html :height: 100 +.. displayitem:: + :header: Train on single or multiple XPUs + :description: Train models faster with XPUs. + :col_css: col-md-12 + :button_link: integrations/xpu/index.html + :height: 100 + .. displayitem:: :header: Train on single or multiple IPUs :description: Train models faster with IPUs. diff --git a/docs/source-pytorch/conf.py b/docs/source-pytorch/conf.py index 8f386a7f495da2..2848713de67e26 100644 --- a/docs/source-pytorch/conf.py +++ b/docs/source-pytorch/conf.py @@ -96,6 +96,11 @@ def _load_py_module(name: str, location: str) -> ModuleType: target_dir="docs/source-pytorch/integrations/hpu", checkout="tags/1.0.0", ) +assist_local.AssistantCLI.pull_docs_files( + gh_user_repo="Lightning-AI/lightning-XPU", + target_dir="docs/source-pytorch/integrations/xpu", + checkout="tags/1.0.0", +) if not _FAST_DOCS_DEV: fetch_external_assets( @@ -324,6 +329,7 @@ def _load_py_module(name: str, location: str) -> ModuleType: "torchmetrics": ("https://torchmetrics.readthedocs.io/en/stable/", None), "graphcore": ("https://docs.graphcore.ai/en/latest/", None), "habana": ("https://lightning-ai.github.io/lightning-Habana/", None), + "intel-xpu": ("https://lightning-ai.github.io/lightning-XPU/", None), } # -- Options for todo extension ---------------------------------------------- diff --git a/docs/source-pytorch/extensions/accelerator.rst b/docs/source-pytorch/extensions/accelerator.rst index 45f4b72500c381..891c919091ae98 100644 --- a/docs/source-pytorch/extensions/accelerator.rst +++ b/docs/source-pytorch/extensions/accelerator.rst @@ -12,6 +12,7 @@ Currently there are accelerators for: - :doc:`TPU <../accelerators/tpu>` - :doc:`IPU <../accelerators/ipu>` - :doc:`HPU <../integrations/hpu/index>` +- :doc:`XPU <../integrations/xpu/index>` - :doc:`MPS <../accelerators/mps>` The Accelerator is part of the Strategy which manages communication across multiple devices (distributed communication). @@ -32,16 +33,16 @@ Create a Custom Accelerator .. warning:: This is an :ref:`experimental ` feature. Here is how you create a new Accelerator. -Let's pretend we want to integrate the fictional XPU accelerator and we have access to its hardware through a library -``xpulib``. +Let's pretend we want to integrate the fictional YPU accelerator and we have access to its hardware through a library +``ypulib``. .. code-block:: python - import xpulib + import ypulib - class XPUAccelerator(Accelerator): - """Support for a hypothetical XPU, optimized for large-scale machine learning.""" + class YPUAccelerator(Accelerator): + """Support for a hypothetical YPU, optimized for large-scale machine learning.""" @staticmethod def parse_devices(devices: Any) -> Any: @@ -52,29 +53,29 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc @staticmethod def get_parallel_devices(devices: Any) -> Any: # Here, convert the device indices to actual device objects - return [torch.device("xpu", idx) for idx in devices] + return [torch.device("ypu", idx) for idx in devices] @staticmethod def auto_device_count() -> int: # Return a value for auto-device selection when `Trainer(devices="auto")` - return xpulib.available_devices() + return ypulib.available_devices() @staticmethod def is_available() -> bool: - return xpulib.is_available() + return ypulib.is_available() def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: # Return optional device statistics for loggers return {} -Finally, add the XPUAccelerator to the Trainer: +Finally, add the YPUAccelerator to the Trainer: .. code-block:: python from lightning.pytorch import Trainer - accelerator = XPUAccelerator() + accelerator = YPUAccelerator() trainer = Trainer(accelerator=accelerator, devices=2) @@ -90,28 +91,28 @@ If you wish to switch to a custom accelerator from the CLI without code changes, .. code-block:: python - class XPUAccelerator(Accelerator): + class YPUAccelerator(Accelerator): ... @classmethod def register_accelerators(cls, accelerator_registry): accelerator_registry.register( - "xpu", + "ypu", cls, - description=f"XPU Accelerator - optimized for large-scale machine learning.", + description=f"YPU Accelerator - optimized for large-scale machine learning.", ) Now, this is possible: .. code-block:: python - trainer = Trainer(accelerator="xpu") + trainer = Trainer(accelerator="ypu") Or if you are using the Lightning CLI, for example: .. code-block:: bash - python train.py fit --trainer.accelerator=xpu --trainer.devices=2 + python train.py fit --trainer.accelerator=ypu --trainer.devices=2 ---------- diff --git a/docs/source-pytorch/glossary/index.rst b/docs/source-pytorch/glossary/index.rst index d632ecca93024b..5b72511a52258b 100644 --- a/docs/source-pytorch/glossary/index.rst +++ b/docs/source-pytorch/glossary/index.rst @@ -18,6 +18,7 @@ GPU <../accelerators/gpu> Half precision <../common/precision> HPU <../integrations/hpu/index> + XPU <../integrations/xpu/index> Inference <../deploy/production_intermediate> IPU <../accelerators/ipu> Lightning CLI <../cli/lightning_cli> @@ -159,6 +160,13 @@ Glossary :button_link: ../integrations/hpu/index.html :height: 100 +.. displayitem:: + :header: XPU + :description: IntelĀ® Graphics Cards for faster training + :col_css: col-md-12 + :button_link: ../integrations/xpu/index.html + :height: 100 + .. displayitem:: :header: Inference :description: Making predictions by applying a trained model to unlabeled examples diff --git a/docs/source-pytorch/integrations/xpu/index.rst b/docs/source-pytorch/integrations/xpu/index.rst new file mode 100644 index 00000000000000..3fb22d6e36541d --- /dev/null +++ b/docs/source-pytorch/integrations/xpu/index.rst @@ -0,0 +1,40 @@ +.. _xpu: + +Accelerator: XPU training +========================= + +.. raw:: html + +
+
+ +.. Add callout items below this line + +.. displayitem:: + :header: Basic + :description: Learn the basics of single and multi-XPU core training. + :col_css: col-md-4 + :button_link: basic.html + :height: 150 + :tag: basic + +.. displayitem:: + :header: Intermediate + :description: Enable state-of-the-art scaling with advanced mix-precision settings. + :col_css: col-md-4 + :button_link: intermediate.html + :height: 150 + :tag: intermediate + +.. displayitem:: + :header: Advanced + :description: Explore state-of-the-art scaling with additional advanced configurations. + :col_css: col-md-4 + :button_link: advanced.html + :height: 150 + :tag: advanced + +.. raw:: html + +
+
diff --git a/docs/source-pytorch/levels/advanced_level_23.rst b/docs/source-pytorch/levels/advanced_level_23.rst new file mode 100644 index 00000000000000..895f4c538398c6 --- /dev/null +++ b/docs/source-pytorch/levels/advanced_level_23.rst @@ -0,0 +1,37 @@ +:orphan: + +###################### +Level 19: Explore XPUs +###################### + +Explore IntelĀ® Graphics Cards (XPU) for model scaling. + +---- + +.. raw:: html + +
+
+ +.. Add callout items below this line + +.. displayitem:: + :header: Train models on XPUs + :description: Learn the basics of single and multi-XPU core training. + :col_css: col-md-6 + :button_link: ../integrations/xpu/basic.html + :height: 150 + :tag: basic + +.. displayitem:: + :header: Optimize models training on XPUs + :description: Enable state-of-the-art scaling with advanced mixed-precision settings. + :col_css: col-md-6 + :button_link: ../integrations/xpu/intermediate.html + :height: 150 + :tag: intermediate + +.. raw:: html + +
+
diff --git a/requirements/_integrations/accelerators.txt b/requirements/_integrations/accelerators.txt index d981df11be4a02..cd28fb73f778ee 100644 --- a/requirements/_integrations/accelerators.txt +++ b/requirements/_integrations/accelerators.txt @@ -1,3 +1,6 @@ # validation HPU connectors lightning-habana >=0.1.0 lightning-graphcore >=0.1.0.rc4 + +# validation XPU connectors +lightning-xpu >=0.1.0 diff --git a/src/lightning/fabric/accelerators/__init__.py b/src/lightning/fabric/accelerators/__init__.py index 3d4b43f75c7626..c2ce0b673378c5 100644 --- a/src/lightning/fabric/accelerators/__init__.py +++ b/src/lightning/fabric/accelerators/__init__.py @@ -22,3 +22,13 @@ ACCELERATOR_REGISTRY = _AcceleratorRegistry() _register_classes(ACCELERATOR_REGISTRY, "register_accelerators", sys.modules[__name__], Accelerator) + +from lightning.fabric.utilities.imports import _lightning_xpu_available + +_ACCELERATORS_BASE_MODULE = "lightning.fabric.accelerators" +ACCELERATOR_REGISTRY = _AcceleratorRegistry() +call_register_accelerators(ACCELERATOR_REGISTRY, _ACCELERATORS_BASE_MODULE) +if _lightning_xpu_available() and "xpu" not in ACCELERATOR_REGISTRY: + from lightning_xpu.fabric import XPUAccelerator + + XPUAccelerator.register_accelerators(ACCELERATOR_REGISTRY) diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 24943a528b47e9..c8c96b5d3892fd 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -24,12 +24,15 @@ from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS from lightning.fabric.strategies import STRATEGY_REGISTRY from lightning.fabric.utilities.device_parser import _parse_gpu_ids +from lightning.fabric.utilities.imports import _lightning_xpu_available _log = logging.getLogger(__name__) _CLICK_AVAILABLE = RequirementCache("click") -_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") +_SUPPORTED_ACCELERATORS = ["cpu", "gpu", "cuda", "mps", "tpu"] +if _lightning_xpu_available(): + _SUPPORTED_ACCELERATORS.append("xpu") def _get_supported_strategies() -> List[str]: @@ -148,13 +151,17 @@ def _set_env_variables(args: Namespace) -> None: def _get_num_processes(accelerator: str, devices: str) -> int: """Parse the `devices` argument to determine how many processes need to be launched on the current machine.""" if accelerator == "gpu": - parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True) + parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True, include_xpu=True) elif accelerator == "cuda": parsed_devices = CUDAAccelerator.parse_devices(devices) elif accelerator == "mps": parsed_devices = MPSAccelerator.parse_devices(devices) elif accelerator == "tpu": raise ValueError("Launching processes for TPU through the CLI is not supported.") + elif accelerator == "xpu": + from lightning_xpu.fabric import XPUAccelerator + + parsed_devices = XPUAccelerator.parse_devices(devices) else: return CPUAccelerator.parse_devices(devices) return len(parsed_devices) if parsed_devices is not None else 0 diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 1723b341be4f15..728edffa28b4dc 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -63,7 +63,7 @@ from lightning.fabric.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy from lightning.fabric.utilities import rank_zero_info, rank_zero_warn from lightning.fabric.utilities.device_parser import _determine_root_gpu_device -from lightning.fabric.utilities.imports import _IS_INTERACTIVE +from lightning.fabric.utilities.imports import _IS_INTERACTIVE, _lightning_xpu_available _PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO] _PLUGIN_INPUT = Union[_PLUGIN, str] @@ -290,6 +290,13 @@ def _check_config_and_set_final_flags( f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cuda" + if self._strategy_flag.parallel_devices[0].type == "xpu": + if self._accelerator_flag and self._accelerator_flag not in ("auto", "xpu", "gpu"): + raise ValueError( + f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," + f" but accelerator set to {self._accelerator_flag}, please choose one device type" + ) + self._accelerator_flag = "xpu" self._parallel_devices = self._strategy_flag.parallel_devices def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None: @@ -315,6 +322,12 @@ def _choose_auto_accelerator(self) -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if _lightning_xpu_available(): + from lightning_xpu.fabric import XPUAccelerator + + if XPUAccelerator.is_available(): + return "xpu" + return "cpu" @staticmethod @@ -323,6 +336,11 @@ def _choose_gpu_accelerator_backend() -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if _lightning_xpu_available(): + from lightning_xpu.fabric import XPUAccelerator + + if XPUAccelerator.is_available(): + return "xpu" raise RuntimeError("No supported gpu backend found!") def _set_parallel_devices_and_init_accelerator(self) -> None: @@ -392,8 +410,15 @@ def _choose_strategy(self) -> Union[Strategy, str]: if self._num_nodes_flag > 1: return "ddp" if len(self._parallel_devices) <= 1: - if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( - isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") + supported_accelerators = [CUDAAccelerator, MPSAccelerator] + supported_accelerators_str = ["cuda", "gpu", "mps"] + if _lightning_xpu_available(): + from lightning_xpu.fabric import XPUAccelerator + + supported_accelerators.append(XPUAccelerator) + supported_accelerators_str.append("xpu") + if isinstance(self._accelerator_flag, tuple(supported_accelerators)) or ( + isinstance(self._accelerator_flag, str) and self._accelerator_flag in tuple(supported_accelerators_str) ): device = _determine_root_gpu_device(self._parallel_devices) else: @@ -473,7 +498,12 @@ def _check_and_init_precision(self) -> Precision: if self._precision_input == "16-mixed" else "Using bfloat16 Automatic Mixed Precision (AMP)" ) - device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + device = "cuda" + if self._accelerator_flag == "cpu": + device = "cpu" + elif self._accelerator_flag == "xpu": + device = "xpu" + return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index 8d7468ba884240..84f5da9e832b81 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -115,10 +115,17 @@ def setup_environment(self) -> None: def setup_module(self, module: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self._determine_ddp_device_ids() - # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() - with ctx: + ctx = None + if self.root_device.type == "cuda": + # https://pytorch.org/docs/stable/notes/cuda.html#id5 + ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + if self.root_device.type == "xpu": + ctx = torch.xpu.stream(torch.xpu.Stream()) if device_ids is not None else nullcontext() + if ctx is None: return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) + else: + with ctx: + return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) def module_to_device(self, module: Module) -> None: module.to(self.root_device) diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 213dd9bf503537..a62b71039be739 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -33,6 +33,7 @@ from lightning.fabric.strategies.registry import _StrategyRegistry from lightning.fabric.strategies.strategy import _Sharded from lightning.fabric.utilities.distributed import log +from lightning.fabric.utilities.imports import _lightning_xpu_available from lightning.fabric.utilities.load import _move_state_into from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn from lightning.fabric.utilities.seed import reset_seed @@ -41,6 +42,9 @@ if TYPE_CHECKING: from deepspeed import DeepSpeedEngine +if _lightning_xpu_available(): + from lightning_xpu.fabric import XPUAccelerator + _DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") @@ -217,7 +221,8 @@ def __init__( contiguous_memory_optimization: Copies partitioned activations so that they are contiguous in memory. Not supported by all models. - synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary. + synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` or :func:`torch.xpu.synchronize` + at each checkpoint boundary. load_full_weights: True when loading a single checkpoint file containing the model state dict when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards @@ -493,6 +498,10 @@ def load_checkpoint( optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values()) torch.cuda.empty_cache() + with suppress(AttributeError): + if _lightning_xpu_available(): + XPUAccelerator.teardown() + _, client_state = engine.load_checkpoint( path, tag="checkpoint", @@ -591,10 +600,15 @@ def _initialize_engine( return deepspeed_engine, deepspeed_optimizer def _setup_distributed(self) -> None: - if not isinstance(self.accelerator, CUDAAccelerator): + ds_support = False + if isinstance(self.accelerator, CUDAAccelerator): + ds_support = True + if _lightning_xpu_available() and isinstance(self.accelerator, XPUAccelerator): + ds_support = True + if not ds_support: raise RuntimeError( - f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" - " is used." + "The DeepSpeed strategy is only supported on CUDA/Intel(R) GPUs but" + " `{self.accelerator.__class__.__name__}` is used." ) assert self.parallel_devices is not None _validate_device_index_selection(self.parallel_devices) diff --git a/src/lightning/fabric/strategies/launchers/multiprocessing.py b/src/lightning/fabric/strategies/launchers/multiprocessing.py index e92853a6a2b20c..c35f2a0d8e8a55 100644 --- a/src/lightning/fabric/strategies/launchers/multiprocessing.py +++ b/src/lightning/fabric/strategies/launchers/multiprocessing.py @@ -27,12 +27,15 @@ from lightning.fabric.accelerators.cpu import CPUAccelerator from lightning.fabric.strategies.launchers.launcher import _Launcher from lightning.fabric.utilities.apply_func import move_data_to_device -from lightning.fabric.utilities.imports import _IS_INTERACTIVE +from lightning.fabric.utilities.imports import _IS_INTERACTIVE, _lightning_xpu_available from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states if TYPE_CHECKING: from lightning.fabric.strategies import ParallelStrategy +if _lightning_xpu_available(): + from lightning_xpu.fabric import XPUAccelerator + class _MultiProcessingLauncher(_Launcher): r"""Launches processes that run a given function in parallel, and joins them all at the end. @@ -92,6 +95,8 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """ if self._start_method in ("fork", "forkserver"): _check_bad_cuda_fork() + if _lightning_xpu_available() and XPUAccelerator.is_available(): + _check_bad_xpu_fork() if self._start_method == "spawn": _check_missing_main_guard() @@ -241,3 +246,23 @@ def main(): """ ) raise RuntimeError(message) + + +def _check_bad_xpu_fork() -> None: + """Checks whether it is safe to fork and initialize XPU in the new processes, and raises an exception if not. + + The error message replaces PyTorch's 'Cannot re-initialize XPU in forked subprocess' with helpful advice for + Lightning users. + + """ + if not XPUAccelerator.is_xpu_initialized(): + return + + message = ( + "Lightning can't create new processes if XPU is already initialized. Did you manually call" + " `torch.xpu.*` functions, have moved the model to the device, or allocated memory on the GPU any" + " other way? Please remove any such calls, or change the selected strategy." + ) + if _IS_INTERACTIVE: + message += " You will have to restart the Python kernel." + raise RuntimeError(message) diff --git a/src/lightning/fabric/utilities/device_parser.py b/src/lightning/fabric/utilities/device_parser.py index 2aa8872e878120..524d2e99915735 100644 --- a/src/lightning/fabric/utilities/device_parser.py +++ b/src/lightning/fabric/utilities/device_parser.py @@ -16,6 +16,7 @@ import lightning.fabric.accelerators as accelerators # avoid circular dependency from lightning.fabric.plugins.environments.torchelastic import TorchElasticEnvironment from lightning.fabric.utilities.exceptions import MisconfigurationException +from lightning.fabric.utilities.imports import _lightning_xpu_available from lightning.fabric.utilities.types import _DEVICE @@ -49,6 +50,7 @@ def _parse_gpu_ids( gpus: Optional[Union[int, str, List[int]]], include_cuda: bool = False, include_mps: bool = False, + include_xpu: bool = False, ) -> Optional[List[int]]: """ Parses the GPU IDs given in the format as accepted by the @@ -62,6 +64,7 @@ def _parse_gpu_ids( Any int N > 0 indicates that GPUs [0..N) should be used. include_cuda: A boolean value indicating whether to include CUDA devices for GPU parsing. include_mps: A boolean value indicating whether to include MPS devices for GPU parsing. + include_xpu: A boolean value indicating whether to include Intel GPU devices for GPU parsing. Returns: A list of GPUs to be used or ``None`` if no GPUs were requested @@ -71,7 +74,7 @@ def _parse_gpu_ids( If no GPUs are available but the value of gpus variable indicates request for GPUs .. note:: - ``include_cuda`` and ``include_mps`` default to ``False`` so that you only + ``include_cuda``, ``include_mps`` and ``include_xpu`` default to ``False`` so that you only have to specify which device type to use and all other devices are not disabled. """ # Check that gpus param is None, Int, String or Sequence of Ints @@ -84,14 +87,17 @@ def _parse_gpu_ids( # We know the user requested GPUs therefore if some of the # requested GPUs are not available an exception is thrown. gpus = _normalize_parse_gpu_string_input(gpus) - gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps) + gpus = _normalize_parse_gpu_input_to_list( + gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu + ) if not gpus: raise MisconfigurationException("GPUs requested but none are available.") if ( TorchElasticEnvironment.detect() and len(gpus) != 1 - and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)) == 1 + and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)) + == 1 ): # Omit sanity check on torchelastic because by default it shows one visible GPU per process return gpus @@ -99,7 +105,7 @@ def _parse_gpu_ids( # Check that GPUs are unique. Duplicate GPUs are not supported by the backend. _check_unique(gpus) - return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps) + return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu) def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]: @@ -112,7 +118,9 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in return int(s.strip()) -def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]: +def _sanitize_gpu_ids( + gpus: List[int], include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False +) -> List[int]: """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the GPUs is not available. @@ -127,9 +135,11 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: If machine has fewer available GPUs than requested. """ - if sum((include_cuda, include_mps)) == 0: + if sum((include_cuda, include_mps, include_xpu)) == 0: raise ValueError("At least one gpu type should be specified!") - all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) + all_available_gpus = _get_all_available_gpus( + include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu + ) for gpu in gpus: if gpu not in all_available_gpus: raise MisconfigurationException( @@ -139,7 +149,10 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: def _normalize_parse_gpu_input_to_list( - gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool + gpus: Union[int, List[int], Tuple[int, ...]], + include_cuda: bool, + include_mps: bool, + include_xpu: bool, ) -> Optional[List[int]]: assert gpus is not None if isinstance(gpus, (MutableSequence, tuple)): @@ -149,19 +162,26 @@ def _normalize_parse_gpu_input_to_list( if not gpus: # gpus==0 return None if gpus == -1: - return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) + return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu) return list(range(gpus)) -def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]: +def _get_all_available_gpus( + include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False +) -> List[int]: """ Returns: A list of all available GPUs """ cuda_gpus = accelerators.cuda._get_all_visible_cuda_devices() if include_cuda else [] mps_gpus = accelerators.mps._get_all_available_mps_gpus() if include_mps else [] - return cuda_gpus + mps_gpus + xpu_gpus = [] + if _lightning_xpu_available(): + import lightning_xpu.fabric as accelerator_xpu + + xpu_gpus += accelerator_xpu._get_all_visible_xpu_devices() if include_xpu else [] + return cuda_gpus + mps_gpus + xpu_gpus def _check_unique(device_ids: List[int]) -> None: diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 0c261a9f099059..74d284e6c9e897 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -229,7 +229,7 @@ def _init_dist_connection( Args: cluster_environment: ``ClusterEnvironment`` instance - torch_distributed_backend: Backend to use (includes `nccl` and `gloo`) + torch_distributed_backend: Backend to use (includes `nccl`, `gloo` and `ccl`) global_rank: Rank of the current process world_size: Number of processes in the group kwargs: Kwargs for ``init_process_group`` @@ -261,7 +261,12 @@ def _init_dist_connection( def _get_default_process_group_backend_for_device(device: torch.device) -> str: - return "nccl" if device.type == "cuda" else "gloo" + if device.type == "cuda": + return "nccl" + elif device.type == "xpu": + return "ccl" + else: + return "gloo" class _DatasetSamplerWrapper(Dataset): diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index a24441bad86c03..eb4f76f00dac81 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities.""" +import functools import operator import platform import sys -from lightning_utilities.core.imports import compare_version +from lightning_utilities.core.imports import compare_version, RequirementCache _IS_WINDOWS = platform.system() == "Windows" @@ -25,12 +26,30 @@ # 2. The inspection mode via `python -i`: https://stackoverflow.com/a/6879085/1162383 _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive) -_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0") -_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0") -_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0") +_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0", use_base_version=True) +_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0", use_base_version=True) +_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True) _TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True) _TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0", use_base_version=True) _TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1 _PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8) _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) + + +@functools.lru_cache(maxsize=128) +def _try_import_module(module_name: str) -> bool: + try: + __import__(module_name) + return True + # added also AttributeError fro case of impoerts like pl.LightningModule + except (ImportError, AttributeError) as err: + rank_zero_warn(f"Import of {module_name} package failed for some compatibility issues: \n{err}") + return False + + +@functools.lru_cache(maxsize=1) +def _lightning_xpu_available() -> bool: + # This is defined as a function instead of a constant to avoid circular imports, because `lightning_xpu` + # also imports Lightning + return bool(RequirementCache("lightning-xpu")) and _try_import_module("lightning_xpu") diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index 2144addad4f44c..5fb68290882f23 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -14,6 +14,10 @@ max_seed_value = np.iinfo(np.uint32).max min_seed_value = np.iinfo(np.uint32).min +from lightning.fabric.utilities.imports import _lightning_xpu_available + +if _lightning_xpu_available(): + from lightning_xpu.fabric import XPUAccelerator def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: @@ -57,6 +61,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + if _lightning_xpu_available() and XPUAccelerator.is_available(): + XPUAccelerator.manual_seed_all(seed) os.environ["PL_SEED_WORKERS"] = f"{int(workers)}" @@ -106,8 +112,8 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: random.seed(stdlib_seed) -def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: - r"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python.""" +def _collect_rng_states(include_cuda: bool = True, include_xpu: bool = True) -> Dict[str, Any]: + """Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python.""" states = { "torch": torch.get_rng_state(), "numpy": np.random.get_state(), @@ -115,6 +121,8 @@ def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: } if include_cuda: states["torch.cuda"] = torch.cuda.get_rng_state_all() + if include_xpu and _lightning_xpu_available() and XPUAccelerator.is_available(): + states["torch.xpu"] = XPUAccelerator._collect_rng_states() return states @@ -125,6 +133,8 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None: # torch.cuda rng_state is only included since v1.8. if "torch.cuda" in rng_state_dict: torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"]) + if "torch.xpu" in rng_state_dict and _lightning_xpu_available() and XPUAccelerator.is_available(): + XPUAccelerator._set_rng_states(rng_state_dict) np.random.set_state(rng_state_dict["numpy"]) version, state, gauss = rng_state_dict["python"] python_set_rng_state((version, tuple(state), gauss)) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 2575a57d45276d..36ae77238a88e1 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -183,10 +183,17 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") - # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() - with ctx: + ctx = None + if self.root_device.type == "cuda": + # https://pytorch.org/docs/stable/notes/cuda.html#id5 + ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + if self.root_device.type == "xpu": + ctx = torch.xpu.stream(torch.xpu.Stream()) if device_ids is not None else nullcontext() + if ctx is None: return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) + else: + with ctx: + return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) def setup_distributed(self) -> None: log.debug(f"{self.__class__.__name__}: setting up distributed...") diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 0391bc6386b06e..027a1223392c21 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -241,7 +241,8 @@ def __init__( contiguous_memory_optimization: Copies partitioned activations so that they are contiguous in memory. Not supported by all models. - synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary. + synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` or :func:`torch.xpu.synchronize` + at each checkpoint boundary. load_full_weights: True when loading a single checkpoint file containing the model state dict when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py index 9cd401ff2f85e9..5a4f92cf5de7f7 100644 --- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py +++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py @@ -29,6 +29,7 @@ import lightning.pytorch as pl from lightning.fabric.strategies.launchers.multiprocessing import ( _check_bad_cuda_fork, + _check_bad_xpu_fork, _check_missing_main_guard, _disable_module_memory_sharing, ) @@ -39,8 +40,12 @@ from lightning.pytorch.strategies.launchers.launcher import _Launcher from lightning.pytorch.trainer.connectors.signal_connector import _SIGNUM from lightning.pytorch.trainer.states import TrainerFn, TrainerState +from lightning.pytorch.utilities.imports import _lightning_xpu_available from lightning.pytorch.utilities.rank_zero import rank_zero_debug +if _lightning_xpu_available(): + from lightning_xpu.pytorch import XPUAccelerator + log = logging.getLogger(__name__) @@ -103,6 +108,8 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] """ if self._start_method in ("fork", "forkserver"): _check_bad_cuda_fork() + if XPUAccelerator.is_available(): + _check_bad_xpu_fork() if self._start_method == "spawn": _check_missing_main_guard() diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 300fcd3c5589b2..733a1a282ce456 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -66,6 +66,7 @@ _LIGHTNING_COLOSSALAI_AVAILABLE, _lightning_graphcore_available, _lightning_habana_available, + _lightning_xpu_available, ) from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn @@ -350,6 +351,11 @@ def _choose_auto_accelerator(self) -> str: if HPUAccelerator.is_available(): return "hpu" + if _lightning_xpu_available(): + from lightning_xpu.pytorch import XPUAccelerator + + if XPUAccelerator.is_available(): + return "xpu" if MPSAccelerator.is_available(): return "mps" if CUDAAccelerator.is_available(): @@ -362,6 +368,11 @@ def _choose_gpu_accelerator_backend() -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if _lightning_xpu_available(): + from lightning_xpu.pytorch import XPUAccelerator + + if XPUAccelerator.is_available(): + return "xpu" raise MisconfigurationException("No supported gpu backend found!") def _set_parallel_devices_and_init_accelerator(self) -> None: @@ -451,6 +462,12 @@ def _choose_strategy(self) -> Union[Strategy, str]: from lightning_habana import SingleHPUStrategy return SingleHPUStrategy(device=torch.device("hpu")) + if self._accelerator_flag == "xpu" and not _lightning_xpu_available(): + raise ImportError( + "You have asked for XPU but you miss install related integration." + " Please run `pip install lightning-xpu` or see for further instructions" + " in https://github.com/Lightning-AI/lightning-XPU/." + ) if self._accelerator_flag == "tpu" or isinstance(self._accelerator_flag, XLAAccelerator): if self._parallel_devices and len(self._parallel_devices) > 1: return XLAStrategy.strategy_name @@ -708,6 +725,13 @@ def _register_external_accelerators_and_strategies() -> None: if "hpu_single" not in StrategyRegistry: SingleHPUStrategy.register_strategies(StrategyRegistry) + if _lightning_xpu_available(): + from lightning_xpu.pytorch import XPUAccelerator + + # TODO: Prevent registering multiple times + if "xpu" not in AcceleratorRegistry: + XPUAccelerator.register_accelerators(AcceleratorRegistry) + if _lightning_graphcore_available(): from lightning_graphcore import IPUAccelerator, IPUStrategy diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index 7187174708f260..4b363f574aa8cf 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -28,7 +28,11 @@ XLAProfiler, ) from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _lightning_graphcore_available, _lightning_habana_available +from lightning.pytorch.utilities.imports import ( + _lightning_graphcore_available, + _lightning_habana_available, + _lightning_xpu_available, +) from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn @@ -178,11 +182,24 @@ def _log_device_info(trainer: "pl.Trainer") -> None: hpu_available = False rank_zero_info(f"HPU available: {hpu_available}, using: {num_hpus} HPUs") + if _lightning_xpu_available(): + from lightning_xpu.pytorch import XPUAccelerator + + num_xpus = trainer.num_devices if isinstance(trainer.accelerator, XPUAccelerator) else 0 + xpu_available = XPUAccelerator.is_available() + else: + num_xpus = 0 + xpu_available = False + rank_zero_info(f"XPU available: {xpu_available}, using: {num_xpus} XPUs") + if ( CUDAAccelerator.is_available() and not isinstance(trainer.accelerator, CUDAAccelerator) or MPSAccelerator.is_available() and not isinstance(trainer.accelerator, MPSAccelerator) + or _lightning_xpu_available() + and XPUAccelerator.is_available() + and not isinstance(trainer.accelerator, XPUAccelerator) ): rank_zero_warn( "GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.", diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 8838262b890bc3..4c67b2efb85532 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -150,7 +150,7 @@ def __init__( precision: Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'), 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed'). - Can be used on CPU, GPU, TPUs, HPUs or IPUs. + Can be used on CPU, GPU, TPUs, HPUs, IPUs or XPUs. Default: ``'32-true'``. logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index e2b40f98ec44b2..e27151bb3ec1a7 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -51,3 +51,10 @@ def _lightning_habana_available() -> bool: # This is defined as a function instead of a constant to avoid circular imports, because `lightning_habana` # also imports Lightning return bool(RequirementCache("lightning-habana")) and _try_import_module("lightning_habana") + + +@functools.lru_cache(maxsize=1) +def _lightning_xpu_available() -> bool: + # This is defined as a function instead of a constant to avoid circular imports, because `lightning_xpu` + # also imports Lightning + return bool(RequirementCache("lightning-xpu")) and _try_import_module("lightning_xpu") diff --git a/src/lightning/pytorch/utilities/seed.py b/src/lightning/pytorch/utilities/seed.py index 10badab69c32fe..e2043c8675242c 100644 --- a/src/lightning/pytorch/utilities/seed.py +++ b/src/lightning/pytorch/utilities/seed.py @@ -19,7 +19,7 @@ @contextmanager -def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]: +def isolate_rng(include_cuda: bool = True, include_xpu: bool = True) -> Generator[None, None, None]: """A context manager that resets the global random state on exit to what it was before entering. It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators. @@ -40,6 +40,6 @@ def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]: tensor([0.7576]) """ - states = _collect_rng_states(include_cuda) + states = _collect_rng_states(include_cuda, include_xpu) yield _set_rng_states(states) diff --git a/src/pytorch_lightning/README.md b/src/pytorch_lightning/README.md index a2b70478a8863f..4122bba9e5666d 100644 --- a/src/pytorch_lightning/README.md +++ b/src/pytorch_lightning/README.md @@ -62,7 +62,7 @@ Lightning forces the following structure to your code which makes it reusable an - Non-essential research code (logging, etc... this goes in Callbacks). - Data (use PyTorch DataLoaders or organize them into a LightningDataModule). -Once you do this, you can train on multiple-GPUs, TPUs, CPUs, IPUs, HPUs and even in 16-bit precision without changing your code! +Once you do this, you can train on multiple-GPUs, TPUs, CPUs, IPUs, HPUs, XPUs and even in 16-bit precision without changing your code! [Get started in just 15 minutes](https://lightning.ai/docs/pytorch/latest/starter/introduction.html) @@ -70,7 +70,7 @@ ______________________________________________________________________ ## Continuous Integration -Lightning is rigorously tested across multiple CPUs, GPUs, TPUs, IPUs, and HPUs and against major Python and PyTorch versions. +Lightning is rigorously tested across multiple CPUs, GPUs, TPUs, IPUs, HPUs and XPUs and against major Python and PyTorch versions.
Current build statuses