diff --git a/src/middlewared/middlewared/api/base/validators.py b/src/middlewared/middlewared/api/base/validators.py index efe4df49e8f89..dc4ace37624cc 100644 --- a/src/middlewared/middlewared/api/base/validators.py +++ b/src/middlewared/middlewared/api/base/validators.py @@ -1,6 +1,8 @@ from datetime import time import re +from pydantic import HttpUrl + def match_validator(pattern: re.Pattern, explanation: str | None = None): def validator(value: str): @@ -21,3 +23,9 @@ def time_validator(value: str): except TypeError: raise ValueError('Time should be in 24 hour format like "18:00"') return value + + +def https_only_check(url: HttpUrl) -> str: + if url.scheme != 'https': + raise ValueError('URL scheme must be https') + return str(url) diff --git a/src/middlewared/middlewared/api/v25_04_0/tn_connect.py b/src/middlewared/middlewared/api/v25_04_0/tn_connect.py index 03d578a93bc34..e84c4ab7bcb94 100644 --- a/src/middlewared/middlewared/api/v25_04_0/tn_connect.py +++ b/src/middlewared/middlewared/api/v25_04_0/tn_connect.py @@ -1,6 +1,9 @@ -from pydantic import IPvAnyAddress, model_validator +from typing import Annotated + +from pydantic import HttpUrl, IPvAnyAddress, AfterValidator from middlewared.api.base import BaseModel, ForUpdateMetaclass, NonEmptyString, single_argument_args +from middlewared.api.base.validators import https_only_check from middlewared.utils.lang import undefined @@ -10,6 +13,9 @@ ] +HttpsURL = Annotated[HttpUrl, AfterValidator(https_only_check)] + + class TNCEntry(BaseModel): id: int enabled: bool @@ -18,28 +24,18 @@ class TNCEntry(BaseModel): status: NonEmptyString status_reason: NonEmptyString certificate: int | None - account_service_base_url: NonEmptyString - leca_service_base_url: NonEmptyString - tnc_base_url: NonEmptyString + account_service_base_url: HttpsURL + leca_service_base_url: HttpsURL + tnc_base_url: HttpsURL @single_argument_args('tn_connect_update') class TNCUpdateArgs(BaseModel, metaclass=ForUpdateMetaclass): enabled: bool ips: list[IPvAnyAddress] - account_service_base_url: NonEmptyString - leca_service_base_url: NonEmptyString - tnc_base_url: NonEmptyString - - @model_validator(mode='after') - def validate_attrs(self): - for k in ('account_service_base_url', 'leca_service_base_url', 'tnc_base_url'): - value = getattr(self, k) - if value != undefined and not value.startswith('https://'): - raise ValueError(f'{k} must start with https://') - if value != undefined and not value.endswith('/'): - setattr(self, k, value + '/') - return self + account_service_base_url: HttpsURL + leca_service_base_url: HttpsURL + tnc_base_url: HttpsURL class TNCUpdateResult(BaseModel): diff --git a/src/middlewared/middlewared/plugins/truenas_connect/update.py b/src/middlewared/middlewared/plugins/truenas_connect/update.py index bb025d15d2e17..e0cb82efc9091 100644 --- a/src/middlewared/middlewared/plugins/truenas_connect/update.py +++ b/src/middlewared/middlewared/plugins/truenas_connect/update.py @@ -101,7 +101,7 @@ async def do_update(self, data): db_payload = { 'enabled': data['enabled'], 'ips': data['ips'], - } | {k: data[k] for k in ('account_service_base_url', 'leca_service_base_url', 'tnc_base_url') if data.get(k)} + } | {k: data[k] for k in ('account_service_base_url', 'leca_service_base_url', 'tnc_base_url')} if config['enabled'] is False and data['enabled'] is True: # Finalization registration is triggered when claim token is generated # We make sure there is no pending claim token