diff --git a/src/middlewared/middlewared/api/base/validators.py b/src/middlewared/middlewared/api/base/validators.py index efe4df49e8f89..ff6b5fad2dab7 100644 --- a/src/middlewared/middlewared/api/base/validators.py +++ b/src/middlewared/middlewared/api/base/validators.py @@ -1,6 +1,8 @@ from datetime import time import re +from pydantic import HttpUrl + def match_validator(pattern: re.Pattern, explanation: str | None = None): def validator(value: str): @@ -21,3 +23,9 @@ def time_validator(value: str): except TypeError: raise ValueError('Time should be in 24 hour format like "18:00"') return value + + +def https_only_check(url: HttpUrl) -> HttpUrl: + if url.scheme != 'https': + raise ValueError('URL scheme must be https') + return url diff --git a/src/middlewared/middlewared/api/v25_04_0/tn_connect.py b/src/middlewared/middlewared/api/v25_04_0/tn_connect.py index 03d578a93bc34..75e2443d04bc4 100644 --- a/src/middlewared/middlewared/api/v25_04_0/tn_connect.py +++ b/src/middlewared/middlewared/api/v25_04_0/tn_connect.py @@ -1,6 +1,9 @@ -from pydantic import IPvAnyAddress, model_validator +from typing import Annotated + +from pydantic import HttpUrl, IPvAnyAddress, AfterValidator from middlewared.api.base import BaseModel, ForUpdateMetaclass, NonEmptyString, single_argument_args +from middlewared.api.base.validators import https_only_check from middlewared.utils.lang import undefined @@ -10,6 +13,9 @@ ] +HttpsURL = Annotated[HttpUrl, AfterValidator(https_only_check)] + + class TNCEntry(BaseModel): id: int enabled: bool @@ -27,19 +33,9 @@ class TNCEntry(BaseModel): class TNCUpdateArgs(BaseModel, metaclass=ForUpdateMetaclass): enabled: bool ips: list[IPvAnyAddress] - account_service_base_url: NonEmptyString - leca_service_base_url: NonEmptyString - tnc_base_url: NonEmptyString - - @model_validator(mode='after') - def validate_attrs(self): - for k in ('account_service_base_url', 'leca_service_base_url', 'tnc_base_url'): - value = getattr(self, k) - if value != undefined and not value.startswith('https://'): - raise ValueError(f'{k} must start with https://') - if value != undefined and not value.endswith('/'): - setattr(self, k, value + '/') - return self + account_service_base_url: HttpsURL + leca_service_base_url: HttpsURL + tnc_base_url: HttpsURL class TNCUpdateResult(BaseModel): diff --git a/src/middlewared/middlewared/plugins/truenas_connect/update.py b/src/middlewared/middlewared/plugins/truenas_connect/update.py index b1756ad9c8b26..b95e26cf057b9 100644 --- a/src/middlewared/middlewared/plugins/truenas_connect/update.py +++ b/src/middlewared/middlewared/plugins/truenas_connect/update.py @@ -95,13 +95,16 @@ async def do_update(self, data): """ config = await self.config() data = config | data + # We have to normalize url fields as they are url objects atm + url_fields = {k: str(data[k]) for k in ('account_service_base_url', 'leca_service_base_url', 'tnc_base_url')} + data = data | url_fields await self.validate_data(config, data) db_payload = { 'enabled': data['enabled'], 'ips': data['ips'], - } | {k: data[k] for k in ('account_service_base_url', 'leca_service_base_url', 'tnc_base_url') if data.get(k)} + } | url_fields if config['enabled'] is False and data['enabled'] is True: # Finalization registration is triggered when claim token is generated # We make sure there is no pending claim token