Skip to content

Commit

Permalink
Simplify pydantic model validation
Browse files Browse the repository at this point in the history
  • Loading branch information
sonicaj committed Jan 15, 2025
1 parent 86e9beb commit c347ddd
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
8 changes: 8 additions & 0 deletions src/middlewared/middlewared/api/base/validators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import time
import re

from pydantic import HttpUrl


def match_validator(pattern: re.Pattern, explanation: str | None = None):
def validator(value: str):
Expand All @@ -21,3 +23,9 @@ def time_validator(value: str):
except TypeError:
raise ValueError('Time should be in 24 hour format like "18:00"')
return value


def https_only_check(url: HttpUrl) -> HttpUrl:
if url.scheme != 'https':
raise ValueError('URL scheme must be https')
return url
24 changes: 10 additions & 14 deletions src/middlewared/middlewared/api/v25_04_0/tn_connect.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pydantic import IPvAnyAddress, model_validator
from typing import Annotated

from pydantic import HttpUrl, IPvAnyAddress, AfterValidator

from middlewared.api.base import BaseModel, ForUpdateMetaclass, NonEmptyString, single_argument_args
from middlewared.api.base.validators import https_only_check
from middlewared.utils.lang import undefined


Expand All @@ -10,6 +13,9 @@
]


HttpsURL = Annotated[HttpUrl, AfterValidator(https_only_check)]


class TNCEntry(BaseModel):
id: int
enabled: bool
Expand All @@ -27,19 +33,9 @@ class TNCEntry(BaseModel):
class TNCUpdateArgs(BaseModel, metaclass=ForUpdateMetaclass):
enabled: bool
ips: list[IPvAnyAddress]
account_service_base_url: NonEmptyString
leca_service_base_url: NonEmptyString
tnc_base_url: NonEmptyString

@model_validator(mode='after')
def validate_attrs(self):
for k in ('account_service_base_url', 'leca_service_base_url', 'tnc_base_url'):
value = getattr(self, k)
if value != undefined and not value.startswith('https://'):
raise ValueError(f'{k} must start with https://')
if value != undefined and not value.endswith('/'):
setattr(self, k, value + '/')
return self
account_service_base_url: HttpsURL
leca_service_base_url: HttpsURL
tnc_base_url: HttpsURL


class TNCUpdateResult(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,16 @@ async def do_update(self, data):
"""
config = await self.config()
data = config | data
# We have to normalize url fields as they are url objects atm
url_fields = {k: str(data[k]) for k in ('account_service_base_url', 'leca_service_base_url', 'tnc_base_url')}
data = data | url_fields

await self.validate_data(config, data)

db_payload = {
'enabled': data['enabled'],
'ips': data['ips'],
} | {k: data[k] for k in ('account_service_base_url', 'leca_service_base_url', 'tnc_base_url') if data.get(k)}
} | url_fields
if config['enabled'] is False and data['enabled'] is True:
# Finalization registration is triggered when claim token is generated
# We make sure there is no pending claim token
Expand Down

0 comments on commit c347ddd

Please sign in to comment.