Skip to content

Commit

Permalink
Simplify pydantic model validation
Browse files Browse the repository at this point in the history
  • Loading branch information
sonicaj committed Jan 15, 2025
1 parent 3137a84 commit 8475085
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 18 deletions.
8 changes: 8 additions & 0 deletions src/middlewared/middlewared/api/base/validators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import time
import re

from pydantic import HttpUrl


def match_validator(pattern: re.Pattern, explanation: str | None = None):
def validator(value: str):
Expand All @@ -21,3 +23,9 @@ def time_validator(value: str):
except TypeError:
raise ValueError('Time should be in 24 hour format like "18:00"')
return value


def https_only_check(url: HttpUrl) -> str:
if url.scheme != 'https':
raise ValueError('URL scheme must be https')
return str(url)
30 changes: 13 additions & 17 deletions src/middlewared/middlewared/api/v25_04_0/tn_connect.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pydantic import IPvAnyAddress, model_validator
from typing import Annotated

from pydantic import HttpUrl, IPvAnyAddress, AfterValidator

from middlewared.api.base import BaseModel, ForUpdateMetaclass, NonEmptyString, single_argument_args
from middlewared.api.base.validators import https_only_check
from middlewared.utils.lang import undefined


Expand All @@ -10,6 +13,9 @@
]


HttpsURL = Annotated[HttpUrl, AfterValidator(https_only_check)]


class TNCEntry(BaseModel):
id: int
enabled: bool
Expand All @@ -18,28 +24,18 @@ class TNCEntry(BaseModel):
status: NonEmptyString
status_reason: NonEmptyString
certificate: int | None
account_service_base_url: NonEmptyString
leca_service_base_url: NonEmptyString
tnc_base_url: NonEmptyString
account_service_base_url: HttpsURL
leca_service_base_url: HttpsURL
tnc_base_url: HttpsURL


@single_argument_args('tn_connect_update')
class TNCUpdateArgs(BaseModel, metaclass=ForUpdateMetaclass):
enabled: bool
ips: list[IPvAnyAddress]
account_service_base_url: NonEmptyString
leca_service_base_url: NonEmptyString
tnc_base_url: NonEmptyString

@model_validator(mode='after')
def validate_attrs(self):
for k in ('account_service_base_url', 'leca_service_base_url', 'tnc_base_url'):
value = getattr(self, k)
if value != undefined and not value.startswith('https://'):
raise ValueError(f'{k} must start with https://')
if value != undefined and not value.endswith('/'):
setattr(self, k, value + '/')
return self
account_service_base_url: HttpsURL
leca_service_base_url: HttpsURL
tnc_base_url: HttpsURL


class TNCUpdateResult(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def do_update(self, data):
db_payload = {
'enabled': data['enabled'],
'ips': data['ips'],
} | {k: data[k] for k in ('account_service_base_url', 'leca_service_base_url', 'tnc_base_url') if data.get(k)}
} | {k: data[k] for k in ('account_service_base_url', 'leca_service_base_url', 'tnc_base_url')}
if config['enabled'] is False and data['enabled'] is True:
# Finalization registration is triggered when claim token is generated
# We make sure there is no pending claim token
Expand Down

0 comments on commit 8475085

Please sign in to comment.