Skip to content

Commit

Permalink
I think this is the correct way for vq (copying how pythae does it)
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Jan 16, 2024
1 parent 4bae42f commit 1c822c1
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions bioimage_embed/models/pythae/legacy/vq_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
depth,
encoder=None,
decoder=None,
strict_latent_size=False,
strict_latent_size=True,
):
super(models.BaseAE, self).__init__()
# super(nn.Module)
Expand Down Expand Up @@ -106,16 +106,18 @@ def forward(self, x, epoch=None):
# loss, x_recon, perplexity = self.model.forward(x["data"])
z = self.model.encoder(x["data"])
z = self.model._pre_vq_conv(z)

proper_shape = z.shape

if self.strict_latent_size:
z = self.avgpool(z)
# Features need to be in the right order for the quantizer
z = z.permute(0, 2, 3, 1)

loss, quantized, perplexity, encodings = self.model._vq_vae(z)

z = quantized.flatten(1)
if self.strict_latent_size:
quantized = quantized.expand(-1, -1, *proper_shape[-2:])
quantized = quantized.permute(0, 3, 1, 2)
quantized = quantized.expand(-1, *proper_shape[-3:])

x_recon = self.model._decoder(quantized)
# return loss, x_recon, perplexity
Expand All @@ -137,12 +139,12 @@ def forward(self, x, epoch=None):
variational_loss = loss-mse_loss

pythae_loss_dict = {
"recon_loss": recon_loss,
"recon_loss": mse_loss,
"vq_loss": variational_loss,
# TODO check this proppperppply
"loss": recon_loss*torch.exp(variational_loss),
"loss": recon_loss + variational_loss,
"recon_x": x_recon,
"z": quantized,
"z": z,
"quantized_indices": indices[0],
"indices": indices,
}
Expand Down

0 comments on commit 1c822c1

Please sign in to comment.