Skip to content

Commit

Permalink
[rebase] combining debolt and master
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Sep 30, 2024
1 parent 2b97491 commit 951b2bc
Show file tree
Hide file tree
Showing 9 changed files with 321 additions and 76 deletions.
1 change: 1 addition & 0 deletions bioimage_embed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ class Trainer:
min_epochs: int = 1
max_epochs: int = II("recipe.max_epochs")
log_every_n_steps: int = 1
ckpt_path: str = "last"
# This is not a clean implementation but I am not sure how to do it better
callbacks: List[Any] = Field(
default_factory=lambda: list(vars(Callbacks()).values()), frozen=True
Expand Down
89 changes: 44 additions & 45 deletions bioimage_embed/models/bolts/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from pythae import models
from pythae.models import VAEConfig
from pl_bolts.models import autoencoders as ae

from . import resnets

def count_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
Expand All @@ -23,7 +22,7 @@ def __init__(
# input_height = model_config.input_dim[-2]
latent_dim = model_config.latent_dim

self.encoder = ae.resnet50_encoder(first_conv, maxpool1)
self.encoder = resnets.resnet50_encoder(first_conv, maxpool1)
self.embedding = nn.Linear(self.enc_out_dim, latent_dim)
self.log_var = nn.Linear(self.enc_out_dim, latent_dim)
# self.fc1 = nn.Linear(512, latent_dim)
Expand All @@ -45,7 +44,7 @@ def __init__(
latent_dim = model_config.latent_dim
input_height = model_config.input_dim[-2]
self.embedding = nn.Linear(latent_dim, self.enc_out_dim)
self.decoder = ae.resnet50_decoder(
self.decoder = resnets.resnet50_decoder(
self.enc_out_dim, input_height, first_conv, maxpool1
)

Expand All @@ -66,7 +65,7 @@ def __init__(
# input_height = model_config.input_dim[-2]
latent_dim = model_config.latent_dim

self.encoder = ae.resnet18_encoder(first_conv, maxpool1)
self.encoder = resnets.resnet18_encoder(first_conv, maxpool1)
self.embedding = nn.Linear(self.enc_out_dim, latent_dim)
self.log_var = nn.Linear(self.enc_out_dim, latent_dim)

Expand All @@ -85,7 +84,7 @@ def __init__(
super(ResNet18VAEDecoder, self).__init__()
latent_dim = model_config.latent_dim
input_height = model_config.input_dim[-2]
self.decoder = ae.resnet18_decoder(
self.decoder = resnets.resnet18_decoder(
self.enc_out_dim, input_height, first_conv, maxpool1
)
self.embedding = nn.Linear(latent_dim, self.enc_out_dim)
Expand All @@ -96,42 +95,42 @@ def forward(self, x):
return ModelOutput(reconstruction=x)


class VAEPythaeWrapper(models.VAE):
def __init__(
self,
model_config,
input_height,
enc_type="resnet18",
enc_out_dim=512,
first_conv=False,
maxpool1=False,
kl_coeff=0.1,
encoder=None,
decoder=None,
):
super(models.BaseAE, self).__init__()
self.model_name = "VAE_bolt"
self.model_config = model_config
self.model = ae.VAE(
input_height=input_height,
enc_type=enc_type,
enc_out_dim=enc_out_dim,
first_conv=first_conv,
maxpool1=maxpool1,
kl_coeff=kl_coeff,
latent_dim=model_config.latent_dim,
)
self.encoder = self.model.encoder
self.decoder = self.model.decoder
self.input_dim = self.model_config.input_dim
self.latent_dim = self.model_config.latent_dim

def forward(self, x, epoch=None):
# return ModelOutput(x=x,recon_x=x,z=x,loss=1)
# # Forward pass logic
x = x["data"]
# x_recon = self.model(x)
z, recon_x, p, q = self.model._run_step(x)
loss, logs = self.model.step((x, x), batch_idx=epoch)
# recon_loss = self.model.reconstruction_loss(x, recon_x)
return ModelOutput(recon_x=recon_x, z=z, logs=logs, loss=loss, recon_loss=loss)
# class VAEPythaeWrapper(models.VAE):
# def __init__(
# self,
# model_config,
# input_height,
# enc_type="resnet18",
# enc_out_dim=512,
# first_conv=False,
# maxpool1=False,
# kl_coeff=0.1,
# encoder=None,
# decoder=None,
# ):
# super(models.BaseAE, self).__init__()
# self.model_name = "VAE"
# self.model_config = model_config
# self.model = ae.VAE(
# input_height=input_height,
# enc_type=enc_type,
# enc_out_dim=enc_out_dim,
# first_conv=first_conv,
# maxpool1=maxpool1,
# kl_coeff=kl_coeff,
# latent_dim=model_config.latent_dim,
# )
# self.encoder = self.model.encoder
# self.decoder = self.model.decoder

# def forward(self, x, epoch=None):
# # return ModelOutput(x=x,recon_x=x,z=x,loss=1)
# # # Forward pass logic
# x = x["data"]
# x_recon = self.model(x)
# z, recon_x, p, q = self.model._run_step(x)
# loss, logs = self.model.step((x, x), batch_idx=epoch)
# # recon_loss = self.model.reconstruction_loss(x, recon_x)
# return ModelOutput(recon_x=recon_x, z=z, logs=logs, loss=loss, recon_loss=loss)


2 changes: 0 additions & 2 deletions bioimage_embed/models/bolts/vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from pythae.models.nn import BaseDecoder, BaseEncoder
from pythae.models import VAEConfig

from pl_bolts.models import autoencoders as ae


class BaseResNetVQVAEEncoder(BaseEncoder):
def __init__(
Expand Down
56 changes: 28 additions & 28 deletions bioimage_embed/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,28 @@ def dummy_model(self):
lambda x: None,
)

def resnet_vae_bolt(
self,
enc_type,
enc_out_dim,
first_conv=False,
maxpool1=False,
kl_coeff=1.0,
):
return self.create_model(
pythae.models.VAEConfig,
partial(
bolts.vae.VAEPythaeWrapper,
input_height=self.input_dim[1],
enc_type=enc_type,
enc_out_dim=enc_out_dim,
first_conv=first_conv,
maxpool1=maxpool1,
kl_coeff=kl_coeff,
),
encoder_class=lambda x: None,
decoder_class=lambda x: None,
)
# def resnet_vae_bolt(
# self,
# enc_type,
# enc_out_dim,
# first_conv=False,
# maxpool1=False,
# kl_coeff=1.0,
# ):
# return self.create_model(
# pythae.models.VAEConfig,
# partial(
# bolts.vae.VAEPythaeWrapper,
# input_height=self.input_dim[1],
# enc_type=enc_type,
# enc_out_dim=enc_out_dim,
# first_conv=first_conv,
# maxpool1=maxpool1,
# kl_coeff=kl_coeff,
# ),
# encoder_class=lambda x: None,
# decoder_class=lambda x: None,
# )

# bolts.vae.VAEPythaeWrapper(
# input_height=self.input_dim[1],
Expand All @@ -86,11 +86,11 @@ def resnet_vae_bolt(
# latent_dim=self.latent_dim,
# )

def resnet18_vae_bolt(self, **kwargs):
return self.resnet_vae_bolt(enc_type="resnet18", enc_out_dim=512, **kwargs)
# def resnet18_vae_bolt(self, **kwargs):
# return self.resnet_vae_bolt(enc_type="resnet18", enc_out_dim=512, **kwargs)

def resnet50_vae_bolt(self, **kwargs):
return self.resnet_vae_bolt(enc_type="resnet50", enc_out_dim=2048, **kwargs)
# def resnet50_vae_bolt(self, **kwargs):
# return self.resnet_vae_bolt(enc_type="resnet50", enc_out_dim=2048, **kwargs)

def resnet18_vae(self):
return self.create_model(
Expand Down Expand Up @@ -196,8 +196,8 @@ def __call__(self, model):
__all_models__ = [
"resnet18_vae",
"resnet50_vae",
"resnet18_vae_bolt",
"resnet50_vae_bolt",
# "resnet18_vae_bolt",
# "resnet50_vae_bolt",
"resnet18_vqvae",
"resnet50_vqvae",
"resnet18_vqvae_legacy",
Expand Down
Empty file removed models/convert_model.sh
Empty file.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ tikzplotlib = "^0.10.1"
# torchmetrics = "^1.1.1"
torchmetrics = "0.11.4"
# pytorch-lightning-bolts = "^0.3.2.post1"
lightning-bolts = "^0.7.0"
# lightning-bolts = "^0.7.0"

Pillow = "9.5.0"
onnx = "^1.15.0"
typer = "^0.9.0"
Expand Down
Loading

0 comments on commit 951b2bc

Please sign in to comment.