Skip to content

Commit

Permalink
fix circle import issue
Browse files Browse the repository at this point in the history
  • Loading branch information
jingxu10 committed Oct 5, 2023
1 parent 2c62b3f commit d2abb7f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
3 changes: 0 additions & 3 deletions src/lightning/fabric/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@

from lightning.fabric.utilities.imports import _lightning_xpu_available

_ACCELERATORS_BASE_MODULE = "lightning.fabric.accelerators"
ACCELERATOR_REGISTRY = _AcceleratorRegistry()
call_register_accelerators(ACCELERATOR_REGISTRY, _ACCELERATORS_BASE_MODULE)
if _lightning_xpu_available() and "xpu" not in ACCELERATOR_REGISTRY:
from lightning_xpu.fabric import XPUAccelerator

Expand Down
15 changes: 6 additions & 9 deletions src/lightning/fabric/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
min_seed_value = np.iinfo(np.uint32).min
from lightning.fabric.utilities.imports import _lightning_xpu_available

if _lightning_xpu_available():
from lightning_xpu.fabric import XPUAccelerator


def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
r"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
Expand Down Expand Up @@ -61,8 +58,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if _lightning_xpu_available() and XPUAccelerator.is_available():
XPUAccelerator.manual_seed_all(seed)
if _lightning_xpu_available() and torch.xpu.is_available():
torch.xpu.manual_seed_all(seed)

os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"

Expand Down Expand Up @@ -121,8 +118,8 @@ def _collect_rng_states(include_cuda: bool = True, include_xpu: bool = True) ->
}
if include_cuda:
states["torch.cuda"] = torch.cuda.get_rng_state_all()
if include_xpu and _lightning_xpu_available() and XPUAccelerator.is_available():
states["torch.xpu"] = XPUAccelerator._collect_rng_states()
if include_xpu and _lightning_xpu_available() and torch.xpu.is_available():
states["torch.xpu"] = torch.xpu.get_rng_state_all()
return states


Expand All @@ -133,8 +130,8 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
# torch.cuda rng_state is only included since v1.8.
if "torch.cuda" in rng_state_dict:
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
if "torch.xpu" in rng_state_dict and _lightning_xpu_available() and XPUAccelerator.is_available():
XPUAccelerator._set_rng_states(rng_state_dict)
if "torch.xpu" in rng_state_dict and _lightning_xpu_available() and torch.xpu.is_available():
torch.xpu.set_rng_states_all(rng_state_dict["torch.xpu"])
np.random.set_state(rng_state_dict["numpy"])
version, state, gauss = rng_state_dict["python"]
python_set_rng_state((version, tuple(state), gauss))

0 comments on commit d2abb7f

Please sign in to comment.