Skip to content

Commit

Permalink
Merge pull request #59 from ctr26/dev
Browse files Browse the repository at this point in the history
Fixing mask_embed for latest changes?
  • Loading branch information
ctr26 authored Aug 14, 2024
2 parents 74d4599 + d856741 commit 66a174f
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 39 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ download.data:
kaggle competitions download -c data-science-bowl-2018

test:
poetry run pytest -v
poetry run pytest -v --tb=no
34 changes: 24 additions & 10 deletions bioimage_embed/lightning/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,24 @@ def __init__(self, model, args=SimpleNamespace()):
# self.example_input_array = torch.randn(1, *self.model.input_dim)
# self.model.train()

def forward(self, x):
batch = ModelOutput(data=x.float())
return self.model(batch)

def predict_step(self, batch, batch_idx, dataloader_idx=0) -> ModelOutput:
return self.batch_to_tensor(batch, batch_idx)

def batch_to_tensor(self, batch, batch_idx) -> ModelOutput:
def forward(self, x: torch.Tensor) -> ModelOutput:
"""
Forward pass of the model
Pythae models take in ModelOutput objects, and return ModelOutput objects so that we can pass in and return multiple tensors
"""
return self.model(ModelOutput(data=x.float()))

def predict_step(
self, batch: tuple, batch_idx: int, dataloader_idx=0
) -> ModelOutput:
return self.batch_to_tensor(batch)

def batch_to_tensor(self, batch) -> ModelOutput:
"""
This takes in a batch and returns a ModelOutput object.
Lightning batches are x,y pairs of tensors, but we only need the x tensor for the model.
x is fed into the self.forward method
"""
x, y = self.batch_to_xy(batch)
model_output = self.forward(x)
model_output.data = x
Expand All @@ -76,7 +86,7 @@ def batch_to_tensor(self, batch, batch_idx) -> ModelOutput:
def embedding(self, model_output: ModelOutput) -> torch.Tensor:
return model_output.z.view(model_output.z.shape[0], -1)

