Skip to content

Commit

Permalink
Simplify pydantic model validation
Browse files Browse the repository at this point in the history
  • Loading branch information
sonicaj committed Jan 15, 2025
1 parent 3137a84 commit 10e9348
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 19 deletions.
1 change: 1 addition & 0 deletions src/middlewared/middlewared/api/base/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .filesystem import * # noqa
from .iscsi import * # noqa
from .string import * # noqa
from .urls import * # noqa
from .user import * # noqa
8 changes: 8 additions & 0 deletions src/middlewared/middlewared/api/base/types/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Annotated

from pydantic import AfterValidator, HttpUrl

from middlewared.api.base.validators import https_only_check


HttpsOnlyURL = Annotated[HttpUrl, AfterValidator(https_only_check)]
8 changes: 8 additions & 0 deletions src/middlewared/middlewared/api/base/validators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import time
import re

from pydantic import HttpUrl


def match_validator(pattern: re.Pattern, explanation: str | None = None):
def validator(value: str):
Expand All @@ -21,3 +23,9 @@ def time_validator(value: str):
except TypeError:
raise ValueError('Time should be in 24 hour format like "18:00"')
return value


def https_only_check(url: HttpUrl) -> str:
if url.scheme != 'https':
raise ValueError('URL scheme must be https')
return str(url)
26 changes: 8 additions & 18 deletions src/middlewared/middlewared/api/v25_04_0/tn_connect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic import IPvAnyAddress, model_validator
from pydantic import IPvAnyAddress

from middlewared.api.base import BaseModel, ForUpdateMetaclass, NonEmptyString, single_argument_args
from middlewared.utils.lang import undefined
from middlewared.api.base.types import HttpsOnlyURL


__all__ = [
Expand All @@ -18,28 +18,18 @@ class TNCEntry(BaseModel):
status: NonEmptyString
status_reason: NonEmptyString
certificate: int | None
account_service_base_url: NonEmptyString
leca_service_base_url: NonEmptyString
tnc_base_url: NonEmptyString
account_service_base_url: HttpsOnlyURL
leca_service_base_url: HttpsOnlyURL
tnc_base_url: HttpsOnlyURL


@single_argument_args('tn_connect_update')
class TNCUpdateArgs(BaseModel, metaclass=ForUpdateMetaclass):
enabled: bool
ips: list[IPvAnyAddress]
account_service_base_url: NonEmptyString
leca_service_base_url: NonEmptyString
tnc_base_url: NonEmptyString

@model_validator(mode='after')
def validate_attrs(self):
for k in ('account_service_base_url', 'leca_service_base_url', 'tnc_base_url'):
value = getattr(self, k)
if value != undefined and not value.startswith('https://'):
raise ValueError(f'{k} must start with https://')
if value != undefined and not value.endswith('/'):
setattr(self, k, value + '/')
return self
account_service_base_url: HttpsOnlyURL
leca_service_base_url: HttpsOnlyURL
tnc_base_url: HttpsOnlyURL


class TNCUpdateResult(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def do_update(self, data):
db_payload = {
'enabled': data['enabled'],
'ips': data['ips'],
} | {k: data[k] for k in ('account_service_base_url', 'leca_service_base_url', 'tnc_base_url') if data.get(k)}
} | {k: data[k] for k in ('account_service_base_url', 'leca_service_base_url', 'tnc_base_url')}
if config['enabled'] is False and data['enabled'] is True:
# Finalization registration is triggered when claim token is generated
# We make sure there is no pending claim token
Expand Down

0 comments on commit 10e9348

Please sign in to comment.