Skip to content

Commit

Permalink
HF token Validation
Browse files Browse the repository at this point in the history
  • Loading branch information
luv-bansal committed Jan 13, 2025
1 parent 1e6930b commit 7c407f2
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
17 changes: 12 additions & 5 deletions clarifai/runners/models/model_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,7 @@ def _validate_config_checkpoints(self):
assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file"
repo_id = self.config.get("checkpoints").get("repo_id")

# prefer env var for HF_TOKEN but if not provided then use the one from config.yaml if any.
if 'HF_TOKEN' in os.environ:
hf_token = os.environ['HF_TOKEN']
else:
hf_token = self.config.get("checkpoints").get("hf_token", None)
hf_token = self.config.get("checkpoints").get("hf_token", None)
return repo_id, hf_token

def _check_app_exists(self):
Expand Down Expand Up @@ -120,6 +116,17 @@ def _validate_config(self):
model_type_id = self.config.get('model').get('model_type_id')
assert model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE, f"Model type {model_type_id} not supported for concepts"

if self.config.get("checkpoints"):
_, hf_token = self._validate_config_checkpoints()

if hf_token:
is_valid_token = HuggingFaceLoader.validate_hftoken(hf_token)
if not is_valid_token:
logger.error(
"Invalid Hugging Face token provided in the config file, this might cause issues with downloading the restricted model checkpoints."
)
logger.info("Continuing without Hugging Face token")

@property
def client(self):
if self._client is None:
Expand Down
32 changes: 19 additions & 13 deletions clarifai/runners/utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,28 @@ def __init__(self, repo_id=None, token=None):
self.repo_id = repo_id
self.token = token
if token:
try:
if importlib.util.find_spec("huggingface_hub") is None:
raise ImportError(self.HF_DOWNLOAD_TEXT)
os.environ['HF_TOKEN'] = token
from huggingface_hub import HfApi

api = HfApi()
api.whoami(token=token)

if self.validate_hftoken(token):
subprocess.run(f'huggingface-cli login --token={os.environ["HF_TOKEN"]}', shell=True)
except Exception as e:
logger.error(
f"Error setting up Hugging Face token, please make sure you have the correct token: {e}"
)
logger.info("Hugging Face token validated")
else:
logger.info("Continuing without Hugging Face token")

@classmethod
def validate_hftoken(cls, hf_token: str):
try:
if importlib.util.find_spec("huggingface_hub") is None:
raise ImportError(cls.HF_DOWNLOAD_TEXT)
os.environ['HF_TOKEN'] = hf_token
from huggingface_hub import HfApi

api = HfApi()
api.whoami(token=hf_token)
return True
except Exception as e:
logger.error(
f"Error setting up Hugging Face token, please make sure you have the correct token: {e}")
return False

def download_checkpoints(self, checkpoint_path: str):
# throw error if huggingface_hub wasn't installed
try:
Expand Down

0 comments on commit 7c407f2

Please sign in to comment.