Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement RDFProxy-compliance checkers for models #222

Merged
merged 6 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.9.6
hooks:
# Run the linter.
- id: ruff
# Run the formatter.
- id: ruff-format
3 changes: 2 additions & 1 deletion rdfproxy/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rdfproxy.mapper import _ModelBindingsMapper
from rdfproxy.sparqlwrapper import SPARQLWrapper
from rdfproxy.utils._types import _TModelInstance
from rdfproxy.utils.checkers.model_checker import check_model
from rdfproxy.utils.checkers.query_checker import check_query
from rdfproxy.utils.models import Page, QueryParameters

Expand Down Expand Up @@ -41,7 +42,7 @@ def __init__(
) -> None:
self._target = target
self._query = check_query(query)
self._model = model
self._model = check_model(model)

self.sparqlwrapper = SPARQLWrapper(self._target)

Expand Down
7 changes: 2 additions & 5 deletions rdfproxy/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
from pandas.api.typing import DataFrameGroupBy
from pydantic import BaseModel
from rdfproxy.utils._types import ModelBoolPredicate, _TModelInstance
from rdfproxy.utils.mapper_utils import (
_is_list_basemodel_type,
_is_list_type,
get_model_bool_predicate,
)
from rdfproxy.utils._typing import _is_list_basemodel_type, _is_list_type
from rdfproxy.utils.mapper_utils import get_model_bool_predicate
from rdfproxy.utils.utils import CurryModel, FieldsBindingsMap


Expand Down
12 changes: 12 additions & 0 deletions rdfproxy/utils/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,15 @@ class QueryParseException(Exception):
parseQuery raises a pyparsing.exceptions.ParseException,
which would require to introduce pyparsing as a dependency just for testing.
"""


class RDFProxyModelValidationException(Exception):
"""Exception for indicating that a model is invalid according to RDFProxy semantics"""


class RDFProxyGroupByException(RDFProxyModelValidationException):
"""Exception for indicating invalid group_by definitions."""


class RDFProxyModelBoolException(RDFProxyModelValidationException):
"""Exception for indicating invalid model_bool definitions."""
36 changes: 36 additions & 0 deletions rdfproxy/utils/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""RDFProxy typing utils."""

import types
from typing import Any, TypeGuard, get_args, get_origin

from pydantic import BaseModel


def _is_type(obj: type | None, _type: type) -> bool:
"""Check if an obj is type _type or a GenericAlias with origin _type."""
return (obj is _type) or (get_origin(obj) is _type)


def _is_list_type(obj: type | None) -> bool:
"""Check if obj is a list type."""
return _is_type(obj, list)


def _is_list_basemodel_type(obj: type | None) -> bool:
"""Check if a type is list[pydantic.BaseModel]."""
return (get_origin(obj) is list) and all(
issubclass(cls, BaseModel) for cls in get_args(obj)
)


def _is_pydantic_model_class(obj: Any) -> TypeGuard[type[BaseModel]]:
"""Predicate for checking if an object is a Pydantic model class."""
return isinstance(obj, type) and issubclass(obj, BaseModel)


def _is_union_pydantic_model_type(obj: Any) -> bool:
"""Predicate for checking if a type is union type of a Pydantic model."""
is_union_type: bool = get_origin(obj) is types.UnionType
has_any_model: bool = any(_is_pydantic_model_class(obj) for obj in get_args(obj))

return is_union_type and has_any_model
71 changes: 71 additions & 0 deletions rdfproxy/utils/checkers/model_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Functionality for performing RDFProxy-compliance checks on Pydantic models."""

from rdfproxy.utils._exceptions import RDFProxyGroupByException
from rdfproxy.utils._types import _TModelInstance
from rdfproxy.utils._typing import _is_list_type
from rdfproxy.utils.model_utils import model_traverse
from rdfproxy.utils.utils import compose_left


def _check_group_by_config(model: type[_TModelInstance]) -> type[_TModelInstance]:
"""Model checker for group_by config settings and grouping model semantics."""
model_group_by_value: str | None = model.model_config.get("group_by")
model_has_list_field: bool = any(
_is_list_type(value.annotation) for value in model.model_fields.values()
)

match model_group_by_value, model_has_list_field:
case None, False:
return model

case None, True:
raise RDFProxyGroupByException(
f"Model '{model.__name__}' has a list-annotated field "
"but does not specify 'group_by' in its model_config."
)

case str(), False:
raise RDFProxyGroupByException(
f"Model '{model.__name__}' does not specify "
"a grouping target (i.e. a list-annotated field)."
)

case str(), True:
applicable_keys: list[str] = [
k
for k, v in model.model_fields.items()
if not _is_list_type(v.annotation)
]

if model_group_by_value in applicable_keys:
return model

applicable_fields_message: str = (
"No applicable fields."
if not applicable_keys
else f"Applicable grouping field(s): {', '.join(applicable_keys)}"
)

raise RDFProxyGroupByException(
f"Requested grouping key '{model_group_by_value}' does not denote "
f"an applicable grouping field. {applicable_fields_message}"
)

case _: # pragma: no cover
raise AssertionError("This should never happen.")


def _check_model_bool_config(model: type[_TModelInstance]) -> type[_TModelInstance]:
"""Model checker for model_bool config settings.

