Skip to content

Commit

Permalink
[bug] shape issue sometimes would arrive where scalings wasn't declared
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Aug 14, 2024
1 parent 584b135 commit f994f31
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions bioimage_embed/shapes/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ def batch_to_tensor(self, batch):
# x = batch[0].float()
output = super().batch_to_tensor(batch)
normalised_data = output.data

if hasattr(self.args, "frobenius_norm") and self.args.frobenius_norm:
scalings = frobenius_norm_2D_torch(output.data)
else:
scalings = torch.ones_like(output.data)
scalings = torch.ones_like(output.data)
if hasattr(self.args, "frobenius_norm"):
if self.args.frobenius_norm:
scalings = frobenius_norm_2D_torch(output.data)

output.data = normalised_data / scalings
output.scalings = scalings
Expand Down Expand Up @@ -59,17 +58,17 @@ def loss_function(self, model_output, *args, **kwargs):
return {
"loss": loss,
"shape_loss": shape_loss,
"reconstruction_loss": model_output.recon_x,
"reconstruction_loss": model_output.recon_loss,
"variational_loss": variational_loss,
}


class MaskEmbed(AutoEncoderUnsupervised, MaskEmbedMixin):
class MaskEmbed(MaskEmbedMixin, AutoEncoderUnsupervised):
def __init__(self, model, args=SimpleNamespace()):
super().__init__(model, args)


class MaskEmbedSupervised(AutoEncoderSupervised, MaskEmbedMixin):
class MaskEmbedSupervised(MaskEmbedMixin, AutoEncoderSupervised):
def __init__(self, model, args=SimpleNamespace()):
super().__init__(model, args)

Expand Down

0 comments on commit f994f31

Please sign in to comment.