diff --git a/src/_tests/conftest.py b/src/_tests/conftest.py index d5ee296..bddcf65 100644 --- a/src/_tests/conftest.py +++ b/src/_tests/conftest.py @@ -33,6 +33,6 @@ def client(context): client = from_context(context) recons_container = client.create_container("reconstructions") recons_container.write_array(np.zeros((2, 3, 3), dtype=np.int8), key="recon1") - masks_container = client.create_container("uid0001", metadata={"mask_idx": ["0"]}) - masks_container.write_array(np.zeros((1, 3, 3), dtype=np.int8), key="mask") + masks_container = client.create_container("uid0001", metadata={"mask_idx": ["1"]}) + masks_container.write_array(np.ones((1, 3, 3), dtype=np.int8), key="mask") yield client diff --git a/src/_tests/example_tunet.yaml b/src/_tests/example_tunet.yaml new file mode 100644 index 0000000..8cd1eeb --- /dev/null +++ b/src/_tests/example_tunet.yaml @@ -0,0 +1,41 @@ +# Example for parameters to excecute + +# I/O +io_parameters: + data_tiled_uri: + data_tiled_api_key: + mask_tiled_uri: + mask_tiled_api_key: + seg_tiled_uri: + uid_save: + uid_retrieve: + models_dir: . + +model_parameters: + network: "TUNet" + num_classes: 3 + num_epochs: 3 + optimizer: "Adam" + criterion: "CrossEntropyLoss" + weights: "[1.0, 2.0, 0.5]" + learning_rate: 0.1 + activation: "ReLU" + normalization: "BatchNorm2d" + convolution: "Conv2d" + + qlty_window: 64 + qlty_step: 32 + qlty_border: 8 + + shuffle_train: True + batch_size_train: 1 + + batch_size_val: 1 + + batch_size_inference: 2 + val_pct: 0.2 + + depth: 4 + base_channels: 8 + growth_rate: 2 + hidden_rate: 1 diff --git a/src/_tests/test_inference.py b/src/_tests/test_inference.py new file mode 100644 index 0000000..e69de29 diff --git a/src/_tests/test_tiled_dataset.py b/src/_tests/test_tiled_dataset.py index a271d4b..b7f10b5 100644 --- a/src/_tests/test_tiled_dataset.py +++ b/src/_tests/test_tiled_dataset.py @@ -1,16 +1,52 @@ +import numpy as np + from ..tiled_dataset import TiledDataset -def test_tiled_dataset(client): +def test_with_mask_training(client): tiled_dataset = TiledDataset( - client["reconstructions"]["recon1"], + data_tiled_client=client["reconstructions"]["recon1"], + mask_tiled_client=client["uid0001"], + is_training=True, ) assert tiled_dataset + assert tiled_dataset.mask_idx == [1] + assert len(tiled_dataset) == 1 + assert len(tiled_dataset[0]) == 2 + # Check data + assert tiled_dataset[0][0].shape == (3, 3) + assert not np.all(tiled_dataset[0][0]) # should be all 0s + # Check mask + assert tiled_dataset[0][1].shape == (3, 3) + assert np.all(tiled_dataset[0][1]) # should be all 1s + + +def test_with_mask_inference(client): + tiled_dataset = TiledDataset( + data_tiled_client=client["reconstructions"]["recon1"], + mask_tiled_client=client["uid0001"], + is_training=False, + ) + assert tiled_dataset + assert tiled_dataset.mask_idx == [1] + assert len(tiled_dataset) == 1 + # Check data assert tiled_dataset[0].shape == (3, 3) + assert not np.all(tiled_dataset[0]) # should be all 0s -def test_tiled_dataset_with_masks(client): +def test_no_mask_inference(client): tiled_dataset = TiledDataset( - client["reconstructions"]["recon1"], mask_tiled_client=client["uid0001"] + data_tiled_client=client["reconstructions"]["recon1"], + is_training=False, ) + assert tiled_dataset + assert len(tiled_dataset) == 2 + # Check data assert tiled_dataset[0].shape == (3, 3) + assert not np.all(tiled_dataset[0]) # should be all 0s + + +# TODO: Test qlty cropping within tiled_dataset. +# Since this part has been moved to the training script and performed outside, +# this is not on higher priority. diff --git a/src/_tests/test_training.py b/src/_tests/test_training.py new file mode 100644 index 0000000..4f8811e --- /dev/null +++ b/src/_tests/test_training.py @@ -0,0 +1,21 @@ +# TODO: test general yaml loading, check param type + +# TODO: test model params loading, check pydantic class type / format + +# TODO: check dir creation? How to handle file system change during pytest? + +# TODO: load TiledDataset from fixture client, test already done. + +# TODO: test data and mask array dim and shape + +# TODO: test train_loader and val_loader from crop_split_load func, check length + +# TODO: test build_network. How to deal with lengthy func? test all network options? + +# TODO: test weights and criterion? + +# TODO: test dvc? + +# TODO: test trainer building + +# TODO: test 1 epoch, check param saving