Skip to content

Commit

Permalink
[fix] pydantic cann allow extra fields
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Sep 17, 2024
1 parent 1488f82 commit 25b3e32
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions bioimage_embed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,22 @@
"""

import pytorch_lightning
# TODO need a way to copy signatures from original classes for validation
from omegaconf import OmegaConf
from bioimage_embed import augmentations as augs
import os
from dataclasses import field
from pydantic.dataclasses import dataclass

# from dataclasses import dataclass
from typing import List, Optional, Dict, Any

from pydantic import Field
from omegaconf import II
from . import utils


@dataclass
@dataclass(config=dict(extra="allow"))
class Recipe:
_target_: str = "types.SimpleNamespace"
model: str = "resnet18_vae"
Expand Down Expand Up @@ -59,7 +61,7 @@ class Recipe:

# Use the ALbumentations .to_dict() method to get the dictionary
# that pydantic can use
@dataclass
@dataclass(config=dict(extra="allow"))
class ATransform:
_target_: str = "albumentations.from_dict"
_convert_: str = "object"
Expand All @@ -72,7 +74,7 @@ class ATransform:
# VisionWrapper is a helper class for applying albumentations pipelines for image augmentations in autoencoding


@dataclass
@dataclass(config=dict(extra="allow"))
class Transform:
_target_: str = "bioimage_embed.augmentations.VisionWrapper"
_convert_: str = "object"
Expand All @@ -82,7 +84,7 @@ class Transform:
)


@dataclass
@dataclass(config=dict(extra="allow"))
class Dataset:
_target_: str = "torch.utils.data.Dataset"
transform: Any = Field(default_factory=Transform)
Expand All @@ -94,24 +96,24 @@ class Dataset:
# return self


@dataclass
@dataclass(config=dict(extra="allow"))
class FakeDataset(Dataset):
_target_: str = "torchvision.datasets.FakeData"


@dataclass
@dataclass(config=dict(extra="allow"))
class ImageFolderDataset(Dataset):
_target_: str = "torchvision.datasets.ImageFolder"
# transform: Transform = Field(default_factory=Transform)
root: str = II("recipe.data")


@dataclass
@dataclass(config=dict(extra="allow"))
class NdDataset(ImageFolderDataset):
transform: Transform = Field(default_factory=Transform)


@dataclass
@dataclass(config=dict(extra="allow"))
class TiffDataset(NdDataset):
_target_: str = "bioimage_embed.datasets.TiffDataset"

Expand All @@ -120,15 +122,15 @@ class NgffDataset(NdDataset):
_target_: str = "bioimage_embed.datasets.NgffDataset"


@dataclass
@dataclass(config=dict(extra="allow"))
class DataLoader:
_target_: str = "bioimage_embed.lightning.dataloader.DataModule"
dataset: Any = Field(default_factory=FakeDataset)
num_workers: int = 1
batch_size: int = II("recipe.batch_size")


@dataclass
@dataclass(config=dict(extra="allow"))
class Model:
_target_: str = "bioimage_embed.models.create_model"
model: str = II("recipe.model")
Expand All @@ -137,20 +139,20 @@ class Model:
pretrained: bool = True


@dataclass
@dataclass(config=dict(extra="allow"))
class Callback:
pass


@dataclass
@dataclass(config=dict(extra="allow"))
class EarlyStopping(Callback):
_target_: str = "pytorch_lightning.callbacks.EarlyStopping"
monitor: str = "loss/val"
mode: str = "min"
patience: int = 3


@dataclass
@dataclass(config=dict(extra="allow"))
class ModelCheckpoint(Callback):
_target_: str = "pytorch_lightning.callbacks.ModelCheckpoint"
save_last = True
Expand All @@ -161,7 +163,7 @@ class ModelCheckpoint(Callback):
dirpath: str = f"{II('paths.model')}/{II('uuid')}"


@dataclass
@dataclass(config=dict(extra="allow"))
class LightningModel:
_target_: str = "bioimage_embed.lightning.torch.AEUnsupervised"
# This should be pythae base autoencoder?
Expand All @@ -173,20 +175,20 @@ class LightningModelSupervised(LightningModel):
_target_: str = "bioimage_embed.lightning.torch.AESupervised"


@dataclass
@dataclass(config=dict(extra="allow"))
class Callbacks:
# _target_: str = "collections.OrderedDict"
model_checkpoint: ModelCheckpoint = Field(default_factory=ModelCheckpoint)
early_stopping: EarlyStopping = Field(default_factory=EarlyStopping)


@dataclass
class Trainer(pytorch_lightning.Trainer):
@dataclass(config=dict(extra="allow"))
class Trainer:
_target_: str = "pytorch_lightning.Trainer"
# logger: Optional[any]
gradient_clip_val: float = 0.5
enable_checkpointing: bool = True
devices: str = "auto"
devices: Any = "auto"
accelerator: str = "auto"
accumulate_grad_batches: int = 16
min_epochs: int = 1
Expand All @@ -196,12 +198,13 @@ class Trainer(pytorch_lightning.Trainer):
callbacks: List[Any] = Field(
default_factory=lambda: list(vars(Callbacks()).values()), frozen=True
)
# TODO idea here would be to use pydantic to validate omegaconf


# TODO add argument caching for checkpointing


@dataclass
@dataclass(config=dict(extra="allow"))
class Paths:
model: str = "models"
logs: str = "logs"
Expand All @@ -213,7 +216,7 @@ def __post_init__(self):
os.makedirs(path, exist_ok=True)


@dataclass
@dataclass(config=dict(extra="allow"))
class Config:
# This has to be dataclass.field instead of pydantic Field for somereason
paths: Any = field(default_factory=Paths)
Expand All @@ -225,7 +228,7 @@ class Config:
uuid: str = field(default_factory=lambda: utils.hashing_fn(Recipe()))


@dataclass
@dataclass(config=dict(extra="allow"))
class SupervisedConfig(Config):
lit_model: LightningModel = field(default_factory=LightningModel)

Expand Down

0 comments on commit 25b3e32

Please sign in to comment.