From 522739815d8f3abde3afd5f42cf96a8109e332a1 Mon Sep 17 00:00:00 2001 From: Yurii Karabas <1998uriyyo@gmail.com> Date: Tue, 5 Sep 2023 00:25:07 +0300 Subject: [PATCH] Add hook func for apply cond --- fastapi_filters/ext/sqlalchemy.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/fastapi_filters/ext/sqlalchemy.py b/fastapi_filters/ext/sqlalchemy.py index 040099b..abbbcc1 100644 --- a/fastapi_filters/ext/sqlalchemy.py +++ b/fastapi_filters/ext/sqlalchemy.py @@ -91,6 +91,7 @@ def _default_apply_filter(*_: Any) -> Any: ApplyFilterFunc: TypeAlias = Callable[[TSelectable, EntityNamespace, str, AbstractFilterOperator, Any], TSelectable] +AddFilterConditionFunc: TypeAlias = Callable[[TSelectable, str, Any], TSelectable] custom_apply_filter: ConfigVar[ApplyFilterFunc[Any]] = ConfigVar( "apply_filter", @@ -105,6 +106,7 @@ def _apply_filter( op: AbstractFilterOperator, val: Any, apply_filter: Optional[ApplyFilterFunc[TSelectable]] = None, + add_condition: Optional[AddFilterConditionFunc[TSelectable]] = None, ) -> TSelectable: custom_apply_filter_impl = custom_apply_filter.get() @@ -124,6 +126,12 @@ def _apply_filter( except KeyError: raise NotImplementedError(f"Operator {op} is not implemented") from None + if add_condition: + try: + return add_condition(stmt, field, cond) + except NotImplementedError: + pass + return stmt.where(cond) # type: ignore[arg-type] @@ -134,6 +142,7 @@ def apply_filters( remapping: Optional[Mapping[str, str]] = None, additional: Optional[EntityNamespace] = None, apply_filter: Optional[ApplyFilterFunc[TSelectable]] = None, + add_condition: Optional[AddFilterConditionFunc[TSelectable]] = None, ) -> TSelectable: if isinstance(filters, FilterSet): filters = filters.filter_values @@ -145,7 +154,7 @@ def apply_filters( field = remapping.get(field, field) for op, val in field_filters.items(): - stmt = _apply_filter(stmt, ns, field, op, val, apply_filter) + stmt = _apply_filter(stmt, ns, field, op, val, apply_filter, add_condition) return stmt @@ -184,6 +193,7 @@ def apply_filters_and_sorting( remapping: Optional[Mapping[str, str]] = None, additional: Optional[EntityNamespace] = None, apply_filter: Optional[ApplyFilterFunc[TSelectable]] = None, + add_condition: Optional[AddFilterConditionFunc[TSelectable]] = None, ) -> TSelectable: stmt = apply_filters(stmt, filters, remapping=remapping, additional=additional, apply_filter=apply_filter) stmt = apply_sorting(stmt, sorting, remapping=remapping, additional=additional)