Skip to content

Commit

Permalink
Merge pull request #60 from ctr26/bie_fix
Browse files Browse the repository at this point in the history
Bie fix
  • Loading branch information
ctr26 authored Aug 15, 2024
2 parents 66a174f + 5f4f449 commit 6974ac2
Show file tree
Hide file tree
Showing 12 changed files with 224 additions and 121 deletions.
5 changes: 3 additions & 2 deletions bioimage_embed/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
),
# Adjust image intensity with a specified range for individual channels
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
A.ToFloat(),
ToTensorV2(),
]

Expand All @@ -49,8 +50,8 @@ def __call__(self, image):
img = np.array(image)
transformed = self.transform(image=img)
return transformed["image"]
except:
return None,None
except Exception:
return None, None


class VisionWrapperSupervised:
Expand Down
1 change: 1 addition & 0 deletions bioimage_embed/bie.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,4 @@ def export(self):
def check(self):
self.model_check()
self.trainer_check()
self.trainer_check_fit()
19 changes: 9 additions & 10 deletions bioimage_embed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pydantic.dataclasses import dataclass
from typing import List, Optional, Dict, Any

from pydantic import Field, root_validator
from pydantic import Field
from omegaconf import II
from . import utils

Expand Down Expand Up @@ -86,6 +86,12 @@ class Dataset:
_target_: str = "torch.utils.data.Dataset"
transform: Any = Field(default_factory=Transform)

# TODO add validation for transform to be floats
# @model_validator(mode="after")
# def validate(self):
# dataset = instantiate(self)
# return self


@dataclass
class FakeDataset(Dataset):
Expand Down Expand Up @@ -119,7 +125,6 @@ class DataLoader:
dataset: Any = Field(default_factory=FakeDataset)
num_workers: int = 1
batch_size: int = II("recipe.batch_size")
collate_fn: Any = None


@dataclass
Expand Down Expand Up @@ -202,15 +207,9 @@ class Paths:
tensorboard: str = "tensorboard"
wandb: str = "wandb"

@root_validator(
pre=False, skip_on_failure=True
) # Ensures this runs after all other validations
@classmethod
def create_dirs(cls, values):
# The `values` dict contains all the validated field values
for path in values.values():
def __post_init__(self):
for path in self.__dict__.values():
os.makedirs(path, exist_ok=True)
return values


@dataclass
Expand Down
38 changes: 22 additions & 16 deletions bioimage_embed/lightning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@
from functools import partial


# https://stackoverflow.com/questions/74931838/cant-pickle-local-object-evaluationloop-advance-locals-batch-to-device-pyto
class Collator:
def collate_filter_for_none(self, batch):
"""
Collate function that filters out None values from the batch.
Args:
batch: The batch to be filtered.
Returns:
The filtered batch.
"""
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)

def __call__(self, incoming):
# do stuff with incoming
return self.collate_filter_for_none(incoming)


