Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
CosmoV committed Jan 16, 2024
1 parent 5e25556 commit 79fb3b0
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 46 deletions.
196 changes: 180 additions & 16 deletions fastapi_jsonapi/data_layers/filtering/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
"""Helper to create sqlalchemy filters according to filter querystring parameter"""
import inspect
import logging
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Type,
Union,
)

from pydantic import BaseModel
from pydantic import BaseConfig, BaseModel
from pydantic.fields import ModelField
from pydantic.validators import _VALIDATORS, find_validators
from sqlalchemy import and_, not_, or_
from sqlalchemy.orm import aliased
from sqlalchemy.orm.attributes import InstrumentedAttribute
Expand All @@ -19,14 +24,22 @@

from fastapi_jsonapi.data_typing import TypeModel, TypeSchema
from fastapi_jsonapi.exceptions import InvalidFilters, InvalidType
from fastapi_jsonapi.exceptions.json_api import HTTPException
from fastapi_jsonapi.schema import get_model_field, get_relationships

log = logging.getLogger(__name__)

RELATIONSHIP_SPLITTER = "."

# The mapping with validators using by to cast raw value to instance of target type
REGISTERED_PYDANTIC_TYPES: Dict[Type, List[Callable]] = dict(_VALIDATORS)

cast_failed = object()

RelationshipPath = str


class RelationshipInfo(BaseModel):
class RelationshipFilteringInfo(BaseModel):
target_schema: Type[TypeSchema]
model: Type[TypeModel]
aliased_model: AliasedClass
Expand All @@ -36,6 +49,129 @@ class Config:
arbitrary_types_allowed = True


def check_can_be_none(fields: list[ModelField]) -> bool:
"""
Return True if None is possible value for target field
"""
return any(field_item.allow_none for field_item in fields)


def separate_types(types: List[Type]) -> Tuple[List[Type], List[Type]]:
"""
Separates the types into two kinds.
The first are those for which there are already validators
defined by pydantic - str, int, datetime and some other built-in types.
The second are all other types for which the `arbitrary_types_allowed`
config is applied when defining the pydantic model
"""
pydantic_types = [
# skip format
type_
for type_ in types
if type_ in REGISTERED_PYDANTIC_TYPES
]
userspace_types = [
# skip format
type_
for type_ in types
if type_ not in REGISTERED_PYDANTIC_TYPES
]
return pydantic_types, userspace_types


def validator_requires_model_field(validator: Callable) -> bool:
"""
Check if validator accepts the `field` param
:param validator:
:return:
"""
signature = inspect.signature(validator)
parameters = signature.parameters

if "field" not in parameters:
return False

field_param = parameters["field"]
field_type = field_param.annotation

return field_type == "ModelField" or field_type is ModelField


def cast_value_with_pydantic(
types: List[Type],
value: Any,
schema_field: ModelField,
) -> Tuple[Optional[Any], List[str]]:
result_value, errors = None, []

for type_to_cast in types:
for validator in find_validators(type_to_cast, BaseConfig):
args = [value]
# TODO: some other way to get all the validator's dependencies?
if validator_requires_model_field(validator):
args.append(schema_field)
try:
result_value = validator(*args)
except Exception as ex:
errors.append(str(ex))
else:
return result_value, errors

return None, errors


def cast_iterable_with_pydantic(
types: List[Type],
values: List,
schema_field: ModelField,
) -> Tuple[List, List[str]]:
type_cast_failed = False
failed_values = []

result_values: List[Any] = []
errors: List[str] = []

for value in values:
casted_value, cast_errors = cast_value_with_pydantic(
types,
value,
schema_field,
)
errors.extend(cast_errors)

if casted_value is None:
type_cast_failed = True
failed_values.append(value)

continue

result_values.append(casted_value)

if type_cast_failed:
msg = f"Can't parse items {failed_values} of value {values}"
raise InvalidFilters(msg, pointer=schema_field.name)

return result_values, errors


def cast_value_with_scheme(field_types: List[Type], value: Any) -> Tuple[Any, List[str]]:
errors: List[str] = []
casted_value = cast_failed

for field_type in field_types:
try:
if isinstance(value, list): # noqa: SIM108
casted_value = [field_type(item) for item in value]
else:
casted_value = field_type(value)
except (TypeError, ValueError) as ex:
errors.append(str(ex))

return casted_value, errors


