Skip to content

Commit

Permalink
fixed filter by null condition
Browse files Browse the repository at this point in the history
  • Loading branch information
CosmoV committed Dec 8, 2023
1 parent fdd12b7 commit 8c93bfa
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 41 deletions.
32 changes: 23 additions & 9 deletions fastapi_jsonapi/data_layers/filtering/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def __init__(self, model: Type[TypeModel], filter_: dict, schema: Type[TypeSchem
self.filter_ = filter_
self.schema = schema

def _check_can_to_be_none(self, fields: list[ModelField]) -> bool:
"""
Return True if None is possible value for target field
"""
return not any(field_item.required for field_item in fields)

def create_filter(self, schema_field: ModelField, model_column, operator, value):
"""
Create sqlalchemy filter
Expand Down Expand Up @@ -78,17 +84,25 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value)
types = [i.type_ for i in fields]
clear_value = None
errors: List[str] = []
for i_type in types:
try:
if isinstance(value, list): # noqa: SIM108
clear_value = [i_type(item) for item in value]
else:
clear_value = i_type(value)
except (TypeError, ValueError) as ex:
errors.append(str(ex))

can_to_be_none = self._check_can_to_be_none(fields)

if value is None and can_to_be_none:
clear_value = None
else:
for i_type in types:
try:
if isinstance(value, list): # noqa: SIM108
clear_value = [i_type(item) for item in value]
else:
clear_value = i_type(value)
except (TypeError, ValueError) as ex:
errors.append(str(ex))

# Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку)
if clear_value is None and not any(not i_f.required for i_f in fields):
if clear_value is None and not can_to_be_none:
raise InvalidType(detail=", ".join(errors))

return getattr(model_column, self.operator)(clear_value)

def resolve(self) -> FilterAndJoins: # noqa: PLR0911
Expand Down
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
user_2_comment_for_one_u1_post,
user_2_posts,
user_3,
user_4,
workplace_1,
workplace_2,
)
Expand Down
13 changes: 0 additions & 13 deletions tests/fixtures/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,6 @@ async def user_3(async_session: AsyncSession):
await async_session.commit()


@async_fixture()
async def user_4(async_session: AsyncSession):
user = build_user(
email=None
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
yield user
await async_session.delete(user)
await async_session.commit()


async def build_user_bio(async_session: AsyncSession, user: User, **fields):
bio = UserBio(user=user, **fields)
async_session.add(bio)
Expand Down
42 changes: 24 additions & 18 deletions tests/test_api/test_api_sqla_with_includes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,33 +1751,39 @@ async def test_field_filters_with_values_from_different_models(
"meta": {"count": 0, "totalPages": 1},
}

@mark.parametrize("filter_dict, expected_email_is_null", [
param([{"name": "email", "op": "is_", "val": None}], True),
param([{"name": "email", "op": "isnot", "val": None}], False)
])
@mark.parametrize(
("filter_dict", "expected_email_is_null"),
[
param([{"name": "email", "op": "is_", "val": None}], True),
param([{"name": "email", "op": "isnot", "val": None}], False),
],
)
async def test_filter_by_null(
self,
app: FastAPI,
client: AsyncClient,
user_1: User,
user_4: User,
filter_dict,
expected_email_is_null
self,
app: FastAPI,
async_session: AsyncSession,
client: AsyncClient,
user_1: User,
user_2: User,
filter_dict: dict,
expected_email_is_null: bool,
):
assert user_1.email is not None
assert user_4.email is None
user_2.email = None
await async_session.commit()

target_user = user_2 if expected_email_is_null else user_1

url = app.url_path_for("get_user_list")
params = {"filter": dumps(filter_dict)}

response = await client.get(url, params=params)
assert response.status_code == 200, response.text

data = response.json()
assert response.status_code == status.HTTP_200_OK, response.text

assert len(data['data']) == 1
assert (data['data'][0]['attributes']['email'] is None) == expected_email_is_null
response_json = response.json()

assert len(data := response_json["data"]) == 1
assert data[0]["id"] == str(target_user.id)
assert data[0]["attributes"]["email"] == target_user.email

async def test_composite_filter_by_one_field(
self,
Expand Down

0 comments on commit 8c93bfa

Please sign in to comment.