Skip to content

Commit

Permalink
forbid extra fields in BaseModel (apache#44306)
Browse files Browse the repository at this point in the history
* use extra=forbid

* fix backfill test

* use Response model to fix plugin test failures

* use BaseModel and StrictBaseModel

* fix

* update docstring
  • Loading branch information
rawwar authored Jan 30, 2025
1 parent 97ebffe commit f7bb606
Show file tree
Hide file tree
Showing 25 changed files with 142 additions and 68 deletions.
12 changes: 12 additions & 0 deletions airflow/api_fastapi/core_api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,15 @@ class BaseModel(PydanticBaseModel):
"""

model_config = ConfigDict(from_attributes=True, populate_by_name=True)


class StrictBaseModel(BaseModel):
"""
StrictBaseModel is a base Pydantic model for REST API that does not allow any extra fields.
Use this class for models that should not have any extra fields in the payload.
:meta private:
"""

model_config = ConfigDict(from_attributes=True, populate_by_name=True, extra="forbid")
10 changes: 5 additions & 5 deletions airflow/api_fastapi/core_api/datamodels/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@

from pydantic import Field, field_validator

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.utils.log.secrets_masker import redact


class DagScheduleAssetReference(BaseModel):
class DagScheduleAssetReference(StrictBaseModel):
"""DAG schedule reference serializer for assets."""

dag_id: str
created_at: datetime
updated_at: datetime


class TaskOutletAssetReference(BaseModel):
class TaskOutletAssetReference(StrictBaseModel):
"""Task outlet reference serializer for assets."""

dag_id: str
Expand Down Expand Up @@ -84,7 +84,7 @@ class AssetAliasCollectionResponse(BaseModel):
total_entries: int


class DagRunAssetReference(BaseModel):
class DagRunAssetReference(StrictBaseModel):
"""DAGRun serializer for asset responses."""

run_id: str
Expand Down Expand Up @@ -141,7 +141,7 @@ class QueuedEventCollectionResponse(BaseModel):
total_entries: int


class CreateAssetEventsBody(BaseModel):
class CreateAssetEventsBody(StrictBaseModel):
"""Create asset events request."""

asset_id: int
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/datamodels/backfills.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

from datetime import datetime

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.models.backfill import ReprocessBehavior


class BackfillPostBody(BaseModel):
class BackfillPostBody(StrictBaseModel):
"""Object used for create backfill request."""

dag_id: str
Expand Down
6 changes: 3 additions & 3 deletions airflow/api_fastapi/core_api/datamodels/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from pydantic import Discriminator, Field, Tag

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel

# Common Bulk Data Models
T = TypeVar("T")
Expand Down Expand Up @@ -57,7 +57,7 @@ class BulkActionNotOnExistence(enum.Enum):
SKIP = "skip"


class BulkBaseAction(BaseModel, Generic[T]):
class BulkBaseAction(StrictBaseModel, Generic[T]):
"""Base class for bulk actions."""

action: BulkAction = Field(..., description="The action to be performed on the entities.")
Expand Down Expand Up @@ -88,7 +88,7 @@ def _action_discriminator(action: Any) -> str:
return BulkAction(action["action"]).value


class BulkBody(BaseModel, Generic[T]):
class BulkBody(StrictBaseModel, Generic[T]):
"""Serializer for bulk entity operations."""

actions: list[
Expand Down
8 changes: 4 additions & 4 deletions airflow/api_fastapi/core_api/datamodels/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
# under the License.
from __future__ import annotations

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import StrictBaseModel


class ConfigOption(BaseModel):
class ConfigOption(StrictBaseModel):
"""Config option."""

key: str
Expand All @@ -32,7 +32,7 @@ def text_format(self):
return f"{self.key} = {self.value}"


class ConfigSection(BaseModel):
class ConfigSection(StrictBaseModel):
"""Config Section Schema."""

name: str
Expand All @@ -53,7 +53,7 @@ def text_format(self):
return f"[{self.name}]\n" + "\n".join(option.text_format for option in self.options) + "\n"


class Config(BaseModel):
class Config(StrictBaseModel):
"""List of config sections with their options."""

sections: list[ConfigSection]
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/datamodels/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pydantic import Field, field_validator
from pydantic_core.core_schema import ValidationInfo

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.utils.log.secrets_masker import redact


Expand Down Expand Up @@ -76,7 +76,7 @@ class ConnectionTestResponse(BaseModel):


# Request Models
class ConnectionBody(BaseModel):
class ConnectionBody(StrictBaseModel):
"""Connection Serializer for requests body."""

connection_id: str = Field(serialization_alias="conn_id", max_length=200, pattern=r"^[\w.-]+$")
Expand Down
10 changes: 5 additions & 5 deletions airflow/api_fastapi/core_api/datamodels/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from pydantic import AwareDatetime, Field, NonNegativeInt, computed_field, model_validator

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.models import DagRun
from airflow.utils import timezone
from airflow.utils.state import DagRunState
Expand All @@ -37,14 +37,14 @@ class DAGRunPatchStates(str, Enum):
FAILED = DagRunState.FAILED


class DAGRunPatchBody(BaseModel):
class DAGRunPatchBody(StrictBaseModel):
"""DAG Run Serializer for PATCH requests."""

state: DAGRunPatchStates | None = None
note: str | None = Field(None, max_length=1000)


class DAGRunClearBody(BaseModel):
class DAGRunClearBody(StrictBaseModel):
"""DAG Run serializer for clear endpoint body."""

dry_run: bool = True
Expand Down Expand Up @@ -78,7 +78,7 @@ class DAGRunCollectionResponse(BaseModel):
total_entries: int


class TriggerDAGRunPostBody(BaseModel):
class TriggerDAGRunPostBody(StrictBaseModel):
"""Trigger DAG Run Serializer for POST body."""

dag_run_id: str | None = None
Expand Down Expand Up @@ -109,7 +109,7 @@ def logical_date(self) -> datetime:
return timezone.utcnow()


class DAGRunsBatchBody(BaseModel):
class DAGRunsBatchBody(StrictBaseModel):
"""List DAG Runs body for batch endpoint."""

order_by: str | None = None
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/datamodels/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
field_validator,
)

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.api_fastapi.core_api.datamodels.dag_tags import DagTagResponse
from airflow.configuration import conf

Expand Down Expand Up @@ -92,7 +92,7 @@ def file_token(self) -> str:
return serializer.dumps(self.fileloc)


class DAGPatchBody(BaseModel):
class DAGPatchBody(StrictBaseModel):
"""Dag Serializer for updatable bodies."""

is_paused: bool
Expand Down
6 changes: 3 additions & 3 deletions airflow/api_fastapi/core_api/datamodels/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pydantic import BeforeValidator, ConfigDict, Field

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel


def _call_function(function: Callable[[], int]) -> int:
Expand Down Expand Up @@ -60,7 +60,7 @@ class PoolCollectionResponse(BaseModel):
total_entries: int


class PoolPatchBody(BaseModel):
class PoolPatchBody(StrictBaseModel):
"""Pool serializer for patch bodies."""

model_config = ConfigDict(populate_by_name=True, from_attributes=True)
Expand All @@ -71,7 +71,7 @@ class PoolPatchBody(BaseModel):
include_deferred: bool | None = None


class PoolBody(BasePool):
class PoolBody(BasePool, StrictBaseModel):
"""Pool serializer for post bodies."""

pool: str = Field(alias="name", max_length=256)
Expand Down
8 changes: 4 additions & 4 deletions airflow/api_fastapi/core_api/datamodels/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
model_validator,
)

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.api_fastapi.core_api.datamodels.job import JobResponse
from airflow.api_fastapi.core_api.datamodels.trigger import TriggerResponse
from airflow.utils.state import TaskInstanceState
Expand Down Expand Up @@ -97,7 +97,7 @@ class TaskDependencyCollectionResponse(BaseModel):
dependencies: list[TaskDependencyResponse]


class TaskInstancesBatchBody(BaseModel):
class TaskInstancesBatchBody(StrictBaseModel):
"""Task Instance body for get batch."""

dag_ids: list[str] | None = None
Expand Down Expand Up @@ -159,7 +159,7 @@ class TaskInstanceHistoryCollectionResponse(BaseModel):
total_entries: int


class ClearTaskInstancesBody(BaseModel):
class ClearTaskInstancesBody(StrictBaseModel):
"""Request body for Clear Task Instances endpoint."""

dry_run: bool = True
Expand Down Expand Up @@ -195,7 +195,7 @@ def validate_model(cls, data: Any) -> Any:
return data


class PatchTaskInstanceBody(BaseModel):
class PatchTaskInstanceBody(StrictBaseModel):
"""Request body for Clear Task Instances endpoint."""

new_state: TaskInstanceState | None = None
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/datamodels/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pydantic import ConfigDict, Field, model_validator

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.models.base import ID_LEN
from airflow.typing_compat import Self
from airflow.utils.log.secrets_masker import redact
Expand Down Expand Up @@ -52,7 +52,7 @@ def redact_val(self) -> Self:
return self


class VariableBody(BaseModel):
class VariableBody(StrictBaseModel):
"""Variable serializer for bodies."""

key: str = Field(max_length=ID_LEN)
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/datamodels/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def value_to_string(cls, v):
return str(v) if v is not None else None


class XComCollection(BaseModel):
"""List of XCom items."""
class XComCollectionResponse(BaseModel):
"""XCom Collection serializer for responses."""

xcom_entries: list[XComResponse]
total_entries: int
Loading

0 comments on commit f7bb606

Please sign in to comment.