From 411e378861ae3389bfc0388ce2cd5818f7ffbe89 Mon Sep 17 00:00:00 2001 From: Lucain Date: Wed, 14 Aug 2024 12:18:30 +0200 Subject: [PATCH] Add 'gated' search parameter (#2448) --- src/huggingface_hub/hf_api.py | 14 ++++++++++++++ tests/test_hf_api.py | 16 ++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 932b8ddf69..3ebf7c387c 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -1595,6 +1595,7 @@ def list_models( # Search-query parameter filter: Union[str, Iterable[str], None] = None, author: Optional[str] = None, + gated: Optional[bool] = None, library: Optional[Union[str, List[str]]] = None, language: Optional[Union[str, List[str]]] = None, model_name: Optional[str] = None, @@ -1624,6 +1625,10 @@ def list_models( author (`str`, *optional*): A string which identify the author (user or organization) of the returned models + gated (`bool`, *optional*): + A boolean to filter models on the Hub that are gated or not. By default, all models are returned. + If `gated=True` is passed, only gated models are returned. + If `gated=False` is passed, only non-gated models are returned. library (`str` or `List`, *optional*): A string or list of strings of foundational libraries models were originally trained from, such as pytorch, tensorflow, or allennlp. @@ -1749,6 +1754,8 @@ def list_models( # Handle other query params if author: params["author"] = author + if gated is not None: + params["gated"] = gated if pipeline_tag: params["pipeline_tag"] = pipeline_tag search_list = [] @@ -1795,6 +1802,7 @@ def list_datasets( author: Optional[str] = None, benchmark: Optional[Union[str, List[str]]] = None, dataset_name: Optional[str] = None, + gated: Optional[bool] = None, language_creators: Optional[Union[str, List[str]]] = None, language: Optional[Union[str, List[str]]] = None, multilinguality: Optional[Union[str, List[str]]] = None, @@ -1826,6 +1834,10 @@ def list_datasets( dataset_name (`str`, *optional*): A string or list of strings that can be used to identify datasets on the Hub by its name, such as `SQAC` or `wikineural` + gated (`bool`, *optional*): + A boolean to filter datasets on the Hub that are gated or not. By default, all datasets are returned. + If `gated=True` is passed, only gated datasets are returned. + If `gated=False` is passed, only non-gated datasets are returned. language_creators (`str` or `List`, *optional*): A string or list of strings that can be used to identify datasets on the Hub with how the data was curated, such as `crowdsourced` or @@ -1954,6 +1966,8 @@ def list_datasets( # Handle other query params if author: params["author"] = author + if gated is not None: + params["gated"] = gated search_list = [] if dataset_name: search_list.append(dataset_name) diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index b7d91992cc..07a6dd68b5 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -1785,6 +1785,14 @@ def test_list_models_expand_cannot_be_used_with_other_params(self): with self.assertRaises(ValueError): next(self._api.list_models(expand=["author"], cardData=True)) + def test_list_models_gated_only(self): + for model in self._api.list_models(expand=["gated"], gated=True, limit=5): + assert model.gated in ("auto", "manual") + + def test_list_models_non_gated_only(self): + for model in self._api.list_models(expand=["gated"], gated=False, limit=5): + assert model.gated is False + def test_model_info(self): model = self._api.model_info(repo_id=DUMMY_MODEL_ID) self.assertIsInstance(model, ModelInfo) @@ -2009,6 +2017,14 @@ def test_list_datasets_expand_cannot_be_used_with_full(self): with self.assertRaises(ValueError): next(self._api.list_datasets(expand=["author"], full=True)) + def test_list_datasets_gated_only(self): + for dataset in self._api.list_datasets(expand=["gated"], gated=True, limit=5): + assert dataset.gated in ("auto", "manual") + + def test_list_datasets_non_gated_only(self): + for dataset in self._api.list_datasets(expand=["gated"], gated=False, limit=5): + assert dataset.gated is False + def test_filter_datasets_with_card_data(self): assert any(dataset.card_data is not None for dataset in self._api.list_datasets(full=True, limit=50)) assert all(dataset.card_data is None for dataset in self._api.list_datasets(full=False, limit=50))