From d7a259bc531a7e2a344afa057b248d1031a604a4 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Wed, 14 Aug 2024 16:30:21 +0100 Subject: [PATCH 01/22] [bug] notebooks cli stuff is erroring, adding tests to confirm --- bioimage_embed/tests/test_cli.py | 41 ++++++++++++++++---------- bioimage_embed/tests/test_lightning.py | 23 ++++++++++----- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/bioimage_embed/tests/test_cli.py b/bioimage_embed/tests/test_cli.py index c5510157..1408b698 100644 --- a/bioimage_embed/tests/test_cli.py +++ b/bioimage_embed/tests/test_cli.py @@ -45,14 +45,14 @@ def test_write_default_config_file( assert config_path.is_file(), "Default config file was not created" -@pytest.fixture -def cfg(): - mock_dataset = config.ImageFolderDataset( - _target_="bioimage_embed.datasets.FakeImageFolder", - ) - cfg = cli.get_default_config() - cfg.recipe.data = mock_dataset - return cfg +# @pytest.fixture +# def cfg(): +# mock_dataset = config.ImageFolderDataset( +# _target_="bioimage_embed.datasets.FakeImageFolder", +# ) +# cfg = cli.get_default_config() +# cfg.recipe.data = mock_dataset +# return cfg def test_get_default_config(cfg): @@ -60,12 +60,12 @@ def test_get_default_config(cfg): # Further assertions can be added to check specific config properties -def test_main_with_default_config( - cfg, config_path, config_dir, config_file, config_directory_setup -): - test_get_default_config +# def test_main_with_default_config( +# cfg, config_path, config_dir, config_file, config_directory_setup +# ): +# test_get_default_config - # cli.main(config_dir=config_dir, config_file=config_file, job_name="test_app") +# # cli.main(config_dir=config_dir, config_file=config_file, job_name="test_app") # @pytest.mark.skip("Computationally heavy") @@ -122,15 +122,24 @@ def hydra_cfg(): return cfg +@pytest.fixture +def model(): + return "dummy_model" + + # TODO double check this is sensible @pytest.fixture -def cfg(): +def cfg(model): cfg = config.Config() - cfg.dataloader.dataset._target_ = "bioimage_embed.datasets.FakeImageFolder" + cfg.dataloader.num_workers = 0 # This avoids processes being forked + cfg.trainer.max_epochs = 1 + cfg.trainer.max_steps = 1 + cfg.trainer.fast_dev_run = True + cfg.recipe.model = model return cfg -@pytest.mark.skip("Computationally heavy") +# @pytest.mark.skip("Computationally heavy") def test_train(cfg): cli.train(cfg) diff --git a/bioimage_embed/tests/test_lightning.py b/bioimage_embed/tests/test_lightning.py index 768e692a..d3c430af 100644 --- a/bioimage_embed/tests/test_lightning.py +++ b/bioimage_embed/tests/test_lightning.py @@ -8,12 +8,18 @@ AEUnsupervised, ) from bioimage_embed.models import create_model +from bioimage_embed import config +from hydra.utils import instantiate from torchvision.datasets import FakeData -from torchvision import transforms torch.manual_seed(42) +@pytest.fixture() +def transform(): + return instantiate(config.Transform()) + + @pytest.fixture(params=[1, 2, 16]) def classes(request): return request.param @@ -26,7 +32,7 @@ def model_name(request): @pytest.fixture() def image_dim(): - return (256, 256) + return (224, 224) @pytest.fixture() @@ -98,12 +104,15 @@ def data(input_dim): @pytest.fixture() -def dataset(samples, input_dim, classes): +def dataset(samples, input_dim, transform, classes=2): + # x = torch.rand(samples, *input_dim) + # y = torch.torch.randint(classes - 1, (samples,)) + # return TensorDataset(x, y) return FakeData( size=samples, image_size=input_dim, num_classes=classes, - transform=transforms.ToTensor(), + transform=transform, ) @@ -131,8 +140,8 @@ def datamodule(dataset, batch_size): @pytest.fixture() def trainer(): return pl.Trainer( - max_steps=1, - max_epochs=1, + # max_steps=1, + max_epochs=2, ) @@ -159,7 +168,7 @@ def test_trainer_dummy_model_fit(trainer, lit_dummy_model, datamodule): return trainer.fit(lit_dummy_model, datamodule) -@pytest.mark.skip(reason="Expensive") +# @pytest.mark.skip(reason="Expensive") def test_trainer_fit(trainer, lit_model, datamodule): return trainer.fit(lit_model, datamodule) From 1fe05a4c14db2ec46225e2a1c41ddb8415b1a9d4 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Wed, 14 Aug 2024 16:33:19 +0100 Subject: [PATCH 02/22] [bug] issue was the mess of dictionaries and model outputs, now consolidated --- bioimage_embed/lightning/torch.py | 63 +++++++++++++++++++------------ 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/bioimage_embed/lightning/torch.py b/bioimage_embed/lightning/torch.py index d2152072..b95317ed 100644 --- a/bioimage_embed/lightning/torch.py +++ b/bioimage_embed/lightning/torch.py @@ -88,11 +88,13 @@ def embedding(self, model_output: ModelOutput) -> torch.Tensor: def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor: self.model.train() - loss, model_output = self.eval_step(batch, batch_idx) + model_output = self.eval_step(batch, batch_idx) self.log_dict( { - "loss/train": loss, + "loss/train": model_output.loss, "mse/train": F.mse_loss(model_output.recon_x, model_output.data), + "recon_loss/train": model_output.recon_loss, + "variational_loss/train": model_output.variational_loss, }, # on_step=True, on_epoch=True, @@ -101,16 +103,22 @@ def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor: ) if isinstance(self.logger, pl.loggers.TensorBoardLogger): self.log_tensorboard(model_output, model_output.data) - return loss + return model_output.loss - def loss_function( - self, model_output: ModelOutput, batch_idx: int, *args, **kwargs - ) -> dict: - return { - "loss": model_output.loss, - "recon_loss": model_output.recon_loss, - "variational_loss": model_output.loss - model_output.recon_loss, - } + def loss_function(self, model_output, batch_idx, *args, **kwargs) -> ModelOutput: + """ + Function for overriding the loss function, should return a ModelOutput object. Preferably use super() to inherit the default loss function and then append the new loss. + """ + return model_output + + def _loss_function(self, model_output, batch_idx, *args, **kwargs) -> ModelOutput: + """ + Internal default loss function, should not be overridden. + This function calculates the variational loss. + """ + + model_output.variational_loss = model_output.loss - model_output.recon_loss + return model_output # def logging_step(self, z, loss, x, model_output, batch_idx): # self.logger.experiment.add_embedding( @@ -129,25 +137,29 @@ def loss_function( def validation_step(self, batch, batch_idx): # x, y = batch - loss, model_output = self.eval_step(batch, batch_idx) + model_output = self.eval_step(batch, batch_idx) self.log_dict( { - "loss/val": loss, + "loss/val": model_output.loss, "mse/val": F.mse_loss(model_output.recon_x, model_output.data), + "recon_loss/val": model_output.recon_loss, + "variational_loss/val": model_output.variational_loss, } ) - return loss + return model_output.loss def test_step(self, batch, batch_idx): # x, y = batch - loss, model_output = self.eval_step(batch, batch_idx) + model_output = self.eval_step(batch, batch_idx) self.log_dict( { - "loss/test": loss, + "loss/test": model_output.loss, "mse/test": F.mse_loss(model_output.recon_x, model_output.data), + "recon_loss/test": model_output.recon_loss, + "variational_loss/test": model_output.variational_loss, } ) - return loss + return model_output.loss # Fangless function to be overloaded later def batch_to_xy(self, batch): @@ -156,8 +168,10 @@ def batch_to_xy(self, batch): def eval_step(self, batch, batch_idx): model_output = self.predict_step(batch, batch_idx) - loss = self.loss_function(model_output, batch_idx) - return loss, model_output + # loss = model_output + model_output = self._loss_function(model_output, batch_idx) + model_output = self.loss_function(model_output, batch_idx) + return model_output # def lr_scheduler_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): # # Implement your own logic for updating the lr scheduler @@ -267,18 +281,19 @@ class AutoEncoderSupervised(AutoEncoder): def loss_function(self, model_output, batch_idx): # x, y = batch - loss = super().loss_function(model_output, batch_idx) + # loss = super().loss_function(model_output, batch_idx) + # TODO check this # Scale is used as the rest of the loss functions are sums rather than means, which may mean we need to scale up the contrastive loss scale = torch.prod(torch.tensor(model_output.z.shape[1:])) if model_output.target.unique().size(0) == 1: - return loss + return model_output pairs = create_label_based_pairs(model_output.z.squeeze(), model_output.target) contrastive_loss = self.criteron(*pairs) - loss["contrastive_loss"] = scale * contrastive_loss - loss["loss"] += loss["contrastive_loss"] - return loss + model_output.contrastive_loss = scale * contrastive_loss + model_output.loss += model_output.contrastive_loss + return model_output class AESupervised(AutoEncoderSupervised): From d06194ea806da82429217f5077446ac804e82cb1 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Wed, 14 Aug 2024 16:33:46 +0100 Subject: [PATCH 03/22] [bug] updated shape_embed to match (new) style --- bioimage_embed/shapes/lightning.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/bioimage_embed/shapes/lightning.py b/bioimage_embed/shapes/lightning.py index d092e8ca..59bc5cbf 100644 --- a/bioimage_embed/shapes/lightning.py +++ b/bioimage_embed/shapes/lightning.py @@ -33,7 +33,6 @@ def batch_to_tensor(self, batch): def loss_function(self, model_output, *args, **kwargs): loss_ops = lf.DistanceMatrixLoss(model_output.recon_x, norm=False) - loss = model_output.loss shape_loss = torch.sum( torch.stack( @@ -46,21 +45,15 @@ def loss_function(self, model_output, *args, **kwargs): ] ) ) - loss += shape_loss + model_output.loss += shape_loss + model_output.shape_loss = shape_loss # loss += lf.diagonal_loss(model_output.recon_x) # loss += lf.symmetry_loss(model_output.recon_x) # loss += lf.triangle_inequality_loss(model_output.recon_x) # loss += lf.non_negative_loss(model_output.recon_x) - variational_loss = model_output.loss - model_output.recon_loss - - return { - "loss": loss, - "shape_loss": shape_loss, - "reconstruction_loss": model_output.recon_loss, - "variational_loss": variational_loss, - } + return model_output class MaskEmbed(MaskEmbedMixin, AutoEncoderUnsupervised): From 92ce6563590815169f79d79d19e4d1212ac28654 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Mon, 12 Aug 2024 12:00:50 +0100 Subject: [PATCH 04/22] [init] notebook example that "should" work From 01631aa8bc18245c1c1def714cd159d9b2144577 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Mon, 12 Aug 2024 12:03:15 +0100 Subject: [PATCH 05/22] [feat] adding jupytext --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 0286e048..116fd368 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ transformers = "^4.42.4" torch = "^2.3.1" torchvision = "^0.18.1" wandb = "^0.17.4" +jupytext = "^1.16.4" From d90317c739bbd3eaa7769a9662751cde9b1e5866 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Mon, 12 Aug 2024 14:20:40 +0100 Subject: [PATCH 06/22] [bug] collate_fn doesn't work with num_workers > 0 --- bioimage_embed/lightning/dataloader.py | 38 +++++++++++++++----------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/bioimage_embed/lightning/dataloader.py b/bioimage_embed/lightning/dataloader.py index dc800ecd..75d05c34 100644 --- a/bioimage_embed/lightning/dataloader.py +++ b/bioimage_embed/lightning/dataloader.py @@ -5,6 +5,26 @@ from functools import partial +# https://stackoverflow.com/questions/74931838/cant-pickle-local-object-evaluationloop-advance-locals-batch-to-device-pyto +class Collator: + def collate_filter_for_none(self, batch): + """ + Collate function that filters out None values from the batch. + + Args: + batch: The batch to be filtered. + + Returns: + The filtered batch. + """ + batch = list(filter(lambda x: x is not None, batch)) + return torch.utils.data.dataloader.default_collate(batch) + + def __call__(self, incoming): + # do stuff with incoming + return self.collate_filter_for_none(incoming) + + class DataModule(pl.LightningDataModule): """ A PyTorch Lightning DataModule for handling dataset loading and splitting. @@ -25,7 +45,6 @@ def __init__( num_workers: int = 4, pin_memory: bool = False, drop_last: bool = False, - collate_fn=None, ): """ Initializes the DataModule with the given dataset and parameters. @@ -40,14 +59,14 @@ def __init__( """ super().__init__() self.dataset = dataset - collate_fn = collate_fn if collate_fn else self.collate_filter_for_none + self.collator = Collator() self.dataloader = partial( DataLoader, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, - collate_fn=collate_fn, + collate_fn=self.collator, ) self.train_dataset = None @@ -55,19 +74,6 @@ def __init__( self.test_dataset = None self.setup() - def collate_filter_for_none(self, batch): - """ - Collate function that filters out None values from the batch. - - Args: - batch: The batch to be filtered. - - Returns: - The filtered batch. - """ - batch = list(filter(lambda x: x is not None, batch)) - return torch.utils.data.dataloader.default_collate(batch) - def setup(self, stage=None): """ Sets up the datasets by splitting the main dataset into train, validation, and test sets. From 72c91fd2b4ba7b4b2106528792597f161b78387f Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Mon, 12 Aug 2024 14:21:19 +0100 Subject: [PATCH 07/22] [bug] removing collate, fixing path validation --- bioimage_embed/config.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/bioimage_embed/config.py b/bioimage_embed/config.py index 0059109b..349ce666 100644 --- a/bioimage_embed/config.py +++ b/bioimage_embed/config.py @@ -19,7 +19,7 @@ from pydantic.dataclasses import dataclass from typing import List, Optional, Dict, Any -from pydantic import Field, root_validator +from pydantic import Field from omegaconf import II from . import utils @@ -86,6 +86,12 @@ class Dataset: _target_: str = "torch.utils.data.Dataset" transform: Any = Field(default_factory=Transform) + # TODO add validation for transform to be floats + # @model_validator(mode="after") + # def validate(self): + # dataset = instantiate(self) + # return self + @dataclass class FakeDataset(Dataset): @@ -119,7 +125,6 @@ class DataLoader: dataset: Any = Field(default_factory=FakeDataset) num_workers: int = 1 batch_size: int = II("recipe.batch_size") - collate_fn: Any = None @dataclass @@ -202,15 +207,9 @@ class Paths: tensorboard: str = "tensorboard" wandb: str = "wandb" - @root_validator( - pre=False, skip_on_failure=True - ) # Ensures this runs after all other validations - @classmethod - def create_dirs(cls, values): - # The `values` dict contains all the validated field values - for path in values.values(): + def __post_init__(self): + for path in self.__dict__.values(): os.makedirs(path, exist_ok=True) - return values @dataclass From 9044247c160708ff6aa23a945362344ce3485847 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Mon, 12 Aug 2024 14:22:32 +0100 Subject: [PATCH 08/22] [bug] default augs need a float tensor --- bioimage_embed/augmentations.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bioimage_embed/augmentations.py b/bioimage_embed/augmentations.py index 61baae50..90fc20b1 100644 --- a/bioimage_embed/augmentations.py +++ b/bioimage_embed/augmentations.py @@ -32,6 +32,7 @@ ), # Adjust image intensity with a specified range for individual channels A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), + A.ToFloat(), ToTensorV2(), ] @@ -49,8 +50,8 @@ def __call__(self, image): img = np.array(image) transformed = self.transform(image=img) return transformed["image"] - except: - return None,None + except Exception: + return None, None class VisionWrapperSupervised: From 0539c082848f11c6c475ee116adaaff84c7b8562 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Mon, 12 Aug 2024 14:22:59 +0100 Subject: [PATCH 09/22] [feat] adding fit check to full check --- bioimage_embed/bie.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bioimage_embed/bie.py b/bioimage_embed/bie.py index 0af2c769..d39c36fc 100644 --- a/bioimage_embed/bie.py +++ b/bioimage_embed/bie.py @@ -162,3 +162,4 @@ def export(self): def check(self): self.model_check() self.trainer_check() + self.trainer_check_fit() From 193c1898c56872ca89776012b4bb8b6d091aa1e5 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Mon, 12 Aug 2024 14:23:43 +0100 Subject: [PATCH 10/22] [feat] adding simple.py --- scripts/simple.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 scripts/simple.py diff --git a/scripts/simple.py b/scripts/simple.py new file mode 100644 index 00000000..4d97ed7f --- /dev/null +++ b/scripts/simple.py @@ -0,0 +1,41 @@ +# %% +import bioimage_embed +import bioimage_embed.config as config + +# %% +from torchvision.datasets import FakeData +from hydra.utils import instantiate + + +# %% +transform = instantiate(config.Transform()) + +# # %% +dataset = FakeData( + size=64, + image_size=(3, 224, 224), + num_classes=10, + transform=transform, +) +# dataset=CelebA(download=True, root="/tmp", split="train") + +# %% [markdown] + +# %% +cfg = config.Config(dataset=dataset) +cfg.recipe.model = "resnet18_vae" +cfg.recipe.max_epochs = 100 +bie = bioimage_embed.BioImageEmbed(cfg) + + +# %% +def process(): + bie.check() + bie.train() + bie.export() + + +# %% +# This is the entrypoint for the script and very import if cfg.trainer.num_workers > 0 +if __name__ == "__main__": + process() From 8a0c5f4f13f5aecb0b7a915668873b95751dedba Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Mon, 12 Aug 2024 14:24:38 +0100 Subject: [PATCH 11/22] [feat] cleaning up cli tests --- bioimage_embed/tests/test_cli.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/bioimage_embed/tests/test_cli.py b/bioimage_embed/tests/test_cli.py index 1408b698..19fa8752 100644 --- a/bioimage_embed/tests/test_cli.py +++ b/bioimage_embed/tests/test_cli.py @@ -69,14 +69,14 @@ def test_get_default_config(cfg): # @pytest.mark.skip("Computationally heavy") -def test_hydra(): - # bie_train model.model="resnet50_vqvae" dataset._target_="bioimage_embed.datasets.FakeImageFolder" - input_dim = [3, 224, 224] - cfg = config.Config() - cfg.dataloader.dataset._target_ = "bioimage_embed.datasets.FakeImageFolder" - cfg.dataloader.dataset.image_size = input_dim - cfg.recipe.model = "resnet18_vae" - cfg.recipe.max_epochs = 1 +# def test_hydra(): +# # bie_train model.model="resnet50_vqvae" dataset._target_="bioimage_embed.datasets.FakeImageFolder" +# input_dim = [3, 224, 224] +# cfg = config.Config() +# # cfg.dataloader.dataset._target_ = "bioimage_embed.datasets.FakeImageFolder" +# # cfg.dataloader.dataset.image_size = input_dim +# cfg.recipe.model = "dummy_model" +# cfg.recipe.max_epochs = 1 # def test_cli(): @@ -114,11 +114,7 @@ def test_init_hydra_with_invalid_config_file(): @pytest.fixture def hydra_cfg(): with initialize(config_path="."): - # cfg = compose(config_name="config", overrides=[ - # 'dataloader.dataset._target_=bioimage_embed.datasets.FakeImageFolder' - # ]) cfg = compose(config_name="config") - cfg.dataloader.dataset._target_ = "bioimage_embed.datasets.FakeImageFolder" return cfg @@ -144,6 +140,5 @@ def test_train(cfg): cli.train(cfg) -@pytest.mark.skip("Computationally heavy") def test_check(cfg): cli.check(cfg) From f0bff4adbbf4bb0883bd56ad551e62060f15b2ef Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Mon, 12 Aug 2024 14:25:33 +0100 Subject: [PATCH 12/22] [feat] adding jupytext --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 116fd368..384e79d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ torch = "^2.3.1" torchvision = "^0.18.1" wandb = "^0.17.4" jupytext = "^1.16.4" - +jupyter = "^1.0.0" [tool.poetry.group.dev.dependencies] From fdb29c78db2bca13dd97788ff2993a4fcac21ad9 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Mon, 12 Aug 2024 14:45:58 +0100 Subject: [PATCH 13/22] [feat] adding jupytext toml --- jupytext.toml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 jupytext.toml diff --git a/jupytext.toml b/jupytext.toml new file mode 100644 index 00000000..b64f6f25 --- /dev/null +++ b/jupytext.toml @@ -0,0 +1,3 @@ +[formats] +"notebooks/" = "ipynb" +"scripts/" = "py:percent" From bb72314ba9104a0411735da5d5ed1dc614b8186c Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Mon, 12 Aug 2024 14:46:17 +0100 Subject: [PATCH 14/22] [feat] adding autoencoder prose --- scripts/simple.py | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/scripts/simple.py b/scripts/simple.py index 4d97ed7f..af8b53b6 100644 --- a/scripts/simple.py +++ b/scripts/simple.py @@ -2,24 +2,45 @@ import bioimage_embed import bioimage_embed.config as config +# Import necessary modules from bioimage_embed and config. +# bioimage_embed is likely a library designed for embedding biological images, +# and config is used to handle configurations. + # %% from torchvision.datasets import FakeData from hydra.utils import instantiate +# Import FakeData from torchvision.datasets to create a fake dataset, +# and instantiate from hydra.utils to create instances based on configuration. # %% transform = instantiate(config.Transform()) -# # %% +# Instantiate a transformation using the configuration provided. +# This will likely include any data augmentation or preprocessing steps defined in the configuration. + +# %% dataset = FakeData( size=64, image_size=(3, 224, 224), num_classes=10, transform=transform, ) + +# Create a fake dataset with 64 images of size 224x224x3 (3 channels), and 10 classes. +# This dataset will be used to simulate data for testing purposes. The 'transform' argument applies the +# transformations defined earlier to the dataset. + +# NOTE: The 'dataset' must be a PyTorch Dataset object with X (data) and y (labels). +# If using an unsupervised encoder, set the labels (y) to None; the model will ignore them during training. + # dataset=CelebA(download=True, root="/tmp", split="train") +# The commented-out code suggests an alternative to use the CelebA dataset. +# It would download the CelebA dataset and use the training split, storing it in the '/tmp' directory. + # %% [markdown] +# # %% cfg = config.Config(dataset=dataset) @@ -27,6 +48,10 @@ cfg.recipe.max_epochs = 100 bie = bioimage_embed.BioImageEmbed(cfg) +# Create a configuration object 'cfg' using the config module, and assign the fake dataset to it. +# The model is set to "resnet18_vae" and the maximum number of epochs for training is set to 100. +# Instantiate the BioImageEmbed object 'bie' using the configuration. + # %% def process(): @@ -35,7 +60,17 @@ def process(): bie.export() +# Define a process function that performs three steps: +# 1. 'check()' to verify the setup or configuration. +# 2. 'train()' to start training the model. +# 3. 'export()' to export the trained model. + # %% -# This is the entrypoint for the script and very import if cfg.trainer.num_workers > 0 +# This is the entrypoint for the script and very important if cfg.trainer.num_workers > 0 if __name__ == "__main__": process() + +# This is the entry point for the script. The 'if __name__ == "__main__":' statement ensures that the 'process()' +# function is called only when the script is run directly, not when imported as a module. +# This is crucial if the 'num_workers' parameter is set in cfg.trainer, as it prevents potential issues +# with multiprocessing in PyTorch. From f67b8150837478f3a53051a8ee3e7d3a1a325614 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Wed, 14 Aug 2024 18:07:21 +0100 Subject: [PATCH 15/22] [fix] removing all the funny loss,model_output logic --- bioimage_embed/lightning/torch.py | 49 +++++-------------------------- 1 file changed, 8 insertions(+), 41 deletions(-) diff --git a/bioimage_embed/lightning/torch.py b/bioimage_embed/lightning/torch.py index b95317ed..83e95845 100644 --- a/bioimage_embed/lightning/torch.py +++ b/bioimage_embed/lightning/torch.py @@ -94,7 +94,7 @@ def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor: "loss/train": model_output.loss, "mse/train": F.mse_loss(model_output.recon_x, model_output.data), "recon_loss/train": model_output.recon_loss, - "variational_loss/train": model_output.variational_loss, + "variational_loss/train": model_output.loss - model_output.recon_loss, }, # on_step=True, on_epoch=True, @@ -105,45 +105,14 @@ def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor: self.log_tensorboard(model_output, model_output.data) return model_output.loss - def loss_function(self, model_output, batch_idx, *args, **kwargs) -> ModelOutput: - """ - Function for overriding the loss function, should return a ModelOutput object. Preferably use super() to inherit the default loss function and then append the new loss. - """ - return model_output - - def _loss_function(self, model_output, batch_idx, *args, **kwargs) -> ModelOutput: - """ - Internal default loss function, should not be overridden. - This function calculates the variational loss. - """ - - model_output.variational_loss = model_output.loss - model_output.recon_loss - return model_output - - # def logging_step(self, z, loss, x, model_output, batch_idx): - # self.logger.experiment.add_embedding( - # z, - # label_img=x["data"], - # global_step=self.current_epoch, - # tag="z", - # ) - - # self.logger.experiment.add_scalar("Loss/val", loss, batch_idx) - # self.logger.experiment.add_image( - # "val", - # torchvision.utils.make_grid(model_output.recon_x), - # batch_idx, - # ) - def validation_step(self, batch, batch_idx): - # x, y = batch model_output = self.eval_step(batch, batch_idx) self.log_dict( { "loss/val": model_output.loss, "mse/val": F.mse_loss(model_output.recon_x, model_output.data), "recon_loss/val": model_output.recon_loss, - "variational_loss/val": model_output.variational_loss, + "variational_loss/val": model_output.loss - model_output.recon_loss, } ) return model_output.loss @@ -156,7 +125,7 @@ def test_step(self, batch, batch_idx): "loss/test": model_output.loss, "mse/test": F.mse_loss(model_output.recon_x, model_output.data), "recon_loss/test": model_output.recon_loss, - "variational_loss/test": model_output.variational_loss, + "variational_loss/test": model_output.loss - model_output.recon_loss, } ) return model_output.loss @@ -167,10 +136,10 @@ def batch_to_xy(self, batch): return x, y def eval_step(self, batch, batch_idx): + """ + This function should be overloaded in the child class to implement the evaluation logic. + """ model_output = self.predict_step(batch, batch_idx) - # loss = model_output - model_output = self._loss_function(model_output, batch_idx) - model_output = self.loss_function(model_output, batch_idx) return model_output # def lr_scheduler_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): @@ -279,13 +248,11 @@ def create_label_based_pairs( class AutoEncoderSupervised(AutoEncoder): criteron = losses.ContrastiveLoss() - def loss_function(self, model_output, batch_idx): + def eval_step(self, batch, batch_idx): # x, y = batch - # loss = super().loss_function(model_output, batch_idx) - # TODO check this # Scale is used as the rest of the loss functions are sums rather than means, which may mean we need to scale up the contrastive loss - + model_output = self.predict_step(batch, batch_idx) scale = torch.prod(torch.tensor(model_output.z.shape[1:])) if model_output.target.unique().size(0) == 1: return model_output From fec2ee331c302d74d9e9977af1bae96fbba58735 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Wed, 14 Aug 2024 18:10:53 +0100 Subject: [PATCH 16/22] [fix] removing the loss_function stuff and replacing entirely with eval_step --- bioimage_embed/shapes/lightning.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bioimage_embed/shapes/lightning.py b/bioimage_embed/shapes/lightning.py index 59bc5cbf..4e787a1a 100644 --- a/bioimage_embed/shapes/lightning.py +++ b/bioimage_embed/shapes/lightning.py @@ -31,7 +31,9 @@ def batch_to_tensor(self, batch): return output - def loss_function(self, model_output, *args, **kwargs): + def eval_step(self, batch, batch_idx): + # model_output = super().eval_step(batch, batch_idx) + model_output = self.predict_step(batch, batch_idx) loss_ops = lf.DistanceMatrixLoss(model_output.recon_x, norm=False) shape_loss = torch.sum( From 35820cf00d4eefd8f61178520a8c8b977969941f Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Thu, 15 Aug 2024 07:30:16 +0100 Subject: [PATCH 17/22] [bug] needs to be super() for eval --- bioimage_embed/shapes/lightning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bioimage_embed/shapes/lightning.py b/bioimage_embed/shapes/lightning.py index 4e787a1a..1237ea6f 100644 --- a/bioimage_embed/shapes/lightning.py +++ b/bioimage_embed/shapes/lightning.py @@ -32,8 +32,8 @@ def batch_to_tensor(self, batch): return output def eval_step(self, batch, batch_idx): - # model_output = super().eval_step(batch, batch_idx) - model_output = self.predict_step(batch, batch_idx) + # Needs to be super because eval_step is overwritten in Supervised + model_output = super().eval_step(batch, batch_idx) loss_ops = lf.DistanceMatrixLoss(model_output.recon_x, norm=False) shape_loss = torch.sum( From f1218a5beddcdd5f7c35984ebe5de6e03b6bb294 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Thu, 15 Aug 2024 07:30:38 +0100 Subject: [PATCH 18/22] [fix] improved cli tests --- bioimage_embed/tests/test_cli.py | 34 +++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/bioimage_embed/tests/test_cli.py b/bioimage_embed/tests/test_cli.py index 19fa8752..37e0f5c9 100644 --- a/bioimage_embed/tests/test_cli.py +++ b/bioimage_embed/tests/test_cli.py @@ -123,15 +123,35 @@ def model(): return "dummy_model" +@pytest.fixture +def cfg_recipe(model): + return config.Recipe(model=model) + + +@pytest.fixture +def cfg_trainer(): + return config.Trainer(max_epochs=1, max_steps=1, fast_dev_run=True) + + +@pytest.fixture +def cfg_dataloader(): + return config.DataLoader(num_workers=0) + + # TODO double check this is sensible @pytest.fixture -def cfg(model): - cfg = config.Config() - cfg.dataloader.num_workers = 0 # This avoids processes being forked - cfg.trainer.max_epochs = 1 - cfg.trainer.max_steps = 1 - cfg.trainer.fast_dev_run = True - cfg.recipe.model = model +def cfg(cfg_recipe, cfg_trainer, cfg_dataloader): + cfg = config.Config( + recipe=cfg_recipe, trainer=cfg_trainer, dataloader=cfg_dataloader + ) + return cfg + # This is an alternative way to create a config object but it is less flexible and if the config object is changed in the future, this will break, i.e validation is not guaranteed + + # cfg.dataloader.num_workers = 0 # This avoids processes being forked + # cfg.trainer.max_epochs = 1 + # cfg.trainer.max_steps = 1 + # cfg.trainer.fast_dev_run = True + # cfg.recipe.model = model return cfg From ff4a106be5846b865ad55b0a9bcad6298c30ccdb Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Thu, 15 Aug 2024 07:31:44 +0100 Subject: [PATCH 19/22] [fix] variational_loss is handled per *_step now, might rethink later --- bioimage_embed/lightning/torch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bioimage_embed/lightning/torch.py b/bioimage_embed/lightning/torch.py index 83e95845..52dbee37 100644 --- a/bioimage_embed/lightning/torch.py +++ b/bioimage_embed/lightning/torch.py @@ -130,8 +130,10 @@ def test_step(self, batch, batch_idx): ) return model_output.loss - # Fangless function to be overloaded later def batch_to_xy(self, batch): + """ + Fangless function to be overloaded later + """ x, y = batch return x, y @@ -139,8 +141,7 @@ def eval_step(self, batch, batch_idx): """ This function should be overloaded in the child class to implement the evaluation logic. """ - model_output = self.predict_step(batch, batch_idx) - return model_output + return self.predict_step(batch, batch_idx) # def lr_scheduler_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): # # Implement your own logic for updating the lr scheduler From 12764b09d55d108e9bbbce9c4e808360bf4b2851 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Thu, 15 Aug 2024 07:40:12 +0100 Subject: [PATCH 20/22] [fix] extra cfg --- bioimage_embed/tests/test_cli.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bioimage_embed/tests/test_cli.py b/bioimage_embed/tests/test_cli.py index 37e0f5c9..a2dabc78 100644 --- a/bioimage_embed/tests/test_cli.py +++ b/bioimage_embed/tests/test_cli.py @@ -152,7 +152,6 @@ def cfg(cfg_recipe, cfg_trainer, cfg_dataloader): # cfg.trainer.max_steps = 1 # cfg.trainer.fast_dev_run = True # cfg.recipe.model = model - return cfg # @pytest.mark.skip("Computationally heavy") From b92991e7c1caaa0ac55c6e4cc2c4cc19f6992a83 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Thu, 15 Aug 2024 07:45:55 +0100 Subject: [PATCH 21/22] [feat] adding model_predict_test --- bioimage_embed/shapes/tests/test_lightning.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bioimage_embed/shapes/tests/test_lightning.py b/bioimage_embed/shapes/tests/test_lightning.py index 2e9f58b0..36b49681 100644 --- a/bioimage_embed/shapes/tests/test_lightning.py +++ b/bioimage_embed/shapes/tests/test_lightning.py @@ -79,3 +79,9 @@ def test_model(trainer, lit_model, dataloader): def test_model_fit(trainer, lit_model, dataloader): return trainer.fit(lit_model, dataloader) + + +def test_model_predict(trainer, lit_model, dataloader): + y = trainer.predict(lit_model, dataloader) + # TODO Add checks for shape_loss and potentially other losses (contrastive loss) + return y From 5f4f4493687422d8662720c58a318cf7fa809bb3 Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Thu, 15 Aug 2024 08:21:38 +0100 Subject: [PATCH 22/22] [fix] removing expensive function --- bioimage_embed/tests/test_lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bioimage_embed/tests/test_lightning.py b/bioimage_embed/tests/test_lightning.py index d3c430af..37ccae5c 100644 --- a/bioimage_embed/tests/test_lightning.py +++ b/bioimage_embed/tests/test_lightning.py @@ -168,7 +168,7 @@ def test_trainer_dummy_model_fit(trainer, lit_dummy_model, datamodule): return trainer.fit(lit_dummy_model, datamodule) -# @pytest.mark.skip(reason="Expensive") +@pytest.mark.skip(reason="Expensive") def test_trainer_fit(trainer, lit_model, datamodule): return trainer.fit(lit_model, datamodule)