diff --git a/examples/text-generation/checkpoint_utils.py b/examples/text-generation/checkpoint_utils.py index 067ad4a36e..7b60f4a931 100644 --- a/examples/text-generation/checkpoint_utils.py +++ b/examples/text-generation/checkpoint_utils.py @@ -3,7 +3,7 @@ from pathlib import Path import torch -from huggingface_hub import snapshot_download +from huggingface_hub import list_repo_files, snapshot_download from transformers.utils import is_offline_mode, is_safetensors_available @@ -20,13 +20,21 @@ def get_repo_root(model_name_or_path, local_rank=-1): if local_rank == 0: print("Offline mode: forcing local_files_only=True") + # Only download PyTorch weights by default + allow_patterns = ["*.bin"] + # If the model repo contains any .safetensors file and + # safetensors is installed, only download safetensors weights + if is_safetensors_available(): + if any(".safetensors" in filename for filename in list_repo_files(model_name_or_path)): + allow_patterns = ["*.safetensors"] + # Download only on first process if local_rank in [-1, 0]: cache_dir = snapshot_download( model_name_or_path, local_files_only=is_offline_mode(), cache_dir=os.getenv("TRANSFORMERS_CACHE", None), - allow_patterns=["*.safetensors"] if is_safetensors_available() else ["*.bin"], + allow_patterns=allow_patterns, max_workers=16, ) if local_rank == -1: @@ -40,7 +48,7 @@ def get_repo_root(model_name_or_path, local_rank=-1): model_name_or_path, local_files_only=is_offline_mode(), cache_dir=os.getenv("TRANSFORMERS_CACHE", None), - allow_patterns=["*.safetensors"] if is_safetensors_available() else ["*.bin"], + allow_patterns=allow_patterns, )