Skip to content

Commit

Permalink
[fix] resnet50_vae duplicated
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Oct 1, 2024
1 parent 335ea77 commit 5dddb33
Showing 1 changed file with 3 additions and 16 deletions.
19 changes: 3 additions & 16 deletions bioimage_embed/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from . import bolts
from functools import partial


class ModelFactory:
def __init__(
self, input_dim, latent_dim, pretrained=False, progress=True, **kwargs
Expand Down Expand Up @@ -136,26 +137,13 @@ def resnet18_beta_vae(self):
pythae.models.BetaVAEConfig,
use_default_encoder=False,
use_default_decoder=False,
**self.kwargs
**self.kwargs,
),
pythae.models.BetaVAE,
bolts.ResNet18VAEEncoder,
bolts.ResNet18VAEDecoder,
)

def resnet50_vae(self):
return self.create_model(
partial(
pythae.models.VAEConfig,
use_default_encoder=False,
use_default_decoder=False,
**self.kwargs
),
pythae.models.VAE,
bolts.ResNet50VAEEncoder,
bolts.ResNet50VAEDecoder,
)

def resnet50_vqvae(self):
return self.create_model(
partial(
Expand All @@ -175,7 +163,7 @@ def resnet50_beta_vae(self):
pythae.models.BetaVAEConfig,
use_default_encoder=False,
use_default_decoder=False,
**self.kwargs
**self.kwargs,
),
pythae.models.BetaVAE,
bolts.ResNet50VAEEncoder,
Expand Down Expand Up @@ -257,7 +245,6 @@ def __call__(self, model):
]



def create_model(
model: str,
input_dim: Tuple[int, int, int],
Expand Down

0 comments on commit 5dddb33

Please sign in to comment.