diff --git a/cyto_dl/utils/checkpoint.py b/cyto_dl/utils/checkpoint.py index 7e6e0043..e6f98aae 100644 --- a/cyto_dl/utils/checkpoint.py +++ b/cyto_dl/utils/checkpoint.py @@ -7,7 +7,7 @@ def load_checkpoint(model, load_params): "ckpt_path" ), "ckpt_path must be provided to with argument weights_only=True" # load model from state dict to get around trainer.max_epochs limit, useful for resuming model training from existing weights - state_dict = torch.load(load_params["ckpt_path"])["state_dict"] + state_dict = torch.load(load_params["ckpt_path"], map_location="cpu")["state_dict"] model.load_state_dict(state_dict, strict=load_params.get("strict", True)) # set ckpt_path to None to avoid loading checkpoint again with model.fit/model.test load_params["ckpt_path"] = None