Skip to content

Commit

Permalink
Merge pull request #72 from ctr26/bvae
Browse files Browse the repository at this point in the history
Bvae
  • Loading branch information
ctr26 authored Oct 1, 2024
2 parents b4dcce4 + 5dddb33 commit 8eda614
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion bioimage_embed/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,19 @@ def resnet18_vqvae(self):
bolts.ResNet18VQVAEDecoder,
)

def resnet18_beta_vae(self):
return self.create_model(
partial(
pythae.models.BetaVAEConfig,
use_default_encoder=False,
use_default_decoder=False,
**self.kwargs,
),
pythae.models.BetaVAE,
bolts.ResNet18VAEEncoder,
bolts.ResNet18VAEDecoder,
)

def resnet50_vqvae(self):
return self.create_model(
partial(
Expand All @@ -144,6 +157,19 @@ def resnet50_vqvae(self):
bolts.ResNet50VQVAEDecoder,
)

def resnet50_beta_vae(self):
return self.create_model(
partial(
pythae.models.BetaVAEConfig,
use_default_encoder=False,
use_default_decoder=False,
**self.kwargs,
),
pythae.models.BetaVAE,
bolts.ResNet50VAEEncoder,
bolts.ResNet50VAEDecoder,
)

def resnet_vae_legacy(self, depth):
return self.create_model(
pythae.models.VAEConfig,
Expand Down Expand Up @@ -195,7 +221,9 @@ def __call__(self, model):

__all_models__ = [
"resnet18_vae",
"resnet18_beta_vae",
"resnet50_vae",
"resnet50_beta_vae",
"resnet18_vae_bolt",
"resnet50_vae_bolt",
"resnet18_vqvae",
Expand All @@ -217,7 +245,6 @@ def __call__(self, model):
]



def create_model(
model: str,
input_dim: Tuple[int, int, int],
Expand Down

0 comments on commit 8eda614

Please sign in to comment.