class DataModule(pl.LightningDataModule):
"""
A PyTorch Lightning DataModule for handling dataset loading and splitting.
Expand All @@ -25,7 +45,6 @@ def __init__(
num_workers: int = 4,
pin_memory: bool = False,
drop_last: bool = False,
collate_fn=None,
):
"""
Initializes the DataModule with the given dataset and parameters.
Expand All @@ -40,34 +59,21 @@ def __init__(
"""
super().__init__()
self.dataset = dataset
collate_fn = collate_fn if collate_fn else self.collate_filter_for_none
self.collator = Collator()
self.dataloader = partial(
DataLoader,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=drop_last,
collate_fn=collate_fn,
collate_fn=self.collator,
)

self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
self.setup()

def collate_filter_for_none(self, batch):
"""
Collate function that filters out None values from the batch.
Args:
batch: The batch to be filtered.
Returns:
The filtered batch.
"""
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)

def setup(self, stage=None):
"""
Sets up the datasets by splitting the main dataset into train, validation, and test sets.
Expand Down
73 changes: 28 additions & 45 deletions bioimage_embed/lightning/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,13 @@ def embedding(self, model_output: ModelOutput) -> torch.Tensor:

def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor:
self.model.train()
loss, model_output = self.eval_step(batch, batch_idx)
model_output = self.eval_step(batch, batch_idx)
self.log_dict(
{
"loss/train": loss,
"loss/train": model_output.loss,
"mse/train": F.mse_loss(model_output.recon_x, model_output.data),
"recon_loss/train": model_output.recon_loss,
"variational_loss/train": model_output.loss - model_output.recon_loss,
},
# on_step=True,
on_epoch=True,
Expand All @@ -101,63 +103,45 @@ def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor:
)
if isinstance(self.logger, pl.loggers.TensorBoardLogger):
self.log_tensorboard(model_output, model_output.data)
return loss

def loss_function(
self, model_output: ModelOutput, batch_idx: int, *args, **kwargs
) -> dict:
return {
"loss": model_output.loss,
"recon_loss": model_output.recon_loss,
"variational_loss": model_output.loss - model_output.recon_loss,
}

# def logging_step(self, z, loss, x, model_output, batch_idx):
# self.logger.experiment.add_embedding(
# z,
# label_img=x["data"],
# global_step=self.current_epoch,
# tag="z",
# )

# self.logger.experiment.add_scalar("Loss/val", loss, batch_idx)
# self.logger.experiment.add_image(
# "val",
# torchvision.utils.make_grid(model_output.recon_x),
# batch_idx,
# )
return model_output.loss

def validation_step(self, batch, batch_idx):
# x, y = batch
loss, model_output = self.eval_step(batch, batch_idx)
model_output = self.eval_step(batch, batch_idx)
self.log_dict(
{
"loss/val": loss,
"loss/val": model_output.loss,
"mse/val": F.mse_loss(model_output.recon_x, model_output.data),
"recon_loss/val": model_output.recon_loss,
"variational_loss/val": model_output.loss - model_output.recon_loss,
}
)
return loss
return model_output.loss

def test_step(self, batch, batch_idx):
# x, y = batch
loss, model_output = self.eval_step(batch, batch_idx)
model_output = self.eval_step(batch, batch_idx)
self.log_dict(
{
"loss/test": loss,
"loss/test": model_output.loss,
"mse/test": F.mse_loss(model_output.recon_x, model_output.data),
"recon_loss/test": model_output.recon_loss,
"variational_loss/test": model_output.loss - model_output.recon_loss,
}
)
return loss
return model_output.loss

# Fangless function to be overloaded later
def batch_to_xy(self, batch):
"""
Fangless function to be overloaded later
"""
x, y = batch
return x, y

def eval_step(self, batch, batch_idx):
model_output = self.predict_step(batch, batch_idx)
loss = self.loss_function(model_output, batch_idx)
return loss, model_output
"""
This function should be overloaded in the child class to implement the evaluation logic.
"""
return self.predict_step(batch, batch_idx)

# def lr_scheduler_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
# # Implement your own logic for updating the lr scheduler
Expand Down Expand Up @@ -265,20 +249,19 @@ def create_label_based_pairs(
class AutoEncoderSupervised(AutoEncoder):
criteron = losses.ContrastiveLoss()

def loss_function(self, model_output, batch_idx):
def eval_step(self, batch, batch_idx):
# x, y = batch
loss = super().loss_function(model_output, batch_idx)
# TODO check this
# Scale is used as the rest of the loss functions are sums rather than means, which may mean we need to scale up the contrastive loss

model_output = self.predict_step(batch, batch_idx)
scale = torch.prod(torch.tensor(model_output.z.shape[1:]))
if model_output.target.unique().size(0) == 1:
return loss
return model_output
pairs = create_label_based_pairs(model_output.z.squeeze(), model_output.target)
contrastive_loss = self.criteron(*pairs)
loss["contrastive_loss"] = scale * contrastive_loss
loss["loss"] += loss["contrastive_loss"]
return loss
model_output.contrastive_loss = scale * contrastive_loss
model_output.loss += model_output.contrastive_loss
return model_output


class AESupervised(AutoEncoderSupervised):
Expand Down
17 changes: 6 additions & 11 deletions bioimage_embed/shapes/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ def batch_to_tensor(self, batch):

return output

def loss_function(self, model_output, *args, **kwargs):
def eval_step(self, batch, batch_idx):
# Needs to be super because eval_step is overwritten in Supervised
model_output = super().eval_step(batch, batch_idx)
loss_ops = lf.DistanceMatrixLoss(model_output.recon_x, norm=False)
loss = model_output.loss

shape_loss = torch.sum(
torch.stack(
Expand All @@ -46,21 +47,15 @@ def loss_function(self, model_output, *args, **kwargs):
]
)
)
loss += shape_loss
model_output.loss += shape_loss
model_output.shape_loss = shape_loss

# loss += lf.diagonal_loss(model_output.recon_x)
# loss += lf.symmetry_loss(model_output.recon_x)
# loss += lf.triangle_inequality_loss(model_output.recon_x)
# loss += lf.non_negative_loss(model_output.recon_x)

variational_loss = model_output.loss - model_output.recon_loss

return {
"loss": loss,
"shape_loss": shape_loss,
"reconstruction_loss": model_output.recon_loss,
"variational_loss": variational_loss,
}
return model_output


class MaskEmbed(MaskEmbedMixin, AutoEncoderUnsupervised):
Expand Down
6 changes: 6 additions & 0 deletions bioimage_embed/shapes/tests/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,9 @@ def test_model(trainer, lit_model, dataloader):

def test_model_fit(trainer, lit_model, dataloader):
return trainer.fit(lit_model, dataloader)


def test_model_predict(trainer, lit_model, dataloader):
y = trainer.predict(lit_model, dataloader)
# TODO Add checks for shape_loss and potentially other losses (contrastive loss)
return y
Loading

0 comments on commit 6974ac2

Please sign in to comment.