Skip to content

Commit

Permalink
[ref] cleaning up commented code
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Sep 30, 2024
1 parent e8a10ab commit 95b22ea
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 94 deletions.
43 changes: 1 addition & 42 deletions bioimage_embed/models/bolts/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from pythae.models.nn import BaseDecoder, BaseEncoder


from pythae import models
from pythae.models import VAEConfig
from . import resnets


def count_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

Expand Down Expand Up @@ -93,44 +93,3 @@ def forward(self, x):
x = self.embedding(x)
x = self.decoder(x)
return ModelOutput(reconstruction=x)


# class VAEPythaeWrapper(models.VAE):
# def __init__(
# self,
# model_config,
# input_height,
# enc_type="resnet18",
# enc_out_dim=512,
# first_conv=False,
# maxpool1=False,
# kl_coeff=0.1,
# encoder=None,
# decoder=None,
# ):
# super(models.BaseAE, self).__init__()
# self.model_name = "VAE"
# self.model_config = model_config
# self.model = ae.VAE(
# input_height=input_height,
# enc_type=enc_type,
# enc_out_dim=enc_out_dim,
# first_conv=first_conv,
# maxpool1=maxpool1,
# kl_coeff=kl_coeff,
# latent_dim=model_config.latent_dim,
# )
# self.encoder = self.model.encoder
# self.decoder = self.model.decoder

# def forward(self, x, epoch=None):
# # return ModelOutput(x=x,recon_x=x,z=x,loss=1)
# # # Forward pass logic
# x = x["data"]
# x_recon = self.model(x)
# z, recon_x, p, q = self.model._run_step(x)
# loss, logs = self.model.step((x, x), batch_idx=epoch)
# # recon_loss = self.model.reconstruction_loss(x, recon_x)
# return ModelOutput(recon_x=recon_x, z=z, logs=logs, loss=loss, recon_loss=loss)


52 changes: 0 additions & 52 deletions bioimage_embed/models/factory.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
# import torch

# import torch.nn.functional as F

# Note - you must have torchvision installed for this example
# from torch.utils.data import DataLoader

# from bioimage_embed.transforms import DistogramToMaskPipeline


# from .vae_bio import Mask_VAE, Image_VAE

# from .bolts import ResNet18VAEEncoder, ResNet18VAEDecoder

from typing import Tuple
Expand Down Expand Up @@ -53,45 +50,6 @@ def dummy_model(self):
lambda x: None,
)

# def resnet_vae_bolt(
# self,
# enc_type,
# enc_out_dim,
# first_conv=False,
# maxpool1=False,
# kl_coeff=1.0,
# ):
# return self.create_model(
# pythae.models.VAEConfig,
# partial(
# bolts.vae.VAEPythaeWrapper,
# input_height=self.input_dim[1],
# enc_type=enc_type,
# enc_out_dim=enc_out_dim,
# first_conv=first_conv,
# maxpool1=maxpool1,
# kl_coeff=kl_coeff,
# ),
# encoder_class=lambda x: None,
# decoder_class=lambda x: None,
# )

# bolts.vae.VAEPythaeWrapper(
# input_height=self.input_dim[1],
# enc_type=enc_type,
# enc_out_dim=enc_out_dim,
# first_conv=first_conv,
# maxpool1=maxpool1,
# kl_coeff=kl_coeff,
# latent_dim=self.latent_dim,
# )

# def resnet18_vae_bolt(self, **kwargs):
# return self.resnet_vae_bolt(enc_type="resnet18", enc_out_dim=512, **kwargs)

# def resnet50_vae_bolt(self, **kwargs):
# return self.resnet_vae_bolt(enc_type="resnet50", enc_out_dim=2048, **kwargs)

def resnet18_vae(self):
return self.create_model(
partial(
Expand Down Expand Up @@ -185,19 +143,10 @@ def resnet152_vqvae_legacy(self):
def __call__(self, model):
return getattr(self, model)()

# return getattr(self
# (
# self.input_dim, self.latent_dim, self.pretrained, self.progress),
# ),
# model,
# )


__all_models__ = [
"resnet18_vae",
"resnet50_vae",
# "resnet18_vae_bolt",
# "resnet50_vae_bolt",
"resnet18_vqvae",
"resnet50_vqvae",
"resnet18_vqvae_legacy",
Expand All @@ -217,7 +166,6 @@ def __call__(self, model):
]



def create_model(
model: str,
input_dim: Tuple[int, int, int],
Expand Down

0 comments on commit 95b22ea

Please sign in to comment.