From 96ee04f796b0ebe5c51531707a335a49d70f0f62 Mon Sep 17 00:00:00 2001 From: Anna Foix Date: Mon, 30 Sep 2024 17:21:56 +0100 Subject: [PATCH] added o2vae repo patch --- .../models/o2vae_shapeembed_integration.diff | 97 +++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 bioimage_embed/models/o2vae_shapeembed_integration.diff diff --git a/bioimage_embed/models/o2vae_shapeembed_integration.diff b/bioimage_embed/models/o2vae_shapeembed_integration.diff new file mode 100644 index 00000000..309d7206 --- /dev/null +++ b/bioimage_embed/models/o2vae_shapeembed_integration.diff @@ -0,0 +1,97 @@ +diff --git a/models/align_reconstructions.py b/models/align_reconstructions.py +index d07d1ab..c52b40d 100644 +--- a/models/align_reconstructions.py ++++ b/models/align_reconstructions.py +@@ -6,7 +6,7 @@ import torch + import torchgeometry as tgm + import torchvision.transforms.functional as T_f + +-from registration import registration ++from ..registration import registration + + + def loss_reconstruction_fourier_batch(x, y, recon_loss_type="bce", mask=None): +diff --git a/models/decoders/cnn_decoder.py b/models/decoders/cnn_decoder.py +index ba3a1cc..1740945 100644 +--- a/models/decoders/cnn_decoder.py ++++ b/models/decoders/cnn_decoder.py +@@ -58,7 +58,7 @@ class CnnDecoder(nn.Module): + + self.dec_conv = nn.Sequential(*layers) + +- def forward(self, x): ++ def forward(self, x, epoch = None): + bs = x.size(0) + x = self.fc(x) + dim = x.size(1) +diff --git a/models/encoders_o2/e2scnn.py b/models/encoders_o2/e2scnn.py +index 9c4f47f..e292b1e 100644 +--- a/models/encoders_o2/e2scnn.py ++++ b/models/encoders_o2/e2scnn.py +@@ -219,14 +219,20 @@ class E2SFCNN(torch.nn.Module): + repr += f"\t{i: <3} - {name: <70} | {params: <8} |\n" + return repr + +- def forward(self, input: torch.tensor): ++ def forward(self, input: torch.tensor, epoch = None): ++ #print(f"DEBUG: e2scnn forward: input.shape: {input.shape}") + x = GeometricTensor(input, self.in_repr) ++ #print(f"DEBUG: e2scnn forward: pre layers x.shape: {x.shape}") + + for layer in self.eq_layers: + x = layer(x) + ++ #print(f"DEBUG: e2scnn forward: pre fully_net x.shape: {x.shape}") ++ + x = self.fully_net(x.tensor.reshape(x.tensor.shape[0], -1)) + ++ #print(f"DEBUG: e2scnn forward: pre final x.shape: {x.shape}") ++ + return x + + def build_layer_regular( +diff --git a/models/vae.py b/models/vae.py +index 3af262b..af1a2dc 100644 +--- a/models/vae.py ++++ b/models/vae.py +@@ -3,8 +3,9 @@ import importlib + import numpy as np + import torch + import torchvision ++from pythae.models.base.base_utils import ModelOutput + +-from models import align_reconstructions ++from . import align_reconstructions + + from . import model_utils as mut + +@@ -273,10 +274,11 @@ class VAE(torch.nn.Module): + + return y + +- def forward(self, x): ++ def forward(self, x, epoch = None): ++ x = x["data"] + in_shape = x.shape + bs = in_shape[0] +- assert x.ndim == 4 ++ assert len(in_shape) == 4 + + # inference and sample + z = self.q_net(x) +@@ -290,8 +292,12 @@ class VAE(torch.nn.Module): + y = torch.sigmoid(y) + # check the spatial dimensions are good (if doing multiclass prediction per pixel, the `c` dim may be different) + assert in_shape[-2:] == y.shape[-2:], ( +- "output image different dimension to " +- "input image ... probably change the number of layers (cnn_dims) in the decoder" ++ f"output image different dimension {y.shape[-2:]} to " ++ f"input image {in_shape[-2:]} ... probably change the number of layers (cnn_dims) in the decoder" + ) + +- return x, y, mu, logvar ++ # gather losses ++ losses = self.loss(x, y, mu, logvar) ++ ++ return ModelOutput(recon_x=y, z=z_sample, loss=losses['loss'], recon_loss=losses['loss_recon']) ++ #return ModelOutput(recon_x=y, z=z_sample)