Skip to content

Commit

Permalink
added o2vae repo patch
Browse files Browse the repository at this point in the history
  • Loading branch information
afoix authored and ctr26 committed Oct 1, 2024
1 parent 1e8446e commit 96ee04f
Showing 1 changed file with 97 additions and 0 deletions.
97 changes: 97 additions & 0 deletions bioimage_embed/models/o2vae_shapeembed_integration.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
diff --git a/models/align_reconstructions.py b/models/align_reconstructions.py
index d07d1ab..c52b40d 100644
--- a/models/align_reconstructions.py
+++ b/models/align_reconstructions.py
@@ -6,7 +6,7 @@ import torch
import torchgeometry as tgm
import torchvision.transforms.functional as T_f

-from registration import registration
+from ..registration import registration


def loss_reconstruction_fourier_batch(x, y, recon_loss_type="bce", mask=None):
diff --git a/models/decoders/cnn_decoder.py b/models/decoders/cnn_decoder.py
index ba3a1cc..1740945 100644
--- a/models/decoders/cnn_decoder.py
+++ b/models/decoders/cnn_decoder.py
@@ -58,7 +58,7 @@ class CnnDecoder(nn.Module):

self.dec_conv = nn.Sequential(*layers)

- def forward(self, x):
+ def forward(self, x, epoch = None):
bs = x.size(0)
x = self.fc(x)
dim = x.size(1)
diff --git a/models/encoders_o2/e2scnn.py b/models/encoders_o2/e2scnn.py
index 9c4f47f..e292b1e 100644
--- a/models/encoders_o2/e2scnn.py
+++ b/models/encoders_o2/e2scnn.py
@@ -219,14 +219,20 @@ class E2SFCNN(torch.nn.Module):
repr += f"\t{i: <3} - {name: <70} | {params: <8} |\n"
return repr

- def forward(self, input: torch.tensor):
+ def forward(self, input: torch.tensor, epoch = None):
+ #print(f"DEBUG: e2scnn forward: input.shape: {input.shape}")
x = GeometricTensor(input, self.in_repr)
+ #print(f"DEBUG: e2scnn forward: pre layers x.shape: {x.shape}")

for layer in self.eq_layers:
x = layer(x)

+ #print(f"DEBUG: e2scnn forward: pre fully_net x.shape: {x.shape}")
+
x = self.fully_net(x.tensor.reshape(x.tensor.shape[0], -1))

+ #print(f"DEBUG: e2scnn forward: pre final x.shape: {x.shape}")
+
return x

def build_layer_regular(
diff --git a/models/vae.py b/models/vae.py
index 3af262b..af1a2dc 100644
--- a/models/vae.py
+++ b/models/vae.py
@@ -3,8 +3,9 @@ import importlib
import numpy as np
import torch
import torchvision
+from pythae.models.base.base_utils import ModelOutput

-from models import align_reconstructions
+from . import align_reconstructions

from . import model_utils as mut

@@ -273,10 +274,11 @@ class VAE(torch.nn.Module):

return y

- def forward(self, x):
+ def forward(self, x, epoch = None):
+ x = x["data"]
in_shape = x.shape
bs = in_shape[0]
- assert x.ndim == 4
+ assert len(in_shape) == 4

# inference and sample
z = self.q_net(x)
@@ -290,8 +292,12 @@ class VAE(torch.nn.Module):
y = torch.sigmoid(y)
# check the spatial dimensions are good (if doing multiclass prediction per pixel, the `c` dim may be different)
assert in_shape[-2:] == y.shape[-2:], (
- "output image different dimension to "
- "input image ... probably change the number of layers (cnn_dims) in the decoder"
+ f"output image different dimension {y.shape[-2:]} to "
+ f"input image {in_shape[-2:]} ... probably change the number of layers (cnn_dims) in the decoder"
)

- return x, y, mu, logvar
+ # gather losses
+ losses = self.loss(x, y, mu, logvar)
+
+ return ModelOutput(recon_x=y, z=z_sample, loss=losses['loss'], recon_loss=losses['loss_recon'])
+ #return ModelOutput(recon_x=y, z=z_sample)

0 comments on commit 96ee04f

Please sign in to comment.