Skip to content

Commit

Permalink
Add 'gated' search parameter (#2448)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin authored Aug 14, 2024
1 parent c9c39b8 commit 411e378
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,7 @@ def list_models(
# Search-query parameter
filter: Union[str, Iterable[str], None] = None,
author: Optional[str] = None,
gated: Optional[bool] = None,
library: Optional[Union[str, List[str]]] = None,
language: Optional[Union[str, List[str]]] = None,
model_name: Optional[str] = None,
Expand Down Expand Up @@ -1624,6 +1625,10 @@ def list_models(
author (`str`, *optional*):
A string which identify the author (user or organization) of the
returned models
gated (`bool`, *optional*):
A boolean to filter models on the Hub that are gated or not. By default, all models are returned.
If `gated=True` is passed, only gated models are returned.
If `gated=False` is passed, only non-gated models are returned.
library (`str` or `List`, *optional*):
A string or list of strings of foundational libraries models were
originally trained from, such as pytorch, tensorflow, or allennlp.
Expand Down Expand Up @@ -1749,6 +1754,8 @@ def list_models(
# Handle other query params
if author:
params["author"] = author
if gated is not None:
params["gated"] = gated
if pipeline_tag:
params["pipeline_tag"] = pipeline_tag
search_list = []
Expand Down Expand Up @@ -1795,6 +1802,7 @@ def list_datasets(
author: Optional[str] = None,
benchmark: Optional[Union[str, List[str]]] = None,
dataset_name: Optional[str] = None,
gated: Optional[bool] = None,
language_creators: Optional[Union[str, List[str]]] = None,
language: Optional[Union[str, List[str]]] = None,
multilinguality: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -1826,6 +1834,10 @@ def list_datasets(
dataset_name (`str`, *optional*):
A string or list of strings that can be used to identify datasets on
the Hub by its name, such as `SQAC` or `wikineural`
gated (`bool`, *optional*):
A boolean to filter datasets on the Hub that are gated or not. By default, all datasets are returned.
If `gated=True` is passed, only gated datasets are returned.
If `gated=False` is passed, only non-gated datasets are returned.
language_creators (`str` or `List`, *optional*):
A string or list of strings that can be used to identify datasets on
the Hub with how the data was curated, such as `crowdsourced` or
Expand Down Expand Up @@ -1954,6 +1966,8 @@ def list_datasets(
# Handle other query params
if author:
params["author"] = author
if gated is not None:
params["gated"] = gated
search_list = []
if dataset_name:
search_list.append(dataset_name)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,6 +1785,14 @@ def test_list_models_expand_cannot_be_used_with_other_params(self):
with self.assertRaises(ValueError):
next(self._api.list_models(expand=["author"], cardData=True))

def test_list_models_gated_only(self):
for model in self._api.list_models(expand=["gated"], gated=True, limit=5):
assert model.gated in ("auto", "manual")

def test_list_models_non_gated_only(self):
for model in self._api.list_models(expand=["gated"], gated=False, limit=5):
assert model.gated is False

def test_model_info(self):
model = self._api.model_info(repo_id=DUMMY_MODEL_ID)
self.assertIsInstance(model, ModelInfo)
Expand Down Expand Up @@ -2009,6 +2017,14 @@ def test_list_datasets_expand_cannot_be_used_with_full(self):
with self.assertRaises(ValueError):
next(self._api.list_datasets(expand=["author"], full=True))

def test_list_datasets_gated_only(self):
for dataset in self._api.list_datasets(expand=["gated"], gated=True, limit=5):
assert dataset.gated in ("auto", "manual")

def test_list_datasets_non_gated_only(self):
for dataset in self._api.list_datasets(expand=["gated"], gated=False, limit=5):
assert dataset.gated is False

def test_filter_datasets_with_card_data(self):
assert any(dataset.card_data is not None for dataset in self._api.list_datasets(full=True, limit=50))
assert all(dataset.card_data is None for dataset in self._api.list_datasets(full=False, limit=50))
Expand Down

0 comments on commit 411e378

Please sign in to comment.