Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NAS-133728 / 25.10 / Add NotRequired default value to our API #15462

Merged
merged 15 commits into from
Jan 30, 2025
58 changes: 53 additions & 5 deletions src/middlewared/middlewared/api/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,58 @@


__all__ = ["BaseModel", "ForUpdateMetaclass", "query_result", "query_result_item", "added_event_model",
"changed_event_model", "removed_event_model", "single_argument_args", "single_argument_result"]
"changed_event_model", "removed_event_model", "single_argument_args", "single_argument_result",
"NotRequired"]


class BaseModel(PydanticBaseModel):
class _NotRequiredMixin(PydanticBaseModel):
@model_serializer(mode="wrap")
def serialize_basemodel(self, serializer):
obj = serializer(self)
if isinstance(obj, dict):
return {
k: v
for k, v in obj.items()
if v is not undefined
}
return obj


NotRequired = undefined
"""Use as the default value for fields that may be excluded from the model."""


class _BaseModelMetaclass(ModelMetaclass):
"""Any BaseModel subclass that uses the NotRequired default value on any of its fields receives the appropriate
model serializer."""
# FIXME: In the future we want to set defaults on all fields that are not required. Remove this metaclass,
# `_NotRequiredMixin`, and `NotRequired` at that time.

def __new__(mcls, name, bases, namespaces, **kwargs):
skip_patching = kwargs.pop("__BaseModelMetaclass_skip_patching", False)

cls = super().__new__(mcls, name, bases, namespaces, **kwargs)

if skip_patching or name == "BaseModel":
return cls

for field in cls.model_fields.values():
if getattr(field, "default", None) is undefined:
return create_model(
cls.__name__,
__base__=(cls, _NotRequiredMixin),
__module__=cls.__module__,
__cls_kwargs__={"__BaseModelMetaclass_skip_patching": True},
**{
k: (v.annotation, v)
for k, v in cls.model_fields.items()
}
)
else:
return cls


