diff --git a/bioimage_embed/shapes/lightning.py b/bioimage_embed/shapes/lightning.py index d7ca31c9..d092e8ca 100644 --- a/bioimage_embed/shapes/lightning.py +++ b/bioimage_embed/shapes/lightning.py @@ -21,11 +21,10 @@ def batch_to_tensor(self, batch): # x = batch[0].float() output = super().batch_to_tensor(batch) normalised_data = output.data - - if hasattr(self.args, "frobenius_norm") and self.args.frobenius_norm: - scalings = frobenius_norm_2D_torch(output.data) - else: - scalings = torch.ones_like(output.data) + scalings = torch.ones_like(output.data) + if hasattr(self.args, "frobenius_norm"): + if self.args.frobenius_norm: + scalings = frobenius_norm_2D_torch(output.data) output.data = normalised_data / scalings output.scalings = scalings @@ -59,17 +58,17 @@ def loss_function(self, model_output, *args, **kwargs): return { "loss": loss, "shape_loss": shape_loss, - "reconstruction_loss": model_output.recon_x, + "reconstruction_loss": model_output.recon_loss, "variational_loss": variational_loss, } -class MaskEmbed(AutoEncoderUnsupervised, MaskEmbedMixin): +class MaskEmbed(MaskEmbedMixin, AutoEncoderUnsupervised): def __init__(self, model, args=SimpleNamespace()): super().__init__(model, args) -class MaskEmbedSupervised(AutoEncoderSupervised, MaskEmbedMixin): +class MaskEmbedSupervised(MaskEmbedMixin, AutoEncoderSupervised): def __init__(self, model, args=SimpleNamespace()): super().__init__(model, args)