Skip to content

Commit

Permalink
[ref] ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Sep 14, 2024
1 parent e5fb07e commit 0ffe6ed
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 23 deletions.
21 changes: 15 additions & 6 deletions bioimage_embed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import torch

torch.cuda.empty_cache()
# from . import models, lightning, cli, export, config
from .lightning import AESupervised, AEUnsupervised, AE, AutoEncoderSupervised, AutoEncoderUnsupervised, AutoEncoder
from .lightning import (
AESupervised,
AEUnsupervised,
AE,
AutoEncoderSupervised,
AutoEncoderUnsupervised,
AutoEncoder,
)

# TODO: Fix this import as it currently produces too many warnings
from .models import ModelFactory, create_model
Expand All @@ -13,14 +17,19 @@
# import logging
# logging.captureWarnings(True)

import torch

torch.cuda.empty_cache()
__all__ = [
"AESupervised",
"AutoEncoderUnsupervised",
"AEUnsupervised",
"AutoEncoderSupervised",
"AutoEncoder"
"AE"
"AutoEncoder",
"AE",
"BioImageEmbed",
"Config",
"augmentations",
"ModelFactory",
"create_model",
]
20 changes: 18 additions & 2 deletions bioimage_embed/lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
from .pyro import LitAutoEncoderPyro
from .torch import AESupervised, AEUnsupervised, AutoEncoder, AE, AutoEncoderSupervised, AutoEncoderUnsupervised
from .torch import (
AESupervised,
AEUnsupervised,
AutoEncoder,
AE,
AutoEncoderSupervised,
AutoEncoderUnsupervised,
)
from .dataloader import DataModule

__all__ = ["LitAutoEncoderPyro", "AESupervised", "AEUnsupervised", "DataModule", "AutoEncoder","AE","AutoEncoderUnsupervised","AutoEncoderSupervised"]
__all__ = [
"LitAutoEncoderPyro",
"AESupervised",
"AEUnsupervised",
"DataModule",
"AutoEncoder",
"AE",
"AutoEncoderUnsupervised",
"AutoEncoderSupervised",
]
1 change: 1 addition & 0 deletions bioimage_embed/lightning/tests/test_ndims.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch


# Fixture for batch sizes
@pytest.fixture(params=[1, 16])
def batch_size(request):
Expand Down
23 changes: 12 additions & 11 deletions bioimage_embed/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import torch
import torch.nn.functional as F
# Description: This file is the main entry point for the models module. It imports all the necessary modules and classes for the models module to function properly.

# Note - you must have torchvision installed for this example
from torch.utils.data import DataLoader

# from .ae import AutoEncoder

# from .vae_bio import Mask_VAE, Image_VAE
# from .utils import BaseVAE
# from .legacy.vae import VAE
# from .vq_vae import VQ_VAE

from .bolts import ResNet18VAEEncoder, ResNet18VAEDecoder

from . import bolts
from . import pythae
from .factory import ModelFactory, create_model, __all_models__
from .factory import ModelFactory, create_model, __all_models__

__all__ = [
"ModelFactory",
"create_model",
"__all_models__",
"ResNet18VAEEncoder",
"ResNet18VAEDecoder",
"bolts",
"pythae",
]
1 change: 0 additions & 1 deletion bioimage_embed/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def __call__(self, model):
]



def create_model(
model: str,
input_dim: Tuple[int, int, int],
Expand Down
4 changes: 3 additions & 1 deletion bioimage_embed/models/tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
latent_dim = [64, 16]
pretrained_options = [True, False]
progress_options = [True, False]
batch = [1,]
batch = [
1,
]


@pytest.mark.parametrize("model", __all_models__)
Expand Down
5 changes: 3 additions & 2 deletions bioimage_embed/tests/test_bioimage_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import torch
from ..bie import BioImageEmbed


@pytest.fixture()
def test_bioimage_embed():
bie = BioImageEmbed()
bie.train()
bie.infer()
bie.validate()
model_output = bie(torch.tensor([1, 2, 3, 4, 5]))
tensor = bie.model(torch.tensor([1, 2, 3, 4, 5]))
assert bie(torch.tensor([1, 2, 3, 4, 5]))
assert bie.model(torch.tensor([1, 2, 3, 4, 5]))

bie.model(torch.tensor([1, 2, 3, 4, 5]))

0 comments on commit 0ffe6ed

Please sign in to comment.