From 7c407f2906e6a2ceb06f422519f9a318f9b24f27 Mon Sep 17 00:00:00 2001 From: luv-bansal Date: Mon, 13 Jan 2025 15:41:58 +0000 Subject: [PATCH] HF token Validation --- clarifai/runners/models/model_upload.py | 17 +++++++++---- clarifai/runners/utils/loader.py | 32 +++++++++++++++---------- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/clarifai/runners/models/model_upload.py b/clarifai/runners/models/model_upload.py index 5ffdee63..f64e9743 100644 --- a/clarifai/runners/models/model_upload.py +++ b/clarifai/runners/models/model_upload.py @@ -72,11 +72,7 @@ def _validate_config_checkpoints(self): assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file" repo_id = self.config.get("checkpoints").get("repo_id") - # prefer env var for HF_TOKEN but if not provided then use the one from config.yaml if any. - if 'HF_TOKEN' in os.environ: - hf_token = os.environ['HF_TOKEN'] - else: - hf_token = self.config.get("checkpoints").get("hf_token", None) + hf_token = self.config.get("checkpoints").get("hf_token", None) return repo_id, hf_token def _check_app_exists(self): @@ -120,6 +116,17 @@ def _validate_config(self): model_type_id = self.config.get('model').get('model_type_id') assert model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE, f"Model type {model_type_id} not supported for concepts" + if self.config.get("checkpoints"): + _, hf_token = self._validate_config_checkpoints() + + if hf_token: + is_valid_token = HuggingFaceLoader.validate_hftoken(hf_token) + if not is_valid_token: + logger.error( + "Invalid Hugging Face token provided in the config file, this might cause issues with downloading the restricted model checkpoints." + ) + logger.info("Continuing without Hugging Face token") + @property def client(self): if self._client is None: diff --git a/clarifai/runners/utils/loader.py b/clarifai/runners/utils/loader.py index dac623d5..e11e8353 100644 --- a/clarifai/runners/utils/loader.py +++ b/clarifai/runners/utils/loader.py @@ -14,22 +14,28 @@ def __init__(self, repo_id=None, token=None): self.repo_id = repo_id self.token = token if token: - try: - if importlib.util.find_spec("huggingface_hub") is None: - raise ImportError(self.HF_DOWNLOAD_TEXT) - os.environ['HF_TOKEN'] = token - from huggingface_hub import HfApi - - api = HfApi() - api.whoami(token=token) - + if self.validate_hftoken(token): subprocess.run(f'huggingface-cli login --token={os.environ["HF_TOKEN"]}', shell=True) - except Exception as e: - logger.error( - f"Error setting up Hugging Face token, please make sure you have the correct token: {e}" - ) + logger.info("Hugging Face token validated") + else: logger.info("Continuing without Hugging Face token") + @classmethod + def validate_hftoken(cls, hf_token: str): + try: + if importlib.util.find_spec("huggingface_hub") is None: + raise ImportError(cls.HF_DOWNLOAD_TEXT) + os.environ['HF_TOKEN'] = hf_token + from huggingface_hub import HfApi + + api = HfApi() + api.whoami(token=hf_token) + return True + except Exception as e: + logger.error( + f"Error setting up Hugging Face token, please make sure you have the correct token: {e}") + return False + def download_checkpoints(self, checkpoint_path: str): # throw error if huggingface_hub wasn't installed try: