From 804c741373b15b79a13b9ac17bdda8dd41e15db3 Mon Sep 17 00:00:00 2001 From: Jonathan de Bruin Date: Sun, 22 Dec 2024 15:38:19 +0100 Subject: [PATCH] Add native support for expressions via filters (#50) --- pyalex/api.py | 99 +++++++++++++++++++++++++++++++++++++------- tests/test_pyalex.py | 73 +++++++++++++++++++++++++++----- 2 files changed, 147 insertions(+), 25 deletions(-) diff --git a/pyalex/api.py b/pyalex/api.py index 73afb01..61ef20d 100644 --- a/pyalex/api.py +++ b/pyalex/api.py @@ -31,6 +31,32 @@ def __setattr__(self, key, value): ) +class or_(dict): + pass + + +class _LogicalExpression: + token = None + + def __init__(self, value): + self.value = value + + def __str__(self) -> str: + return f"{self.token}{self.value}" + + +class not_(_LogicalExpression): + token = "!" + + +class gt_(_LogicalExpression): + token = ">" + + +class lt_(_LogicalExpression): + token = "<" + + def _quote_oa_value(v): """Prepare a value for the OpenAlex API. @@ -41,30 +67,40 @@ def _quote_oa_value(v): if isinstance(v, bool): return str(v).lower() + if isinstance(v, _LogicalExpression) and isinstance(v.value, str): + v.value = quote_plus(v.value) + return v + if isinstance(v, str): return quote_plus(v) return v -def _flatten_kv(d, prefix=""): +def _flatten_kv(d, prefix=None, logical="+"): + if prefix is None and not isinstance(d, dict): + raise ValueError("prefix should be set if d is not a dict") + if isinstance(d, dict): + logical_subd = "|" if isinstance(d, or_) else logical + t = [] for k, v in d.items(): - if isinstance(v, list): - t.extend([f"{prefix}.{k}:{_quote_oa_value(i)}" for i in v]) - else: - new_prefix = f"{prefix}.{k}" if prefix else f"{k}" - x = _flatten_kv(v, prefix=new_prefix) - t.append(x) + x = _flatten_kv( + v, prefix=f"{prefix}.{k}" if prefix else f"{k}", logical=logical_subd + ) + t.append(x) return ",".join(t) + elif isinstance(d, list): + list_str = logical.join([f"{_quote_oa_value(i)}" for i in d]) + return f"{prefix}:{list_str}" else: return f"{prefix}:{_quote_oa_value(d)}" def _params_merge(params, add_params): - for k, _v in add_params.items(): + for k in add_params.keys(): if ( k in params and isinstance(params[k], dict) @@ -113,6 +149,18 @@ def invert_abstract(inv_index): return " ".join(map(lambda x: x[0], sorted(l_inv, key=lambda x: x[1]))) +def _wrap_values_nested_dict(d, func): + for k, v in d.items(): + if isinstance(v, dict): + d[k] = _wrap_values_nested_dict(v, func) + elif isinstance(v, list): + d[k] = [func(i) for i in v] + else: + d[k] = func(v) + + return d + + class QueryError(ValueError): pass @@ -207,9 +255,6 @@ class BaseOpenAlex: def __init__(self, params=None): self.params = params - def _get_multi_items(self, record_list): - return self.filter(openalex_id="|".join(record_list)).get() - def _full_collection_name(self): if self.params is not None and "q" in self.params.keys(): return ( @@ -234,10 +279,14 @@ def __getattr__(self, key): def __getitem__(self, record_id): if isinstance(record_id, list): - return self._get_multi_items(record_id) + if len(record_id) > 100: + raise ValueError("OpenAlex does not support more than 100 ids") + + return self.filter_or(openalex_id=record_id).get(per_page=len(record_id)) return self._get_from_url( - f"{self._full_collection_name()}/{record_id}", return_meta=False + f"{self._full_collection_name()}/{_quote_oa_value(record_id)}", + return_meta=False, ) @property @@ -322,7 +371,10 @@ def paginate(self, method="cursor", page=1, per_page=None, cursor="*", n_max=100 def random(self): return self.__getitem__("random") - def _add_params(self, argument, new_params): + def _add_params(self, argument, new_params, raise_if_exists=False): + if raise_if_exists: + raise NotImplementedError("raise_if_exists is not implemented") + if self.params is None: self.params = {argument: new_params} elif argument in self.params and isinstance(self.params[argument], dict): @@ -336,6 +388,25 @@ def filter(self, **kwargs): self._add_params("filter", kwargs) return self + def filter_and(self, **kwargs): + return self.filter(**kwargs) + + def filter_or(self, **kwargs): + self._add_params("filter", or_(kwargs), raise_if_exists=False) + return self + + def filter_not(self, **kwargs): + self._add_params("filter", _wrap_values_nested_dict(kwargs, not_)) + return self + + def filter_gt(self, **kwargs): + self._add_params("filter", _wrap_values_nested_dict(kwargs, gt_)) + return self + + def filter_lt(self, **kwargs): + self._add_params("filter", _wrap_values_nested_dict(kwargs, lt_)) + return self + def search_filter(self, **kwargs): self._add_params("filter", {f"{k}.search": v for k, v in kwargs.items()}) return self diff --git a/tests/test_pyalex.py b/tests/test_pyalex.py index 9b3855f..d4cc772 100644 --- a/tests/test_pyalex.py +++ b/tests/test_pyalex.py @@ -117,7 +117,12 @@ def test_multi_works(): # the work to extract the referenced works of w = Works()["W2741809807"] - assert len(Works()[w["referenced_works"]]) == 25 + assert len(Works()[w["referenced_works"]]) >= 38 + + assert ( + len(Works().filter_or(openalex_id=w["referenced_works"]).get(per_page=100)) + >= 38 + ) def test_works_multifilter(): @@ -278,33 +283,80 @@ def test_random_publishers(): def test_and_operator(): - # https://github.com/J535D165/pyalex/issues/11 - url = "https://api.openalex.org/works?filter=institutions.country_code:tw,institutions.country_code:hk,institutions.country_code:us,publication_year:2022" + urls = [ + "https://api.openalex.org/works?filter=institutions.country_code:tw,institutions.country_code:hk,institutions.country_code:us,publication_year:2022", + "https://api.openalex.org/works?filter=institutions.country_code:tw+hk+us,publication_year:2022", + ] assert ( - url - == Works() + Works() .filter( institutions={"country_code": ["tw", "hk", "us"]}, publication_year=2022 ) .url + in urls ) assert ( - url - == Works() + Works() .filter(institutions={"country_code": "tw"}) .filter(institutions={"country_code": "hk"}) .filter(institutions={"country_code": "us"}) .filter(publication_year=2022) .url + in urls ) assert ( - url - == Works() + Works() .filter(institutions={"country_code": ["tw", "hk"]}) .filter(institutions={"country_code": "us"}) .filter(publication_year=2022) .url + in urls + ) + + +def test_or_operator(): + assert ( + Works() + .filter_or( + institutions={"country_code": ["tw", "hk", "us"]}, publication_year=2022 + ) + .url + == "https://api.openalex.org/works?filter=institutions.country_code:tw|hk|us,publication_year:2022" + ) + + +def test_not_operator(): + assert ( + Works() + .filter_not(institutions={"country_code": "us"}) + .filter(publication_year=2022) + .url + == "https://api.openalex.org/works?filter=institutions.country_code:!us,publication_year:2022" + ) + + +def test_not_operator_list(): + assert ( + Works() + .filter_not(institutions={"country_code": ["tw", "hk", "us"]}) + .filter(publication_year=2022) + .url + == "https://api.openalex.org/works?filter=institutions.country_code:!tw+!hk+!us,publication_year:2022" + ) + + +@pytest.mark.skip("Wait for feedback on issue by OpenAlex") +def test_combined_operators(): + # works: + # https://api.openalex.org/works?filter=publication_year:>2022,publication_year:!2023 + + # doesn't work + # https://api.openalex.org/works?filter=publication_year:>2022+!2023 + + assert ( + Works().filter_gt(publication_year=2022).filter_not(publication_year=2023).url + == "https://api.openalex.org/works?filter=publication_year:>2022+!2023" ) @@ -359,11 +411,10 @@ def test_filter_urlencoding(): ) -@pytest.mark.skip("This test is not working due to inconsistencies in the API.") def test_urlencoding_list(): assert ( Works() - .filter( + .filter_or( doi=[ "https://doi.org/10.1207/s15327809jls0703&4_2", "https://doi.org/10.1001/jama.264.8.944b",