-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
97 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
diff --git a/models/align_reconstructions.py b/models/align_reconstructions.py | ||
index d07d1ab..c52b40d 100644 | ||
--- a/models/align_reconstructions.py | ||
+++ b/models/align_reconstructions.py | ||
@@ -6,7 +6,7 @@ import torch | ||
import torchgeometry as tgm | ||
import torchvision.transforms.functional as T_f | ||
|
||
-from registration import registration | ||
+from ..registration import registration | ||
|
||
|
||
def loss_reconstruction_fourier_batch(x, y, recon_loss_type="bce", mask=None): | ||
diff --git a/models/decoders/cnn_decoder.py b/models/decoders/cnn_decoder.py | ||
index ba3a1cc..1740945 100644 | ||
--- a/models/decoders/cnn_decoder.py | ||
+++ b/models/decoders/cnn_decoder.py | ||
@@ -58,7 +58,7 @@ class CnnDecoder(nn.Module): | ||
|
||
self.dec_conv = nn.Sequential(*layers) | ||
|
||
- def forward(self, x): | ||
+ def forward(self, x, epoch = None): | ||
bs = x.size(0) | ||
x = self.fc(x) | ||
dim = x.size(1) | ||
diff --git a/models/encoders_o2/e2scnn.py b/models/encoders_o2/e2scnn.py | ||
index 9c4f47f..e292b1e 100644 | ||
--- a/models/encoders_o2/e2scnn.py | ||
+++ b/models/encoders_o2/e2scnn.py | ||
@@ -219,14 +219,20 @@ class E2SFCNN(torch.nn.Module): | ||
repr += f"\t{i: <3} - {name: <70} | {params: <8} |\n" | ||
return repr | ||
|
||
- def forward(self, input: torch.tensor): | ||
+ def forward(self, input: torch.tensor, epoch = None): | ||
+ #print(f"DEBUG: e2scnn forward: input.shape: {input.shape}") | ||
x = GeometricTensor(input, self.in_repr) | ||
+ #print(f"DEBUG: e2scnn forward: pre layers x.shape: {x.shape}") | ||
|
||
for layer in self.eq_layers: | ||
x = layer(x) | ||
|
||
+ #print(f"DEBUG: e2scnn forward: pre fully_net x.shape: {x.shape}") | ||
+ | ||
x = self.fully_net(x.tensor.reshape(x.tensor.shape[0], -1)) | ||
|
||
+ #print(f"DEBUG: e2scnn forward: pre final x.shape: {x.shape}") | ||
+ | ||
return x | ||
|
||
def build_layer_regular( | ||
diff --git a/models/vae.py b/models/vae.py | ||
index 3af262b..af1a2dc 100644 | ||
--- a/models/vae.py | ||
+++ b/models/vae.py | ||
@@ -3,8 +3,9 @@ import importlib | ||
import numpy as np | ||
import torch | ||
import torchvision | ||
+from pythae.models.base.base_utils import ModelOutput | ||
|
||
-from models import align_reconstructions | ||
+from . import align_reconstructions | ||
|
||
from . import model_utils as mut | ||
|
||
@@ -273,10 +274,11 @@ class VAE(torch.nn.Module): | ||
|
||
return y | ||
|
||
- def forward(self, x): | ||
+ def forward(self, x, epoch = None): | ||
+ x = x["data"] | ||
in_shape = x.shape | ||
bs = in_shape[0] | ||
- assert x.ndim == 4 | ||
+ assert len(in_shape) == 4 | ||
|
||
# inference and sample | ||
z = self.q_net(x) | ||
@@ -290,8 +292,12 @@ class VAE(torch.nn.Module): | ||
y = torch.sigmoid(y) | ||
# check the spatial dimensions are good (if doing multiclass prediction per pixel, the `c` dim may be different) | ||
assert in_shape[-2:] == y.shape[-2:], ( | ||
- "output image different dimension to " | ||
- "input image ... probably change the number of layers (cnn_dims) in the decoder" | ||
+ f"output image different dimension {y.shape[-2:]} to " | ||
+ f"input image {in_shape[-2:]} ... probably change the number of layers (cnn_dims) in the decoder" | ||
) | ||
|
||
- return x, y, mu, logvar | ||
+ # gather losses | ||
+ losses = self.loss(x, y, mu, logvar) | ||
+ | ||
+ return ModelOutput(recon_x=y, z=z_sample, loss=losses['loss'], recon_loss=losses['loss_recon']) | ||
+ #return ModelOutput(recon_x=y, z=z_sample) |