Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch.load() frequently warning in PersistentDataset and GDSDataset #8177

Merged
merged 9 commits into from
Oct 25, 2024
11 changes: 9 additions & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import warnings
from collections.abc import Callable, Sequence
from copy import copy, deepcopy
from inspect import signature
from multiprocessing.managers import ListProxy
from multiprocessing.pool import ThreadPool
from pathlib import Path
Expand Down Expand Up @@ -371,7 +372,10 @@ def _cachecheck(self, item_transformed):

if hashfile is not None and hashfile.is_file(): # cache hit
try:
return torch.load(hashfile)
if "weights_only" in signature(torch.load).parameters:
return torch.load(hashfile, weights_only=False)
else:
return torch.load(hashfile)
except PermissionError as e:
if sys.platform != "win32":
raise e
Expand Down Expand Up @@ -1670,4 +1674,7 @@ def _load_meta_cache(self, meta_hash_file_name):
if meta_hash_file_name in self._meta_cache:
return self._meta_cache[meta_hash_file_name]
else:
return torch.load(self.cache_dir / meta_hash_file_name)
if "weights_only" in signature(torch.load).parameters:
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
else:
return torch.load(self.cache_dir / meta_hash_file_name)
Loading