From 25b3e3220f7dd2a9efe7cdd111508c4eacda0bf7 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Tue, 17 Sep 2024 15:01:49 +0100 Subject: [PATCH] [fix] pydantic cann allow extra fields --- bioimage_embed/config.py | 47 +++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/bioimage_embed/config.py b/bioimage_embed/config.py index 7c505ed9..a610567c 100644 --- a/bioimage_embed/config.py +++ b/bioimage_embed/config.py @@ -12,12 +12,14 @@ """ -import pytorch_lightning +# TODO need a way to copy signatures from original classes for validation from omegaconf import OmegaConf from bioimage_embed import augmentations as augs import os from dataclasses import field from pydantic.dataclasses import dataclass + +# from dataclasses import dataclass from typing import List, Optional, Dict, Any from pydantic import Field @@ -25,7 +27,7 @@ from . import utils -@dataclass +@dataclass(config=dict(extra="allow")) class Recipe: _target_: str = "types.SimpleNamespace" model: str = "resnet18_vae" @@ -59,7 +61,7 @@ class Recipe: # Use the ALbumentations .to_dict() method to get the dictionary # that pydantic can use -@dataclass +@dataclass(config=dict(extra="allow")) class ATransform: _target_: str = "albumentations.from_dict" _convert_: str = "object" @@ -72,7 +74,7 @@ class ATransform: # VisionWrapper is a helper class for applying albumentations pipelines for image augmentations in autoencoding -@dataclass +@dataclass(config=dict(extra="allow")) class Transform: _target_: str = "bioimage_embed.augmentations.VisionWrapper" _convert_: str = "object" @@ -82,7 +84,7 @@ class Transform: ) -@dataclass +@dataclass(config=dict(extra="allow")) class Dataset: _target_: str = "torch.utils.data.Dataset" transform: Any = Field(default_factory=Transform) @@ -94,24 +96,24 @@ class Dataset: # return self -@dataclass +@dataclass(config=dict(extra="allow")) class FakeDataset(Dataset): _target_: str = "torchvision.datasets.FakeData" -@dataclass +@dataclass(config=dict(extra="allow")) class ImageFolderDataset(Dataset): _target_: str = "torchvision.datasets.ImageFolder" # transform: Transform = Field(default_factory=Transform) root: str = II("recipe.data") -@dataclass +@dataclass(config=dict(extra="allow")) class NdDataset(ImageFolderDataset): transform: Transform = Field(default_factory=Transform) -@dataclass +@dataclass(config=dict(extra="allow")) class TiffDataset(NdDataset): _target_: str = "bioimage_embed.datasets.TiffDataset" @@ -120,7 +122,7 @@ class NgffDataset(NdDataset): _target_: str = "bioimage_embed.datasets.NgffDataset" -@dataclass +@dataclass(config=dict(extra="allow")) class DataLoader: _target_: str = "bioimage_embed.lightning.dataloader.DataModule" dataset: Any = Field(default_factory=FakeDataset) @@ -128,7 +130,7 @@ class DataLoader: batch_size: int = II("recipe.batch_size") -@dataclass +@dataclass(config=dict(extra="allow")) class Model: _target_: str = "bioimage_embed.models.create_model" model: str = II("recipe.model") @@ -137,12 +139,12 @@ class Model: pretrained: bool = True -@dataclass +@dataclass(config=dict(extra="allow")) class Callback: pass -@dataclass +@dataclass(config=dict(extra="allow")) class EarlyStopping(Callback): _target_: str = "pytorch_lightning.callbacks.EarlyStopping" monitor: str = "loss/val" @@ -150,7 +152,7 @@ class EarlyStopping(Callback): patience: int = 3 -@dataclass +@dataclass(config=dict(extra="allow")) class ModelCheckpoint(Callback): _target_: str = "pytorch_lightning.callbacks.ModelCheckpoint" save_last = True @@ -161,7 +163,7 @@ class ModelCheckpoint(Callback): dirpath: str = f"{II('paths.model')}/{II('uuid')}" -@dataclass +@dataclass(config=dict(extra="allow")) class LightningModel: _target_: str = "bioimage_embed.lightning.torch.AEUnsupervised" # This should be pythae base autoencoder? @@ -173,20 +175,20 @@ class LightningModelSupervised(LightningModel): _target_: str = "bioimage_embed.lightning.torch.AESupervised" -@dataclass +@dataclass(config=dict(extra="allow")) class Callbacks: # _target_: str = "collections.OrderedDict" model_checkpoint: ModelCheckpoint = Field(default_factory=ModelCheckpoint) early_stopping: EarlyStopping = Field(default_factory=EarlyStopping) -@dataclass -class Trainer(pytorch_lightning.Trainer): +@dataclass(config=dict(extra="allow")) +class Trainer: _target_: str = "pytorch_lightning.Trainer" # logger: Optional[any] gradient_clip_val: float = 0.5 enable_checkpointing: bool = True - devices: str = "auto" + devices: Any = "auto" accelerator: str = "auto" accumulate_grad_batches: int = 16 min_epochs: int = 1 @@ -196,12 +198,13 @@ class Trainer(pytorch_lightning.Trainer): callbacks: List[Any] = Field( default_factory=lambda: list(vars(Callbacks()).values()), frozen=True ) + # TODO idea here would be to use pydantic to validate omegaconf # TODO add argument caching for checkpointing -@dataclass +@dataclass(config=dict(extra="allow")) class Paths: model: str = "models" logs: str = "logs" @@ -213,7 +216,7 @@ def __post_init__(self): os.makedirs(path, exist_ok=True) -@dataclass +@dataclass(config=dict(extra="allow")) class Config: # This has to be dataclass.field instead of pydantic Field for somereason paths: Any = field(default_factory=Paths) @@ -225,7 +228,7 @@ class Config: uuid: str = field(default_factory=lambda: utils.hashing_fn(Recipe())) -@dataclass +@dataclass(config=dict(extra="allow")) class SupervisedConfig(Config): lit_model: LightningModel = field(default_factory=LightningModel)