Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unit tests of TiledDataset Class #25

Merged
merged 4 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def client(context):
client = from_context(context)
recons_container = client.create_container("reconstructions")
recons_container.write_array(np.zeros((2, 3, 3), dtype=np.int8), key="recon1")
masks_container = client.create_container("uid0001", metadata={"mask_idx": ["0"]})
masks_container.write_array(np.zeros((1, 3, 3), dtype=np.int8), key="mask")
masks_container = client.create_container("uid0001", metadata={"mask_idx": ["1"]})
masks_container.write_array(np.ones((1, 3, 3), dtype=np.int8), key="mask")
yield client
41 changes: 41 additions & 0 deletions src/_tests/example_tunet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Example for parameters to excecute

# I/O
io_parameters:
data_tiled_uri:
data_tiled_api_key:
mask_tiled_uri:
mask_tiled_api_key:
seg_tiled_uri:
uid_save:
uid_retrieve:
models_dir: .

model_parameters:
network: "TUNet"
num_classes: 3
num_epochs: 3
optimizer: "Adam"
criterion: "CrossEntropyLoss"
weights: "[1.0, 2.0, 0.5]"
learning_rate: 0.1
activation: "ReLU"
normalization: "BatchNorm2d"
convolution: "Conv2d"

qlty_window: 64
qlty_step: 32
qlty_border: 8

shuffle_train: True
batch_size_train: 1

batch_size_val: 1

batch_size_inference: 2
val_pct: 0.2

depth: 4
base_channels: 8
growth_rate: 2
hidden_rate: 1
Empty file added src/_tests/test_inference.py
Empty file.
44 changes: 40 additions & 4 deletions src/_tests/test_tiled_dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,52 @@
import numpy as np

from ..tiled_dataset import TiledDataset


def test_tiled_dataset(client):
def test_with_mask_training(client):
tiled_dataset = TiledDataset(
client["reconstructions"]["recon1"],
data_tiled_client=client["reconstructions"]["recon1"],
mask_tiled_client=client["uid0001"],
is_training=True,
)
assert tiled_dataset
assert tiled_dataset.mask_idx == [1]
assert len(tiled_dataset) == 1
assert len(tiled_dataset[0]) == 2
# Check data
assert tiled_dataset[0][0].shape == (3, 3)
assert not np.all(tiled_dataset[0][0]) # should be all 0s
# Check mask
assert tiled_dataset[0][1].shape == (3, 3)
assert np.all(tiled_dataset[0][1]) # should be all 1s


def test_with_mask_inference(client):
tiled_dataset = TiledDataset(
data_tiled_client=client["reconstructions"]["recon1"],
mask_tiled_client=client["uid0001"],
is_training=False,
)
assert tiled_dataset
assert tiled_dataset.mask_idx == [1]
assert len(tiled_dataset) == 1
# Check data
assert tiled_dataset[0].shape == (3, 3)
assert not np.all(tiled_dataset[0]) # should be all 0s


def test_tiled_dataset_with_masks(client):
def test_no_mask_inference(client):
tiled_dataset = TiledDataset(
client["reconstructions"]["recon1"], mask_tiled_client=client["uid0001"]
data_tiled_client=client["reconstructions"]["recon1"],
is_training=False,
)
assert tiled_dataset
assert len(tiled_dataset) == 2
# Check data
assert tiled_dataset[0].shape == (3, 3)
assert not np.all(tiled_dataset[0]) # should be all 0s


# TODO: Test qlty cropping within tiled_dataset.
# Since this part has been moved to the training script and performed outside,
# this is not on higher priority.
21 changes: 21 additions & 0 deletions src/_tests/test_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# TODO: test general yaml loading, check param type

# TODO: test model params loading, check pydantic class type / format

# TODO: check dir creation? How to handle file system change during pytest?

# TODO: load TiledDataset from fixture client, test already done.

# TODO: test data and mask array dim and shape

# TODO: test train_loader and val_loader from crop_split_load func, check length

# TODO: test build_network. How to deal with lengthy func? test all network options?

# TODO: test weights and criterion?

# TODO: test dvc?

# TODO: test trainer building

# TODO: test 1 epoch, check param saving
Loading