class BaseModel(PydanticBaseModel, metaclass=_BaseModelMetaclass):
model_config = ConfigDict(
extra="forbid",
strict=True,
Expand Down Expand Up @@ -51,7 +99,7 @@ def model_dump(
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | typing.Literal['none', 'warn', 'error'] = True,
serialize_as_any: bool = False,
serialize_as_any: bool = True, # so that nested models set to `NotRequired` do not serialize
) -> dict[str, typing.Any]:
return self.__pydantic_serializer__.to_python(
self,
Expand Down Expand Up @@ -102,7 +150,7 @@ def to_previous(cls, value):
return value


class ForUpdateMetaclass(ModelMetaclass):
class ForUpdateMetaclass(_BaseModelMetaclass):
"""
Using this metaclass on a model will change all of its fields default values to `undefined`.
Such a model might be instantiated with any subset of its fields, which can be useful to validate request bodies
Expand All @@ -112,7 +160,7 @@ class ForUpdateMetaclass(ModelMetaclass):
def __new__(mcls, name, bases, namespaces, **kwargs):
skip_patching = kwargs.pop("__ForUpdateMetaclass_skip_patching", False)

cls = super().__new__(mcls, name, bases, namespaces, **kwargs)
cls = ModelMetaclass.__new__(mcls, name, bases, namespaces, **kwargs)

if skip_patching:
return cls
Expand Down
114 changes: 110 additions & 4 deletions src/middlewared/middlewared/pytest/unit/api/base/test_excluded.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pydantic import Field, Secret
import pytest

from middlewared.api.base import BaseModel, Excluded, excluded_field
from middlewared.api.base.handler.accept import accept_params
from middlewared.api.base import BaseModel, Excluded, excluded_field, ForUpdateMetaclass, NotRequired
from middlewared.api.base.handler.accept import accept_params, validate_model
from middlewared.service_exception import ValidationErrors


Expand All @@ -18,9 +19,114 @@ class CreateArgs(BaseModel):
data: CreateObject


def check_serialization(test_model, test_cases):
for args, dump in test_cases:
result = validate_model(test_model, args)
assert result == dump, (args, dump, result)


def test_excluded_field():
with pytest.raises(ValidationErrors) as ve:
accept_params(CreateObject, [{"id": 1, "name": "Ivan"}])
accept_params(CreateArgs, [{"id": 1, "name": "Ivan"}])

assert ve.value.errors[0].attribute == "id"
assert ve.value.errors[0].attribute == "data.id"
assert ve.value.errors[0].errmsg == "Extra inputs are not permitted"


def test_not_required():
class NestedModel(BaseModel):
a: int = NotRequired

class NotRequiredTestModel(BaseModel):
b: int
c: int = 3
d: int = NotRequired
e: NestedModel
f: NestedModel = Field(default_factory=NestedModel)
# default_factory must be used here
g: NestedModel = NotRequired
h: list[NestedModel] = NotRequired
i_: int = Field(alias="i", default=NotRequired)
j: Secret[int] = NotRequired

test_cases = (
(
{"b": 2, "e": {}},
{"b": 2, "c": 3, "e": {}, "f": {}}
),
(
{"b": 2, "e": {"a": 1}},
{"b": 2, "c": 3, "e": {"a": 1}, "f": {}}
),
(
{"b": 2, "c": -3, "e": {}},
{"b": 2, "c": -3, "e": {}, "f": {}}
),
(
{"b": 2, "d": 4, "e": {}},
{"b": 2, "c": 3, "d": 4, "e": {}, "f": {}}
),
(
{"b": 2, "e": {}, "f": {}},
{"b": 2, "c": 3, "e": {}, "f": {}}
),
(
{"b": 2, "e": {}, "f": {"a": 1}},
{"b": 2, "c": 3, "e": {}, "f": {"a": 1}}
),
(
{"b": 2, "e": {}, "g": {}},
{"b": 2, "c": 3, "e": {}, "f": {}, "g": {}}
),
(
{"b": 2, "e": {}, "g": {"a": 1}},
{"b": 2, "c": 3, "e": {}, "f": {}, "g": {"a": 1}}
),
(
{"b": 2, "e": {}, "h": []},
{"b": 2, "c": 3, "e": {}, "f": {}, "h": []}
),
(
{"b": 2, "e": {}, "h": [{}]},
{"b": 2, "c": 3, "e": {}, "f": {}, "h": [{}]}
),
(
{"b": 2, "e": {}, "h": [{"a": 1}]},
{"b": 2, "c": 3, "e": {}, "f": {}, "h": [{"a": 1}]}
),
(
{"b": 2, "e": {}, "h": [{"a": 1}, {}]},
{"b": 2, "c": 3, "e": {}, "f": {}, "h": [{"a": 1}, {}]}
),
(
{"b": 2, "e": {}, "i": 4},
{"b": 2, "c": 3, "e": {}, "f": {}, "i": 4}
),
(
{"b": 2, "e": {}, "j": 4},
{"b": 2, "c": 3, "e": {}, "f": {}, "j": 4}
),
)
check_serialization(NotRequiredTestModel, test_cases)


def test_update_metaclass():
class NestedModel(BaseModel):
a: int

class UpdateModel(BaseModel, metaclass=ForUpdateMetaclass):
b: int
c: NestedModel

test_cases = (
(
{}, {}
),
(
{"b": 2}, {"b": 2}
),
(
{"c": {"a": 1}}, {"c": {"a": 1}}
),
)
check_serialization(UpdateModel, test_cases)
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@

def test_dump_by_alias():
class AliasModel(BaseModel):
field1_: int = Field(..., alias='field1')
field1_: int = Field(alias='field1')
field2: str
field3_: bool = Field(alias='field3', default=False)

class AliasModelResult(BaseModel):
result: AliasModel

result = {'field1': 1, 'field2': 'two'}
dump = serialize_result(AliasModelResult, result, False)

assert dump == {'field1': 1, 'field2': 'two', 'field3': False}
result = serialize_result(AliasModelResult, {'field1': 1, 'field2': 'two'}, True)
assert result == {'field1': 1, 'field2': 'two', 'field3': False}