Skip to content

Commit

Permalink
[EAGLE-5416] Added Tests for Download checkpoints and Fix download me…
Browse files Browse the repository at this point in the history
…thods (#488)

* also fall back to HF_TOKEN in env

* older version

* rc3

* fix folders with nested files

* Revert "Revert "[EAGLE-5416] Added Tests for Download checkpoints and Fix dow…"

This reverts commit f0d909a.

* improve loader validation

* also try ignoring .cache folder

* use new flag in tests
  • Loading branch information
zeiler authored Jan 17, 2025
1 parent f0d909a commit 540c0eb
Show file tree
Hide file tree
Showing 11 changed files with 225 additions and 55 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## [[11.0.4]](https://github.com/Clarifai/clarifai-python/releases/tag/11.0.4) - [PyPI](https://pypi.org/project/clarifai/11.0.4/) - 2025-01-17

### Changed

- Added tests for downloads and various improvements [(#488)] (https://github.com/Clarifai/clarifai-python/pull/488)

## [[11.0.3]](https://github.com/Clarifai/clarifai-python/releases/tag/11.0.3) - [PyPI](https://pypi.org/project/clarifai/11.0.3/) - 2025-01-14

### Changed
Expand Down
2 changes: 1 addition & 1 deletion clarifai/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "11.0.3"
__version__ = "11.0.4"
89 changes: 52 additions & 37 deletions clarifai/runners/models/model_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,36 +29,43 @@ def _clear_line(n: int = 1) -> None:

class ModelUploader:

def __init__(self, folder: str, validate_api_ids: bool = True):
def __init__(self, folder: str, validate_api_ids: bool = True, download_validation_only=False):
"""
:param folder: The folder containing the model.py, config.yaml, requirements.txt and
checkpoints.
:param validate_api_ids: Whether to validate the user_id and app_id in the config file.
:param validate_api_ids: Whether to validate the user_id and app_id in the config file. TODO(zeiler):
deprecate in favor of download_validation_only.
:param download_validation_only: Whether to skip the API config validation. Set to True if
just downloading a checkpoint.
"""
self._client = None
self.download_validation_only = download_validation_only
self.folder = self._validate_folder(folder)
self.config = self._load_config(os.path.join(self.folder, 'config.yaml'))
self.validate_api_ids = validate_api_ids
self._validate_config()
self.model_proto = self._get_model_proto()
self.model_id = self.model_proto.id
self.model_version_id = None
self.inference_compute_info = self._get_inference_compute_info()
self.is_v3 = True # Do model build for v3

@staticmethod
def _validate_folder(folder):
def _validate_folder(self, folder):
if folder == ".":
folder = "" # will getcwd() next which ends with /
if not folder.startswith("/"):
folder = os.path.join(os.getcwd(), folder)
logger.info(f"Validating folder: {folder}")
if not os.path.exists(folder):
raise FileNotFoundError(f"Folder {folder} not found, please provide a valid folder path")
files = os.listdir(folder)
assert "requirements.txt" in files, "requirements.txt not found in the folder"
assert "config.yaml" in files, "config.yaml not found in the folder"
# If just downloading we don't need requirements.txt or the python code, we do need the
# 1/ folder to put 1/checkpoints into.
assert "1" in files, "Subfolder '1' not found in the folder"
subfolder_files = os.listdir(os.path.join(folder, '1'))
assert 'model.py' in subfolder_files, "model.py not found in the folder"
if not self.download_validation_only:
assert "requirements.txt" in files, "requirements.txt not found in the folder"
subfolder_files = os.listdir(os.path.join(folder, '1'))
assert 'model.py' in subfolder_files, "model.py not found in the folder"
return folder

@staticmethod
Expand All @@ -68,22 +75,27 @@ def _load_config(config_file: str):
return config

def _validate_config_checkpoints(self):

"""
Validates the checkpoints section in the config file.
:return: loader_type the type of loader or None if no checkpoints.
:return: repo_id location of checkpoint.
:return: hf_token token to access checkpoint.
"""
assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
loader_type = self.config.get("checkpoints").get("type")
if not loader_type:
logger.info("No loader type specified in the config file for checkpoints")
return None, None, None
assert loader_type == "huggingface", "Only huggingface loader supported for now"
if loader_type == "huggingface":
assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file"
repo_id = self.config.get("checkpoints").get("repo_id")

hf_token = self.config.get("checkpoints").get("hf_token", None)
return repo_id, hf_token
# get from config.yaml otherwise fall back to HF_TOKEN env var.
hf_token = self.config.get("checkpoints").get("hf_token", os.environ.get("HF_TOKEN", None))
return loader_type, repo_id, hf_token

def _check_app_exists(self):
if not self.validate_api_ids:
return True
resp = self.client.STUB.GetApp(service_pb2.GetAppRequest(user_app_id=self.client.user_app_id))
if resp.status.code == status_code_pb2.SUCCESS:
return True
Expand Down Expand Up @@ -113,21 +125,19 @@ def _validate_config_model(self):
sys.exit(1)

def _validate_config(self):
self._validate_config_model()
if not self.download_validation_only:
self._validate_config_model()

if self.config.get("checkpoints"):
self._validate_config_checkpoints()
assert "inference_compute_info" in self.config, "inference_compute_info not found in the config file"

assert "inference_compute_info" in self.config, "inference_compute_info not found in the config file"

if self.config.get("concepts"):
model_type_id = self.config.get('model').get('model_type_id')
assert model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE, f"Model type {model_type_id} not supported for concepts"
if self.config.get("concepts"):
model_type_id = self.config.get('model').get('model_type_id')
assert model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE, f"Model type {model_type_id} not supported for concepts"

if self.config.get("checkpoints"):
_, hf_token = self._validate_config_checkpoints()
loader_type, _, hf_token = self._validate_config_checkpoints()

if hf_token:
if loader_type == "huggingface" and hf_token:
is_valid_token = HuggingFaceLoader.validate_hftoken(hf_token)
if not is_valid_token:
logger.error(
Expand Down Expand Up @@ -311,16 +321,19 @@ def download_checkpoints(self):
logger.info("No checkpoints specified in the config file")
return True

repo_id, hf_token = self._validate_config_checkpoints()
loader_type, repo_id, hf_token = self._validate_config_checkpoints()

loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)
success = loader.download_checkpoints(self.checkpoint_path)
success = True
if loader_type == "huggingface":
loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)
success = loader.download_checkpoints(self.checkpoint_path)

if not success:
logger.error(f"Failed to download checkpoints for model {repo_id}")
sys.exit(1)
else:
logger.info(f"Downloaded checkpoints for model {repo_id}")
if loader_type:
if not success:
logger.error(f"Failed to download checkpoints for model {repo_id}")
sys.exit(1)
else:
logger.info(f"Downloaded checkpoints for model {repo_id}")
return success

def _concepts_protos_from_concepts(self, concepts):
Expand Down Expand Up @@ -400,9 +413,10 @@ def upload_model_version(self, download_checkpoints):
input(
"Press Enter to download the HuggingFace model's config.json file to infer the concepts and continue..."
)
repo_id, hf_token = self._validate_config_checkpoints()
loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)
loader.download_config(self.checkpoint_path)
loader_type, repo_id, hf_token = self._validate_config_checkpoints()
if loader_type == "huggingface":
loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)
loader.download_config(self.checkpoint_path)