def training_step(self, batch, batch_idx):
def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor:
self.model.train()
loss, model_output = self.eval_step(batch, batch_idx)
self.log_dict(
Expand All @@ -93,7 +103,9 @@ def training_step(self, batch, batch_idx):
self.log_tensorboard(model_output, model_output.data)
return loss

def loss_function(self, model_output, batch_idx, *args, **kwargs):
def loss_function(
self, model_output: ModelOutput, batch_idx: int, *args, **kwargs
) -> dict:
return {
"loss": model_output.loss,
"recon_loss": model_output.recon_loss,
Expand Down Expand Up @@ -260,6 +272,8 @@ def loss_function(self, model_output, batch_idx):
# Scale is used as the rest of the loss functions are sums rather than means, which may mean we need to scale up the contrastive loss

scale = torch.prod(torch.tensor(model_output.z.shape[1:]))
if model_output.target.unique().size(0) == 1:
return loss
pairs = create_label_based_pairs(model_output.z.squeeze(), model_output.target)
contrastive_loss = self.criteron(*pairs)
loss["contrastive_loss"] = scale * contrastive_loss
Expand Down
1 change: 1 addition & 0 deletions bioimage_embed/models/bolts/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
self.encoder = self.model.encoder
self.decoder = self.model.decoder
self.input_dim = self.model_config.input_dim
self.latent_dim = self.model_config.latent_dim

def forward(self, x, epoch=None):
# return ModelOutput(x=x,recon_x=x,z=x,loss=1)
Expand Down
8 changes: 6 additions & 2 deletions bioimage_embed/models/pythae/legacy/vq_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ def __init__(
num_residual_layers,
):
super(Encoder, self).__init__()
embedding_dim = model_config.latent_dim
input_dim = model_config.input_dim[1:]

self.model = ResnetEncoder(
in_channels=model_config.input_dim[0],
Expand Down Expand Up @@ -104,6 +102,9 @@ def __init__(
self.encoder = self.model._encoder
self.decoder = self.model._decoder
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.latent_dim = model_config.latent_dim
self.input_dim = model_config.input_dim

# This isn't completely necessary for training I don't think
# self._set_quantizer(model_config)

Expand Down Expand Up @@ -188,6 +189,9 @@ def __init__(
)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(num_hiddens, model_config.latent_dim * 2)
self.latent_dim = model_config.latent_dim
self.input_dim = model_config.input_dim

# shape is (batch_size, model_config.num_hiddens, 1, 1)

def reparameterize(self, mu, log_var):
Expand Down
47 changes: 28 additions & 19 deletions bioimage_embed/shapes/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@
import torchvision

from torch import nn
from ..lightning import AutoEncoderUnsupervised
from ..lightning import AutoEncoderUnsupervised, AutoEncoderSupervised
from . import loss_functions as lf
from transformers.utils import ModelOutput
from types import SimpleNamespace


def frobenius_norm_2D_torch(tensor: torch.Tensor) -> torch.Tensor:
return torch.norm(tensor, p="fro", dim=(-2, -1), keepdim=True)


class MaskEmbed(AutoEncoderUnsupervised):
def __init__(self, model, args=SimpleNamespace()):
super().__init__(model, args)

class MaskEmbedMixin:
def batch_to_tensor(self, batch):
"""
Converts a batch of data to a tensor
Expand All @@ -25,11 +21,15 @@ def batch_to_tensor(self, batch):
# x = batch[0].float()
output = super().batch_to_tensor(batch)
normalised_data = output.data
if self.args.frobenius_norm:
scalings = frobenius_norm_2D_torch(output.data)
else:
scalings = torch.ones_like(output.data)
return ModelOutput(data=normalised_data / scalings, scalings=scalings)
scalings = torch.ones_like(output.data)
if hasattr(self.args, "frobenius_norm"):
if self.args.frobenius_norm:
scalings = frobenius_norm_2D_torch(output.data)

output.data = normalised_data / scalings
output.scalings = scalings

return output

def loss_function(self, model_output, *args, **kwargs):
loss_ops = lf.DistanceMatrixLoss(model_output.recon_x, norm=False)
Expand All @@ -53,15 +53,24 @@ def loss_function(self, model_output, *args, **kwargs):
# loss += lf.triangle_inequality_loss(model_output.recon_x)
# loss += lf.non_negative_loss(model_output.recon_x)

# variational_loss = model_output.loss - model_output.recon_loss
variational_loss = model_output.loss - model_output.recon_loss

# loss_dict = {
# "loss": loss,
# "shape_loss": shape_loss,
# "reconstruction_loss": model_output.recon_x,
# "variational_loss": variational_loss,
# }
return loss
return {
"loss": loss,
"shape_loss": shape_loss,
"reconstruction_loss": model_output.recon_loss,
"variational_loss": variational_loss,
}


class MaskEmbed(MaskEmbedMixin, AutoEncoderUnsupervised):
def __init__(self, model, args=SimpleNamespace()):
super().__init__(model, args)


class MaskEmbedSupervised(MaskEmbedMixin, AutoEncoderSupervised):
def __init__(self, model, args=SimpleNamespace()):
super().__init__(model, args)


class FixedOutput(nn.Module):
Expand Down
81 changes: 81 additions & 0 deletions bioimage_embed/shapes/tests/test_lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from bioimage_embed.shapes.lightning import MaskEmbed, MaskEmbedSupervised
import pytest
from bioimage_embed import create_model
from torchvision.datasets import FakeData
import pytorch_lightning as pl
from bioimage_embed.lightning.dataloader import DataModule
from torchvision.transforms import transforms
from types import SimpleNamespace


@pytest.fixture
def transform():
return transforms.Compose(
[
transforms.ToTensor(),
]
)


@pytest.fixture(params=[1, 2, 16])
def classes(request):
return request.param


@pytest.fixture
def dataset(transform, classes):
return FakeData(
size=64,
image_size=(3, 224, 224),
num_classes=classes,
transform=transform,
)


@pytest.fixture
def model():
return create_model(
model="resnet18_vae",
input_dim=[3, 224, 224],
latent_dim=64,
pretrained=True,
)


# TODO Add tests for MaskEmbedSupervised
@pytest.fixture(params=[MaskEmbed, MaskEmbedSupervised])
def wrapper(request):
return request.param


@pytest.fixture
def lit_model(model, wrapper):
args = SimpleNamespace(frobenius_norm=False)
return wrapper(model, args)


@pytest.fixture
def trainer():
return pl.Trainer(
max_epochs=1,
max_steps=1,
# gpus=1,
fast_dev_run=True,
)


@pytest.fixture
def dataloader(dataset):
return DataModule(
dataset,
batch_size=16,
num_workers=0,
)


def test_model(trainer, lit_model, dataloader):
return trainer.test(lit_model, dataloader)


def test_model_fit(trainer, lit_model, dataloader):
return trainer.fit(lit_model, dataloader)
38 changes: 31 additions & 7 deletions bioimage_embed/tests/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
AEUnsupervised,
)
from bioimage_embed.models import create_model
from torch.utils.data import TensorDataset

from torchvision.datasets import FakeData
from torchvision import transforms

torch.manual_seed(42)


@pytest.fixture(params=[1, 2, 16])
def classes(request):
return request.param


@pytest.fixture(params=__all_models__)
def model_name(request):
return request.param
Expand Down Expand Up @@ -93,10 +98,13 @@ def data(input_dim):


@pytest.fixture()
def dataset(samples, input_dim, classes=2):
x = torch.rand(samples, *input_dim)
y = torch.torch.randint(classes - 1, (samples,))
return TensorDataset(x, y)
def dataset(samples, input_dim, classes):
return FakeData(
size=samples,
image_size=input_dim,
num_classes=classes,
transform=transforms.ToTensor(),
)


@pytest.fixture(params=[AESupervised, AEUnsupervised])
Expand Down Expand Up @@ -161,8 +169,24 @@ def test_dataset_trainer(trainer, lit_model, dataset):
return trainer.test(lit_model, dataset)


def test_model_properties(model):
assert model.encoder is not None
assert model.decoder is not None
assert model.latent_dim is not None
assert model.input_dim is not None
assert model.model_name is not None
assert model.model_config is not None


def test_trainer_predict(trainer, lit_model, datamodule):
return trainer.predict(lit_model, datamodule)
batch_size = datamodule.predict_dataloader().batch_size
latent_dim = lit_model.model.latent_dim
predictions = trainer.predict(lit_model, datamodule)
assert predictions is not None
assert len(predictions[0].z.flatten()) == batch_size * latent_dim
# TODO prefer
# assert list(predictions[0].z.shape) == [batch_size,latent_dim]
# assert len(list(predictions[0].z.shape)) == 2


# Has to be a list not a tuple
Expand Down

0 comments on commit 66a174f

Please sign in to comment.