def build_filter_expression(
schema_field: ModelField,
model_column: InstrumentedAttribute,
Expand All @@ -61,26 +197,51 @@ def build_filter_expression(
if schema_field.sub_fields:
fields = list(schema_field.sub_fields)

can_be_none = check_can_be_none(fields)

if value is None:
if can_be_none:
return getattr(model_column, operator)(value)

raise InvalidFilters(detail=f"The field `{schema_field.name}` can't be null")

types = [i.type_ for i in fields]
casted_value = None
errors: List[str] = []

for cast_type in [field.type_ for field in fields]:
try:
casted_value = [cast_type(item) for item in value] if isinstance(value, list) else cast_type(value)
except (TypeError, ValueError) as ex:
errors.append(str(ex))
pydantic_types, userspace_types = separate_types(types)

if pydantic_types:
func = cast_value_with_pydantic
if isinstance(value, list):
func = cast_iterable_with_pydantic
casted_value, errors = func(pydantic_types, value, schema_field)

all_fields_required = all(field.required for field in fields)
if casted_value is None and userspace_types:
log.warning("Filtering by user type values is not properly tested yet. Use this on your own risk.")

if casted_value is None and all_fields_required:
raise InvalidType(detail=", ".join(errors))
casted_value, errors = cast_value_with_scheme(types, value)

if casted_value is cast_failed:
raise InvalidType(
detail=f"Can't cast filter value `{value}` to arbitrary type.",
errors=[HTTPException(status_code=InvalidType.status_code, detail=str(err)) for err in errors],
)

# Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку)
if casted_value is None and not can_be_none:
raise InvalidType(
detail=", ".join(errors),
pointer=schema_field.name,
)

return getattr(model_column, operator)(casted_value)


def is_terminal_node(filter_item: dict) -> bool:
"""
If node shape is:
{
"name: ...,
"op: ...,
Expand Down Expand Up @@ -166,7 +327,7 @@ def gather_relationships_info(
relationship_path: List[str],
collected_info: dict,
target_relationship_idx: int = 0,
) -> dict[RelationshipPath, RelationshipInfo]:
) -> dict[RelationshipPath, RelationshipFilteringInfo]:
is_last_relationship = target_relationship_idx == len(relationship_path) - 1
target_relationship_path = RELATIONSHIP_SPLITTER.join(
relationship_path[: target_relationship_idx + 1],
Expand All @@ -184,7 +345,7 @@ def gather_relationships_info(
schema,
target_relationship_name,
)
collected_info[target_relationship_path] = RelationshipInfo(
collected_info[target_relationship_path] = RelationshipFilteringInfo(
target_schema=target_schema,
model=target_model,
aliased_model=aliased(target_model),
Expand All @@ -207,7 +368,7 @@ def gather_relationships(
entrypoint_model: Type[TypeModel],
schema: Type[TypeSchema],
relationship_paths: Set[str],
) -> dict[RelationshipPath, RelationshipInfo]:
) -> dict[RelationshipPath, RelationshipFilteringInfo]:
collected_info = {}
for relationship_path in sorted(relationship_paths):
gather_relationships_info(
Expand Down Expand Up @@ -238,19 +399,22 @@ def build_filter_expressions(
filter_item: Union[dict, list],
target_schema: Type[TypeSchema],
target_model: Type[TypeModel],
relationships_info: dict[RelationshipPath, RelationshipInfo],
relationships_info: dict[RelationshipPath, RelationshipFilteringInfo],
) -> Union[BinaryExpression, BooleanClauseList]:
"""
Return sqla expressions.
Builds sqlalchemy expression which can be use
in where condition: query(Model).where(build_filter_expressions(...))
"""
if is_terminal_node(filter_item):
name = filter_item["name"]
target_schema = target_schema

if is_relationship_filter(name):
*relationship_path, field_name = name.split(RELATIONSHIP_SPLITTER)
relationship_info: RelationshipInfo = relationships_info[RELATIONSHIP_SPLITTER.join(relationship_path)]
relationship_info: RelationshipFilteringInfo = relationships_info[
RELATIONSHIP_SPLITTER.join(relationship_path)
]
model_column = get_model_column(
model=relationship_info.aliased_model,
schema=relationship_info.target_schema,
Expand Down
48 changes: 18 additions & 30 deletions tests/test_data_layers/test_filtering/test_sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,41 @@
from typing import Any
from unittest.mock import Mock
from unittest.mock import MagicMock, Mock

from fastapi import status
from pydantic import BaseModel
from pytest import raises # noqa PT013

from fastapi_jsonapi.data_layers.filtering.sqlalchemy import Node
from fastapi_jsonapi.exceptions.json_api import InvalidType
from fastapi_jsonapi.data_layers.filtering.sqlalchemy import (
build_filter_expression,
)
from fastapi_jsonapi.exceptions import InvalidType


class TestNode:
class TestFilteringFuncs:
def test_user_type_cast_success(self):
class UserType:
def __init__(self, *args, **kwargs):
self.value = "success"
pass

class ModelSchema(BaseModel):
user_type: UserType
value: UserType

class Config:
arbitrary_types_allowed = True

node = Node(
model=Mock(),
filter_={
"name": "user_type",
"op": "eq",
"val": Any,
},
schema=ModelSchema,
)

model_column_mock = Mock()
model_column_mock.eq = lambda clear_value: clear_value
model_column_mock = MagicMock()

clear_value = node.create_filter(
schema_field=ModelSchema.__fields__["user_type"],
build_filter_expression(
schema_field=ModelSchema.__fields__["value"],
model_column=model_column_mock,
operator=Mock(),
operator="__eq__",
value=Any,
)
assert isinstance(clear_value, UserType)
assert clear_value.value == "success"

model_column_mock.__eq__.assert_called_once()

call_arg = model_column_mock.__eq__.call_args[0]
isinstance(call_arg, UserType)

def test_user_type_cast_fail(self):
class UserType:
Expand All @@ -55,14 +49,8 @@ class ModelSchema(BaseModel):
class Config:
arbitrary_types_allowed = True

node = Node(
model=Mock(),
filter_=Mock(),
schema=ModelSchema,
)

with raises(InvalidType) as exc_info:
node.create_filter(
build_filter_expression(
schema_field=ModelSchema.__fields__["user_type"],
model_column=Mock(),
operator=Mock(),
Expand Down

0 comments on commit 79fb3b0

Please sign in to comment.