Skip to content
This repository has been archived by the owner on Oct 10, 2023. It is now read-only.

Commit

Permalink
Overrride HF_HOME env var within flojoy-decorated functions (#97)
Browse files Browse the repository at this point in the history
* override hfhome within flojoy-decorated functions

* Update flojoy/utils.py

Co-authored-by: Julien Jerphanion <[email protected]>

* add hf_home to test cache_huggingface_to_flojoy

---------

Co-authored-by: Julien Jerphanion <[email protected]>
  • Loading branch information
Roulbac and jjerphan authored Sep 4, 2023
1 parent 9a95e7f commit 1a58240
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 296 deletions.
269 changes: 0 additions & 269 deletions flojoy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,275 +18,6 @@ from .data_container import *
from .config import *
from .flojoy_cloud import *

def hf_hub_download(
repo_id: str,
filename: str,
*,
subfolder: Optional[str] = None,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
local_dir: Union[str, Path, None] = None,
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
user_agent: Union[dict, str, None] = None,
force_download: bool = False,
force_filename: Optional[str] = None,
proxies: Optional[dict] = None,
etag_timeout: float = 10,
resume_download: bool = False,
token: Union[bool, str, None] = None,
local_files_only: bool = False,
legacy_cache_layout: bool = False,
) -> str:
"""Download a given file if it's not already present in the local cache.
The new cache file layout looks like this:
- The cache directory contains one subfolder per repo_id (namespaced by repo type)
- inside each repo folder:
- refs is a list of the latest known revision => commit_hash pairs
- blobs contains the actual file blobs (identified by their git-sha or sha256, depending on
whether they're LFS files or not)
- snapshots contains one subfolder per commit, each "commit" contains the subset of the files
that have been resolved at that particular commit. Each filename is a symlink to the blob
at that particular commit.
If `local_dir` is provided, the file structure from the repo will be replicated in this location. You can configure
how you want to move those files:
- If `local_dir_use_symlinks="auto"` (default), files are downloaded and stored in the cache directory as blob
files. Small files (<5MB) are duplicated in `local_dir` while a symlink is created for bigger files. The goal
is to be able to manually edit and save small files without corrupting the cache while saving disk space for
binary files. The 5MB threshold can be configured with the `HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD`
environment variable.
- If `local_dir_use_symlinks=True`, files are downloaded, stored in the cache directory and symlinked in `local_dir`.
This is optimal in term of disk usage but files must not be manually edited.
- If `local_dir_use_symlinks=False` and the blob files exist in the cache directory, they are duplicated in the
local dir. This means disk usage is not optimized.
- Finally, if `local_dir_use_symlinks=False` and the blob files do not exist in the cache directory, then the
files are downloaded and directly placed under `local_dir`. This means if you need to download them again later,
they will be re-downloaded entirely.
```
[ 96] .
└── [ 160] models--julien-c--EsperBERTo-small
├── [ 160] blobs
│ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
│ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e
│ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812
├── [ 96] refs
│ └── [ 40] main
└── [ 128] snapshots
├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f
│ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
│ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
└── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48
├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e
└── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
```
Args:
repo_id (`str`):
A user or an organization name and a repo name separated by a `/`.
filename (`str`):
The name of the file in the repo.
subfolder (`str`, *optional*):
An optional value corresponding to a folder inside the model repo.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if downloading from a dataset or space,
`None` or `"model"` if downloading from a model. Default is `None`.
revision (`str`, *optional*):
An optional Git revision id which can be a branch name, a tag, or a
commit hash.
library_name (`str`, *optional*):
The name of the library to which the object corresponds.
library_version (`str`, *optional*):
The version of the library.
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
local_dir (`str` or `Path`, *optional*):
If provided, the downloaded file will be placed under this directory, either as a symlink (default) or
a regular file (see description for more details).
local_dir_use_symlinks (`"auto"` or `bool`, defaults to `"auto"`):
To be used with `local_dir`. If set to "auto", the cache directory will be used and the file will be either
duplicated or symlinked to the local directory depending on its size. It set to `True`, a symlink will be
created, no matter the file size. If set to `False`, the file will either be duplicated from cache (if
already exists) or downloaded from the Hub and not cached. See description for more details.
user_agent (`dict`, `str`, *optional*):
The user-agent info in the form of a dictionary or a string.
force_download (`bool`, *optional*, defaults to `False`):
Whether the file should be downloaded even if it already exists in
the local cache.
proxies (`dict`, *optional*):
Dictionary mapping protocol to the URL of the proxy passed to
`requests.request`.
etag_timeout (`float`, *optional*, defaults to `10`):
When fetching ETag, how many seconds to wait for the server to send
data before giving up which is passed to `requests.request`.
resume_download (`bool`, *optional*, defaults to `False`):
If `True`, resume a previously interrupted download.
token (`str`, `bool`, *optional*):
A token to be used for the download.
- If `True`, the token is read from the HuggingFace config
folder.
- If a string, it's used as the authentication token.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
legacy_cache_layout (`bool`, *optional*, defaults to `False`):
If `True`, uses the legacy file cache layout i.e. just call [`hf_hub_url`]
then `cached_download`. This is deprecated as the new cache layout is
more powerful.
Returns:
Local path (string) of file or if networking is off, last version of
file cached on disk.
<Tip>
Raises the following errors:
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
if `token=True` and the token cannot be found.
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
if ETag cannot be determined.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
- [`~utils.RepositoryNotFoundError`]
If the repository to download from cannot be found. This may be because it doesn't exist,
or because it is set to `private` and you do not have access.
- [`~utils.RevisionNotFoundError`]
If the revision to download from cannot be found.
- [`~utils.EntryNotFoundError`]
If the file to download cannot be found.
- [`~utils.LocalEntryNotFoundError`]
If network is disabled or unavailable and file is not found in cache.
</Tip>
"""
...

def snapshot_download(
repo_id: str,
*,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
local_dir: Union[str, Path, None] = None,
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Optional[Union[dict, str]] = None,
proxies: Optional[dict] = None,
etag_timeout: float = 10,
resume_download: bool = False,
force_download: bool = False,
token: Optional[Union[bool, str]] = None,
local_files_only: bool = False,
allow_patterns: Optional[Union[list[str], str]] = None,
ignore_patterns: Optional[Union[list[str], str]] = None,
max_workers: int = 8,
tqdm_class: Optional[Any] = None,
) -> str:
"""Download repo files.
Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from
a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order
to keep their actual filename relative to that folder. You can also filter which files to download using
`allow_patterns` and `ignore_patterns`.
If `local_dir` is provided, the file structure from the repo will be replicated in this location. You can configure
how you want to move those files:
- If `local_dir_use_symlinks="auto"` (default), files are downloaded and stored in the cache directory as blob
files. Small files (<5MB) are duplicated in `local_dir` while a symlink is created for bigger files. The goal
is to be able to manually edit and save small files without corrupting the cache while saving disk space for
binary files. The 5MB threshold can be configured with the `HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD`
environment variable.
- If `local_dir_use_symlinks=True`, files are downloaded, stored in the cache directory and symlinked in `local_dir`.
This is optimal in term of disk usage but files must not be manually edited.
- If `local_dir_use_symlinks=False` and the blob files exist in the cache directory, they are duplicated in the
local dir. This means disk usage is not optimized.
- Finally, if `local_dir_use_symlinks=False` and the blob files do not exist in the cache directory, then the
files are downloaded and directly placed under `local_dir`. This means if you need to download them again later,
they will be re-downloaded entirely.
An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly
configured. It is also not possible to filter which files to download when cloning a repository using git.
Args:
repo_id (`str`):
A user or an organization name and a repo name separated by a `/`.
revision (`str`, *optional*):
An optional Git revision id which can be a branch name, a tag, or a
commit hash.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if downloading from a dataset or space,
`None` or `"model"` if downloading from a model. Default is `None`.
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
local_dir (`str` or `Path`, *optional*:
If provided, the downloaded files will be placed under this directory, either as symlinks (default) or
regular files (see description for more details).
local_dir_use_symlinks (`"auto"` or `bool`, defaults to `"auto"`):
To be used with `local_dir`. If set to "auto", the cache directory will be used and the file will be either
duplicated or symlinked to the local directory depending on its size. It set to `True`, a symlink will be
created, no matter the file size. If set to `False`, the file will either be duplicated from cache (if
already exists) or downloaded from the Hub and not cached. See description for more details.
library_name (`str`, *optional*):
The name of the library to which the object corresponds.
library_version (`str`, *optional*):
The version of the library.
user_agent (`str`, `dict`, *optional*):
The user-agent info in the form of a dictionary or a string.
proxies (`dict`, *optional*):
Dictionary mapping protocol to the URL of the proxy passed to
`requests.request`.
etag_timeout (`float`, *optional*, defaults to `10`):
When fetching ETag, how many seconds to wait for the server to send
data before giving up which is passed to `requests.request`.
resume_download (`bool`, *optional*, defaults to `False):
If `True`, resume a previously interrupted download.
force_download (`bool`, *optional*, defaults to `False`):
Whether the file should be downloaded even if it already exists in the local cache.
token (`str`, `bool`, *optional*):
A token to be used for the download.
- If `True`, the token is read from the HuggingFace config
folder.
- If a string, it's used as the authentication token.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
allow_patterns (`List[str]` or `str`, *optional*):
If provided, only files matching at least one pattern are downloaded.
ignore_patterns (`List[str]` or `str`, *optional*):
If provided, files matching any of the patterns are not downloaded.
max_workers (`int`, *optional*):
Number of concurrent threads to download files (1 thread = 1 file download).
Defaults to 8.
tqdm_class (`tqdm`, *optional*):
If provided, overwrites the default behavior for the progress bar. Passed
argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior.
Note that the `tqdm_class` is not passed to each individual download.
Defaults to the custom HF progress bar that can be disabled by setting
`HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
Returns:
Local folder path (string) of repo snapshot
<Tip>
Raises the following errors:
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
if `token=True` and the token cannot be found.
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
ETag cannot be determined.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
</Tip>
"""
...

def flojoy(
original_function: Callable[..., DataContainer | dict[str, Any] | TypedDict | None]
Expand Down
20 changes: 19 additions & 1 deletion flojoy/flojoy_python.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from contextlib import ContextDecorator
import json
import os
import traceback
from functools import wraps

Expand All @@ -7,7 +9,7 @@
from .utils import PlotlyJSONEncoder
from typing import Callable, Any, Optional
from .job_result_utils import get_frontend_res_obj_from_result, get_dc_from_result
from .utils import send_to_socket
from .utils import send_to_socket, get_hf_hub_cache_path
from .config import logger
from .parameter_types import format_param_value
from inspect import signature
Expand Down Expand Up @@ -84,6 +86,20 @@ def __init__(
self.jobset_id = jobset_id
self.node_type = node_type

class cache_huggingface_to_flojoy(ContextDecorator):
""" Context manager to override the HF_HOME env var """
def __enter__(self):
self.old_env_var = os.environ.get("HF_HOME")
os.environ["HF_HOME"] = get_hf_hub_cache_path()
return self

def __exit__(self, *exc):
if(self.old_env_var is None):
del os.environ["HF_HOME"]
else:
os.environ["HF_HOME"] = self.old_env_var
return False


def display(
original_function: Callable[..., DataContainer | dict[str, Any]] | None = None
Expand Down Expand Up @@ -143,6 +159,8 @@ def SINE(dc_inputs:list[DataContainer], params:dict[str, Any]):
"""

def decorator(func: Callable[..., Optional[DataContainer | dict[str, Any]]]):
# Wrap func here to override the HF_HOME env var
func = cache_huggingface_to_flojoy()(func)
@wraps(func)
def wrapper(
node_id: str,
Expand Down
34 changes: 8 additions & 26 deletions flojoy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
import yaml
import requests
from dotenv import dotenv_values # type:ignore
from huggingface_hub import hf_hub_download as _hf_hub_download
from huggingface_hub import snapshot_download as _snapshot_download
# TODO(roulbac): Remove these imports once the nodes using them have been
# tested and updated to use huggingface_hub directly
from huggingface_hub import hf_hub_download
from huggingface_hub import snapshot_download
from .dao import Dao
from .config import FlojoyConfig, logger

Expand Down Expand Up @@ -45,30 +47,10 @@


# Make as a function to mock at test-time
def _get_hf_hub_cache_path() -> str:
return os.path.join(FLOJOY_CACHE_DIR, "cache", "hf_hub")


def hf_hub_download(*args, **kwargs):
if "cache_dir" not in kwargs:
kwargs["cache_dir"] = _get_hf_hub_cache_path()
else:
if kwargs["cache_dir"] != _get_hf_hub_cache_path():
raise ValueError(
f"Attempted to override cache_dir parameter, received {kwargs['cache_dir']} while the only alloed value is {_get_hf_hub_cache_path()}"
)
return _hf_hub_download(*args, **kwargs)


def snapshot_download(*args, **kwargs):
if "cache_dir" not in kwargs:
kwargs["cache_dir"] = _get_hf_hub_cache_path()
else:
if kwargs["cache_dir"] != _get_hf_hub_cache_path():
raise ValueError(
f"Attempted to override cache_dir parameter, received {kwargs['cache_dir']} while the only alloed value is {_get_hf_hub_cache_path()}"
)
return _snapshot_download(*args, **kwargs)
def get_hf_hub_cache_path() -> str:
"""Returns the path to the HuggingFace home directory (HF_HOME) within the Flojoy cache directory
This is used to cache huggingface artifacts within the Flojoy cache directory. """
return os.path.join(FLOJOY_CACHE_DIR, "cache", "huggingface")


env_vars = dotenv_values("../.env")
Expand Down
17 changes: 17 additions & 0 deletions tests/flojoy_python_test_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import os
from flojoy.flojoy_python import cache_huggingface_to_flojoy
from flojoy.utils import get_hf_hub_cache_path


def test_cache_huggingface_to_flojoy_decorator():

os.environ["HF_HOME"] = "test"

def test_func():
return os.environ.get("HF_HOME")

test_func = cache_huggingface_to_flojoy()(test_func)
assert os.environ.get("HF_HOME") == "test"
assert test_func() == get_hf_hub_cache_path()
assert os.environ.get("HF_HOME") == "test"

0 comments on commit 1a58240

Please sign in to comment.