From 1ca56ea93c6a4539bce52a3573944c4fa3f2e3df Mon Sep 17 00:00:00 2001 From: junsong Date: Mon, 6 Jan 2025 00:26:14 -0800 Subject: [PATCH] fix the bug when resume rng_state with diffuerent gpu numbers. --- train_scripts/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/train_scripts/train.py b/train_scripts/train.py index 72a3ec0..3a56931 100755 --- a/train_scripts/train.py +++ b/train_scripts/train.py @@ -947,10 +947,13 @@ def main(cfg: SanaConfig) -> None: if rng_state: logger.info("resuming randomise") torch.set_rng_state(rng_state["torch"]) - torch.cuda.set_rng_state_all(rng_state["torch_cuda"]) np.random.set_state(rng_state["numpy"]) random.setstate(rng_state["python"]) generator.set_state(rng_state["generator"]) # resume generator status + try: + torch.cuda.set_rng_state_all(rng_state["torch_cuda"]) + except: + logger.warning("Failed to resume torch_cuda rng state") # Prepare everything # There is no specific order to remember, you just need to unpack the