-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #59 from ctr26/dev
Fixing mask_embed for latest changes?
- Loading branch information
Showing
7 changed files
with
172 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from bioimage_embed.shapes.lightning import MaskEmbed, MaskEmbedSupervised | ||
import pytest | ||
from bioimage_embed import create_model | ||
from torchvision.datasets import FakeData | ||
import pytorch_lightning as pl | ||
from bioimage_embed.lightning.dataloader import DataModule | ||
from torchvision.transforms import transforms | ||
from types import SimpleNamespace | ||
|
||
|
||
@pytest.fixture | ||
def transform(): | ||
return transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
] | ||
) | ||
|
||
|
||
@pytest.fixture(params=[1, 2, 16]) | ||
def classes(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture | ||
def dataset(transform, classes): | ||
return FakeData( | ||
size=64, | ||
image_size=(3, 224, 224), | ||
num_classes=classes, | ||
transform=transform, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def model(): | ||
return create_model( | ||
model="resnet18_vae", | ||
input_dim=[3, 224, 224], | ||
latent_dim=64, | ||
pretrained=True, | ||
) | ||
|
||
|
||
# TODO Add tests for MaskEmbedSupervised | ||
@pytest.fixture(params=[MaskEmbed, MaskEmbedSupervised]) | ||
def wrapper(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture | ||
def lit_model(model, wrapper): | ||
args = SimpleNamespace(frobenius_norm=False) | ||
return wrapper(model, args) | ||
|
||
|
||
@pytest.fixture | ||
def trainer(): | ||
return pl.Trainer( | ||
max_epochs=1, | ||
max_steps=1, | ||
# gpus=1, | ||
fast_dev_run=True, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def dataloader(dataset): | ||
return DataModule( | ||
dataset, | ||
batch_size=16, | ||
num_workers=0, | ||
) | ||
|
||
|
||
def test_model(trainer, lit_model, dataloader): | ||
return trainer.test(lit_model, dataloader) | ||
|
||
|
||
def test_model_fit(trainer, lit_model, dataloader): | ||
return trainer.fit(lit_model, dataloader) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters