Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bie fix #60

Merged
merged 22 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions bioimage_embed/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
),
# Adjust image intensity with a specified range for individual channels
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
A.ToFloat(),
ToTensorV2(),
]

Expand All @@ -49,8 +50,8 @@ def __call__(self, image):
img = np.array(image)
transformed = self.transform(image=img)
return transformed["image"]
except:
return None,None
except Exception:
return None, None


class VisionWrapperSupervised:
Expand Down
1 change: 1 addition & 0 deletions bioimage_embed/bie.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,4 @@ def export(self):
def check(self):
self.model_check()
self.trainer_check()
self.trainer_check_fit()
19 changes: 9 additions & 10 deletions bioimage_embed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pydantic.dataclasses import dataclass
from typing import List, Optional, Dict, Any

from pydantic import Field, root_validator
from pydantic import Field
from omegaconf import II
from . import utils

Expand Down Expand Up @@ -86,6 +86,12 @@ class Dataset:
_target_: str = "torch.utils.data.Dataset"
transform: Any = Field(default_factory=Transform)

# TODO add validation for transform to be floats
# @model_validator(mode="after")
# def validate(self):
# dataset = instantiate(self)
# return self


@dataclass
class FakeDataset(Dataset):
Expand Down Expand Up @@ -119,7 +125,6 @@ class DataLoader:
dataset: Any = Field(default_factory=FakeDataset)
num_workers: int = 1
batch_size: int = II("recipe.batch_size")
collate_fn: Any = None


@dataclass
Expand Down Expand Up @@ -202,15 +207,9 @@ class Paths:
tensorboard: str = "tensorboard"
wandb: str = "wandb"

@root_validator(
pre=False, skip_on_failure=True
) # Ensures this runs after all other validations
@classmethod
def create_dirs(cls, values):
# The `values` dict contains all the validated field values
for path in values.values():
def __post_init__(self):
for path in self.__dict__.values():
os.makedirs(path, exist_ok=True)
return values


@dataclass
Expand Down
38 changes: 22 additions & 16 deletions bioimage_embed/lightning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@
from functools import partial


# https://stackoverflow.com/questions/74931838/cant-pickle-local-object-evaluationloop-advance-locals-batch-to-device-pyto
class Collator:
def collate_filter_for_none(self, batch):
"""
Collate function that filters out None values from the batch.

Args:
batch: The batch to be filtered.

Returns:
The filtered batch.
"""
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)

def __call__(self, incoming):
# do stuff with incoming
return self.collate_filter_for_none(incoming)


class DataModule(pl.LightningDataModule):
"""
A PyTorch Lightning DataModule for handling dataset loading and splitting.
Expand All @@ -25,7 +45,6 @@ def __init__(
num_workers: int = 4,
pin_memory: bool = False,
drop_last: bool = False,
collate_fn=None,
):
"""
Initializes the DataModule with the given dataset and parameters.
Expand All @@ -40,34 +59,21 @@ def __init__(
"""
super().__init__()
self.dataset = dataset
collate_fn = collate_fn if collate_fn else self.collate_filter_for_none
self.collator = Collator()
self.dataloader = partial(
DataLoader,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=drop_last,
collate_fn=collate_fn,
collate_fn=self.collator,
)

self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
self.setup()

def collate_filter_for_none(self, batch):
"""
Collate function that filters out None values from the batch.

Args:
batch: The batch to be filtered.

Returns:
The filtered batch.
"""
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)

def setup(self, stage=None):
"""
Sets up the datasets by splitting the main dataset into train, validation, and test sets.
Expand Down
73 changes: 28 additions & 45 deletions bioimage_embed/lightning/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,13 @@ def embedding(self, model_output: ModelOutput) -> torch.Tensor:

def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor:
self.model.train()
loss, model_output = self.eval_step(batch, batch_idx)
model_output = self.eval_step(batch, batch_idx)
self.log_dict(
{
"loss/train": loss,
"loss/train": model_output.loss,
"mse/train": F.mse_loss(model_output.recon_x, model_output.data),
"recon_loss/train": model_output.recon_loss,
"variational_loss/train": model_output.loss - model_output.recon_loss,
},
# on_step=True,
on_epoch=True,
Expand All @@ -101,63 +103,45 @@ def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor:
)
if isinstance(self.logger, pl.loggers.TensorBoardLogger):
self.log_tensorboard(model_output, model_output.data)
return loss

def loss_function(
self, model_output: ModelOutput, batch_idx: int, *args, **kwargs
) -> dict:
return {
"loss": model_output.loss,
"recon_loss": model_output.recon_loss,
"variational_loss": model_output.loss - model_output.recon_loss,
}

# def logging_step(self, z, loss, x, model_output, batch_idx):
# self.logger.experiment.add_embedding(
# z,
# label_img=x["data"],
# global_step=self.current_epoch,
# tag="z",
# )

# self.logger.experiment.add_scalar("Loss/val", loss, batch_idx)
# self.logger.experiment.add_image(
# "val",
# torchvision.utils.make_grid(model_output.recon_x),
# batch_idx,
# )
return model_output.loss

def validation_step(self, batch, batch_idx):
# x, y = batch
loss, model_output = self.eval_step(batch, batch_idx)
model_output = self.eval_step(batch, batch_idx)
self.log_dict(
{
"loss/val": loss,
"loss/val": model_output.loss,
"mse/val": F.mse_loss(model_output.recon_x, model_output.data),
"recon_loss/val": model_output.recon_loss,
"variational_loss/val": model_output.loss - model_output.recon_loss,
}
)
return loss
return model_output.loss

def test_step(self, batch, batch_idx):
# x, y = batch
loss, model_output = self.eval_step(batch, batch_idx)
model_output = self.eval_step(batch, batch_idx)
self.log_dict(
{
"loss/test": loss,
"loss/test": model_output.loss,
"mse/test": F.mse_loss(model_output.recon_x, model_output.data),
"recon_loss/test": model_output.recon_loss,
"variational_loss/test": model_output.loss - model_output.recon_loss,
}
)
return loss
return model_output.loss

# Fangless function to be overloaded later
def batch_to_xy(self, batch):
"""
Fangless function to be overloaded later
"""
x, y = batch
return x, y

def eval_step(self, batch, batch_idx):
model_output = self.predict_step(batch, batch_idx)
loss = self.loss_function(model_output, batch_idx)
return loss, model_output
"""
This function should be overloaded in the child class to implement the evaluation logic.
"""
return self.predict_step(batch, batch_idx)

# def lr_scheduler_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
# # Implement your own logic for updating the lr scheduler
Expand Down Expand Up @@ -265,20 +249,19 @@ def create_label_based_pairs(
class AutoEncoderSupervised(AutoEncoder):
criteron = losses.ContrastiveLoss()

def loss_function(self, model_output, batch_idx):
def eval_step(self, batch, batch_idx):
# x, y = batch
loss = super().loss_function(model_output, batch_idx)
# TODO check this
# Scale is used as the rest of the loss functions are sums rather than means, which may mean we need to scale up the contrastive loss

model_output = self.predict_step(batch, batch_idx)
scale = torch.prod(torch.tensor(model_output.z.shape[1:]))
if model_output.target.unique().size(0) == 1:
return loss
return model_output
pairs = create_label_based_pairs(model_output.z.squeeze(), model_output.target)
contrastive_loss = self.criteron(*pairs)
loss["contrastive_loss"] = scale * contrastive_loss
loss["loss"] += loss["contrastive_loss"]
return loss
model_output.contrastive_loss = scale * contrastive_loss
model_output.loss += model_output.contrastive_loss
return model_output


class AESupervised(AutoEncoderSupervised):
Expand Down
17 changes: 6 additions & 11 deletions bioimage_embed/shapes/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ def batch_to_tensor(self, batch):

return output

def loss_function(self, model_output, *args, **kwargs):
def eval_step(self, batch, batch_idx):
# Needs to be super because eval_step is overwritten in Supervised
model_output = super().eval_step(batch, batch_idx)
loss_ops = lf.DistanceMatrixLoss(model_output.recon_x, norm=False)
loss = model_output.loss

shape_loss = torch.sum(
torch.stack(
Expand All @@ -46,21 +47,15 @@ def loss_function(self, model_output, *args, **kwargs):
]
)
)
loss += shape_loss
model_output.loss += shape_loss
model_output.shape_loss = shape_loss

# loss += lf.diagonal_loss(model_output.recon_x)
# loss += lf.symmetry_loss(model_output.recon_x)
# loss += lf.triangle_inequality_loss(model_output.recon_x)
# loss += lf.non_negative_loss(model_output.recon_x)

variational_loss = model_output.loss - model_output.recon_loss

return {
"loss": loss,
"shape_loss": shape_loss,
"reconstruction_loss": model_output.recon_loss,
"variational_loss": variational_loss,
}
return model_output


class MaskEmbed(MaskEmbedMixin, AutoEncoderUnsupervised):
Expand Down
6 changes: 6 additions & 0 deletions bioimage_embed/shapes/tests/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,9 @@ def test_model(trainer, lit_model, dataloader):

def test_model_fit(trainer, lit_model, dataloader):
return trainer.fit(lit_model, dataloader)


def test_model_predict(trainer, lit_model, dataloader):
y = trainer.predict(lit_model, dataloader)
# TODO Add checks for shape_loss and potentially other losses (contrastive loss)
return y
Loading
Loading