else:
logger.error(
Expand All @@ -413,10 +427,10 @@ def upload_model_version(self, download_checkpoints):
model_version_proto = self.get_model_version_proto()

if download_checkpoints:
tar_cmd = f"tar --exclude=*~ -czvf {self.tar_file} -C {self.folder} ."
tar_cmd = f"tar --exclude=*~ --exclude={self.tar_file} -czvf {self.tar_file} -C {self.folder} ."
else: # we don't want to send the checkpoints up even if they are in the folder.
logger.info(f"Skipping {self.checkpoint_path} in the tar file that is uploaded.")
tar_cmd = f"tar --exclude={self.checkpoint_suffix} --exclude=*~ -czvf {self.tar_file} -C {self.folder} ."
tar_cmd = f"tar --exclude={self.checkpoint_suffix} --exclude=*~ --exclude={self.tar_file} -czvf {self.tar_file} -C {self.folder} ."
# Tar the folder
logger.debug(tar_cmd)
os.system(tar_cmd)
Expand Down Expand Up @@ -493,14 +507,15 @@ def init_upload_model_version(self, model_version_proto, file_path):
file_size = os.path.getsize(file_path)
logger.info(f"Uploading model version of model {self.model_proto.id}")
logger.info(f"Using file '{os.path.basename(file_path)}' of size: {file_size} bytes")
return service_pb2.PostModelVersionsUploadRequest(
result = service_pb2.PostModelVersionsUploadRequest(
upload_config=service_pb2.PostModelVersionsUploadConfig(
user_app_id=self.client.user_app_id,
model_id=self.model_proto.id,
model_version=model_version_proto,
total_size=file_size,
is_v3=self.is_v3,
))
return result

def get_model_build_logs(self):
logs_request = service_pb2.ListLogEntriesRequest(
Expand Down
60 changes: 48 additions & 12 deletions clarifai/runners/utils/loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import fnmatch
import importlib.util
import json
import os
import shutil
import subprocess

from clarifai.utils.logging import logger
Expand Down Expand Up @@ -39,7 +41,7 @@ def validate_hftoken(cls, hf_token: str):
def download_checkpoints(self, checkpoint_path: str):
# throw error if huggingface_hub wasn't installed
try:
from huggingface_hub import list_repo_files, snapshot_download
from huggingface_hub import snapshot_download
except ImportError:
raise ImportError(self.HF_DOWNLOAD_TEXT)
if os.path.exists(checkpoint_path) and self.validate_download(checkpoint_path):
Expand All @@ -53,16 +55,17 @@ def download_checkpoints(self, checkpoint_path: str):
logger.error("Model %s not found on Hugging Face" % (self.repo_id))
return False

ignore_patterns = None # Download everything.
repo_files = list_repo_files(repo_id=self.repo_id, token=self.token)
if any(f.endswith(".safetensors") for f in repo_files):
logger.info(f"SafeTensors found in {self.repo_id}, downloading only .safetensors files.")
ignore_patterns = ["**/original/*", "**/*.pth", "**/*.bin"]
self.ignore_patterns = self._get_ignore_patterns()
snapshot_download(
repo_id=self.repo_id,
local_dir=checkpoint_path,
local_dir_use_symlinks=False,
ignore_patterns=ignore_patterns)
ignore_patterns=self.ignore_patterns)
# Remove the `.cache` folder if it exists
cache_path = os.path.join(checkpoint_path, ".cache")
if os.path.exists(cache_path) and os.path.isdir(cache_path):
shutil.rmtree(cache_path)

except Exception as e:
logger.error(f"Error downloading model checkpoints {e}")
return False
Expand Down Expand Up @@ -109,11 +112,44 @@ def validate_download(self, checkpoint_path: str):
from huggingface_hub import list_repo_files
except ImportError:
raise ImportError(self.HF_DOWNLOAD_TEXT)
checkpoint_dir_files = [
f for dp, dn, fn in os.walk(os.path.expanduser(checkpoint_path)) for f in fn
]
return (len(checkpoint_dir_files) >= len(list_repo_files(self.repo_id))) and len(
list_repo_files(self.repo_id)) > 0
# Get the list of files on the repo
repo_files = list_repo_files(self.repo_id, token=self.token)

self.ignore_patterns = self._get_ignore_patterns()
# Get the list of files on the repo that are not ignored
if getattr(self, "ignore_patterns", None):
patterns = self.ignore_patterns

def should_ignore(file_path):
return any(fnmatch.fnmatch(file_path, pattern) for pattern in patterns)

repo_files = [f for f in repo_files if not should_ignore(f)]

# Check if downloaded files match the files we expect (ignoring ignored patterns)
checkpoint_dir_files = []
for dp, dn, fn in os.walk(os.path.expanduser(checkpoint_path)):
checkpoint_dir_files.extend(
[os.path.relpath(os.path.join(dp, f), checkpoint_path) for f in fn])

# Validate by comparing file lists
return len(checkpoint_dir_files) >= len(repo_files) and not (
len(set(repo_files) - set(checkpoint_dir_files)) > 0) and len(repo_files) > 0

def _get_ignore_patterns(self):
# check if model exists on HF
try:
from huggingface_hub import list_repo_files
except ImportError:
raise ImportError(self.HF_DOWNLOAD_TEXT)

# Get the list of files on the repo that are not ignored
repo_files = list_repo_files(self.repo_id, token=self.token)
self.ignore_patterns = None
if any(f.endswith(".safetensors") for f in repo_files):
self.ignore_patterns = [
"**/original/*", "**/*.pth", "**/*.bin", "*.pth", "*.bin", "**/.cache/*"
]
return self.ignore_patterns

@staticmethod
def validate_config(checkpoint_path: str):
Expand Down
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pytest==7.1.2
pytest-cov==5.0.0
pytest-xdist==2.5.0
llama-index-core==0.11.17
huggingface_hub[hf_transfer]==0.27.1
pypdf==3.17.4
seaborn==0.13.2
pycocotools==2.0.6
File renamed without changes.
20 changes: 20 additions & 0 deletions tests/runners/dummy_runner_models/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# This is the sample config file for the GOT OCR2.O model.

model:
id: "dummy-runner-model"
user_id: "user_id"
app_id: "app_id"
model_type_id: "multimodal-to-text"

build_info:
python_version: "3.11"

inference_compute_info:
cpu_limit: "1"
cpu_memory: "1Gi"
num_accelerators: 0


checkpoints:
type: "huggingface"
repo_id: "timm/mobilenetv3_small_100.lamb_in1k"
Empty file.
62 changes: 62 additions & 0 deletions tests/runners/test_download_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import shutil
import tempfile

import pytest

from clarifai.runners.models.model_upload import ModelUploader
from clarifai.runners.utils.loader import HuggingFaceLoader

MODEL_ID = "timm/mobilenetv3_small_100.lamb_in1k"


@pytest.fixture(scope="module")
def checkpoint_dir():
# Create a temporary directory for the test checkpoints
temp_dir = os.path.join(tempfile.gettempdir(), MODEL_ID[5:])
if not os.path.exists(temp_dir):
os.makedirs(temp_dir)
yield temp_dir # Provide the directory to the tests
# Cleanup: remove the directory after all tests are complete
shutil.rmtree(temp_dir, ignore_errors=True)


# Pytest fixture to delete the checkpoints in dummy runner models folder after tests complete
@pytest.fixture(scope="function")
def dummy_runner_models_dir():
model_folder_path = os.path.join(os.path.dirname(__file__), "dummy_runner_models")
checkpoints_path = os.path.join(model_folder_path, "1", "checkpoints")
yield checkpoints_path
# Cleanup the checkpoints folder after the test
if os.path.exists(checkpoints_path):
shutil.rmtree(checkpoints_path)


@pytest.fixture(scope="function", autouse=True)
def override_environment_variables():
# Backup the existing environment variable value
original_clarifai_pat = os.environ.get("CLARIFAI_PAT")
if "CLARIFAI_PAT" in os.environ:
del os.environ["CLARIFAI_PAT"] # Temporarily unset the variable for the tests
yield
# Restore the original environment variable value after tests
if original_clarifai_pat:
os.environ["CLARIFAI_PAT"] = original_clarifai_pat


def test_loader_download_checkpoints(checkpoint_dir):
loader = HuggingFaceLoader(repo_id=MODEL_ID)
loader.download_checkpoints(checkpoint_path=checkpoint_dir)
assert len(os.listdir(checkpoint_dir)) == 4


def test_validate_download(checkpoint_dir):
loader = HuggingFaceLoader(repo_id=MODEL_ID)
assert loader.validate_download(checkpoint_path=checkpoint_dir) is True


def test_download_checkpoints(dummy_runner_models_dir):
model_folder_path = os.path.join(os.path.dirname(__file__), "dummy_runner_models")
model_upload = ModelUploader(model_folder_path, download_validation_only=True)
isdownloaded = model_upload.download_checkpoints()
assert isdownloaded is True
Loading

0 comments on commit 540c0eb

Please sign in to comment.