This is a stub for now, the model_bool feature is in flux right now,
see issues #176 and #219.
"""
return model


def check_model(model: type[_TModelInstance]) -> type[_TModelInstance]:
composite = compose_left(_check_group_by_config, _check_model_bool_config)
_model, *_ = model_traverse(model, composite) # exhaust iterator for full traversal

return _model
19 changes: 1 addition & 18 deletions rdfproxy/utils/mapper_utils.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,10 @@
from collections.abc import Iterable
from typing import TypeGuard, get_args, get_origin
from typing import TypeGuard

from pydantic import BaseModel
from rdfproxy.utils._types import ModelBoolPredicate, _TModelBoolValue


def _is_type(obj: type | None, _type: type) -> bool:
"""Check if an obj is type _type or a GenericAlias with origin _type."""
return (obj is _type) or (get_origin(obj) is _type)


def _is_list_type(obj: type | None) -> bool:
"""Check if obj is a list type."""
return _is_type(obj, list)


def _is_list_basemodel_type(obj: type | None) -> bool:
"""Check if a type is list[pydantic.BaseModel]."""
return (get_origin(obj) is list) and all(
issubclass(cls, BaseModel) for cls in get_args(obj)
)


def default_model_bool_predicate(model: BaseModel) -> bool:
"""Default predicate for determining model truthiness.

Expand Down
48 changes: 48 additions & 0 deletions rdfproxy/utils/model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""RDFProxy model utils."""

from collections.abc import Callable, Iterator
from typing import TypeVar, get_args

from rdfproxy.utils._types import _TModelInstance
from rdfproxy.utils._typing import (
_is_list_basemodel_type,
_is_pydantic_model_class,
_is_union_pydantic_model_type,
)


T = TypeVar("T")


def model_traverse(
model: type[_TModelInstance],
f: Callable[[type[_TModelInstance]], T],
_self: bool = True,
) -> Iterator[T]:
"""Recursively traverse a model and apply a callable to all (sub)models.

If the _self flag is set to True, the callable will be applied to the root model model as well.
Recursive calls intentionally do not pass on the _self flag.
"""
if _self:
yield f(model)

for _, field_info in model.model_fields.items():
if _is_list_basemodel_type(list_model := field_info.annotation):
nested_model, *_ = get_args(list_model)
yield from model_traverse(nested_model, f)

elif _is_pydantic_model_class(nested_model := field_info.annotation):
yield from model_traverse(nested_model, f)

elif _is_union_pydantic_model_type(union := field_info.annotation):
_model_filter = filter(_is_pydantic_model_class, get_args(union))
nested_model = next(_model_filter)

_multi_model_union = next(_model_filter, False)
assert not _multi_model_union, "Multiple model unions are not supported."

yield from model_traverse(nested_model, f)

else:
continue
114 changes: 114 additions & 0 deletions tests/unit/tests_checkers/test_model_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Unit test for rdfproxy.utils.checkers.model_checker."""

from pydantic import BaseModel, create_model
import pytest
from rdfproxy.utils._exceptions import RDFProxyGroupByException
from rdfproxy.utils._types import ConfigDict
from rdfproxy.utils.checkers.model_checker import check_model


class Invalid1(BaseModel):
"""group_by without corresponding scalar field and without list field."""

model_config = ConfigDict(group_by="x")


class Invalid2(BaseModel):
"""list field but no group_by config at all."""

x: list[int]


class Invalid3(BaseModel):
"""group_by references list field."""

model_config = ConfigDict(group_by="x")
x: list[int]


class Invalid4(BaseModel):
"""group_by references list field + additional fields."""

model_config = ConfigDict(group_by="x")
x: list[int]
y: int
z: int


class Invalid5(BaseModel):
"""legal group_by but no list field."""

model_config = ConfigDict(group_by="x")
x: int


class Invalid6(BaseModel):
"""group_by without corresponding scalar field."""

model_config = ConfigDict(group_by="x")
y: list[int]


class Valid1(BaseModel):
"""Simple valid group_by model."""

model_config = ConfigDict(group_by="x")

x: int
y: list[int]


class Valid2(BaseModel):
"""Simple valid group_by model with additional fields"""

model_config = ConfigDict(group_by="x")

x: int
y: int
z: list[int]


class Valid3(BaseModel):
"""Simple valid group_by model with aggregated nested model."""

model_config = ConfigDict(group_by="x")

x: int
y: int
z: list[Valid2]


invalid_group_by_models = [
Invalid1,
Invalid2,
Invalid3,
Invalid4,
Invalid5,
Invalid6,
]

valid_group_by_models = [Valid1, Valid2, Valid3]


@pytest.mark.parametrize("model", invalid_group_by_models)
def test_check_invalid_group_by_models(model):
with pytest.raises(RDFProxyGroupByException):
check_model(model)


@pytest.mark.parametrize("model", valid_group_by_models)
def test_check_valid_group_by_models(model):
assert check_model(model)


@pytest.mark.parametrize("model", invalid_group_by_models)
def test_check_nested_invalid_group_by_models(model):
nested_model = create_model("NestedModel", nested=(model, ...))
with pytest.raises(RDFProxyGroupByException):
check_model(nested_model)


@pytest.mark.parametrize("model", valid_group_by_models)
def test_check_nested_valid_group_by_models(model):
nested_model = create_model("NestedModel", nested=(model, ...))
assert check_model(nested_model)
Loading