From e3c3e0454ca2ef23f83af86d13a96a2111a2223d Mon Sep 17 00:00:00 2001 From: TibbersHao Date: Fri, 14 Jun 2024 16:45:32 -0700 Subject: [PATCH 1/4] modified mask array --- src/_tests/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/_tests/conftest.py b/src/_tests/conftest.py index d5ee296..bddcf65 100644 --- a/src/_tests/conftest.py +++ b/src/_tests/conftest.py @@ -33,6 +33,6 @@ def client(context): client = from_context(context) recons_container = client.create_container("reconstructions") recons_container.write_array(np.zeros((2, 3, 3), dtype=np.int8), key="recon1") - masks_container = client.create_container("uid0001", metadata={"mask_idx": ["0"]}) - masks_container.write_array(np.zeros((1, 3, 3), dtype=np.int8), key="mask") + masks_container = client.create_container("uid0001", metadata={"mask_idx": ["1"]}) + masks_container.write_array(np.ones((1, 3, 3), dtype=np.int8), key="mask") yield client From f7ff3b612636bfdaf17ac2abfca59e77381e2649 Mon Sep 17 00:00:00 2001 From: TibbersHao Date: Fri, 14 Jun 2024 16:46:26 -0700 Subject: [PATCH 2/4] rearranged and added tests for different branches --- src/_tests/test_tiled_dataset.py | 40 ++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/src/_tests/test_tiled_dataset.py b/src/_tests/test_tiled_dataset.py index a271d4b..6ff0eca 100644 --- a/src/_tests/test_tiled_dataset.py +++ b/src/_tests/test_tiled_dataset.py @@ -1,16 +1,46 @@ +import numpy as np from ..tiled_dataset import TiledDataset -def test_tiled_dataset(client): +def test_with_mask_training(client): tiled_dataset = TiledDataset( - client["reconstructions"]["recon1"], + data_tiled_client=client["reconstructions"]["recon1"], + mask_tiled_client=client["uid0001"], + is_training=True, ) assert tiled_dataset - assert tiled_dataset[0].shape == (3, 3) + assert tiled_dataset.mask_idx == [1] + assert len(tiled_dataset) == 1 + assert len(tiled_dataset[0]) == 2 + # Check data + assert tiled_dataset[0][0].shape == (3, 3) + assert not np.all(tiled_dataset[0][0]) # should be all 0s + # Check mask + assert tiled_dataset[0][1].shape == (3, 3) + assert np.all(tiled_dataset[0][1])# should be all 1s +def test_with_mask_inference(client): + tiled_dataset = TiledDataset( + data_tiled_client=client["reconstructions"]["recon1"], + mask_tiled_client=client["uid0001"], + is_training=False, + ) + assert tiled_dataset + assert tiled_dataset.mask_idx == [1] + assert len(tiled_dataset) == 1 + # Check data + assert tiled_dataset[0].shape == (3, 3) + assert not np.all(tiled_dataset[0]) # should be all 0s -def test_tiled_dataset_with_masks(client): +def test_no_mask_inference(client): tiled_dataset = TiledDataset( - client["reconstructions"]["recon1"], mask_tiled_client=client["uid0001"] + data_tiled_client=client["reconstructions"]["recon1"], + is_training=False, ) + assert tiled_dataset + assert len(tiled_dataset) == 2 + # Check data assert tiled_dataset[0].shape == (3, 3) + assert not np.all(tiled_dataset[0]) # should be all 0s + +# TODO: Test qlty cropping within tiled_dataset. Since this part has been moved to the training script and performed outside, this is not on higher priority. From 352d2268c9f7fa15b42bd6ef1f8f90a52ea012a9 Mon Sep 17 00:00:00 2001 From: TibbersHao Date: Fri, 14 Jun 2024 16:46:44 -0700 Subject: [PATCH 3/4] added placeholders --- src/_tests/example_tunet.yaml | 41 +++++++++++++++++++++++++++++++++++ src/_tests/test_inference.py | 2 ++ src/_tests/test_training.py | 27 +++++++++++++++++++++++ 3 files changed, 70 insertions(+) create mode 100644 src/_tests/example_tunet.yaml create mode 100644 src/_tests/test_inference.py create mode 100644 src/_tests/test_training.py diff --git a/src/_tests/example_tunet.yaml b/src/_tests/example_tunet.yaml new file mode 100644 index 0000000..0e27ab7 --- /dev/null +++ b/src/_tests/example_tunet.yaml @@ -0,0 +1,41 @@ +# Example for parameters to excecute + +# I/O +io_parameters: + data_tiled_uri: + data_tiled_api_key: + mask_tiled_uri: + mask_tiled_api_key: + seg_tiled_uri: + uid_save: + uid_retrieve: + models_dir: . + +model_parameters: + network: "TUNet" + num_classes: 3 + num_epochs: 3 + optimizer: "Adam" + criterion: "CrossEntropyLoss" + weights: "[1.0, 2.0, 0.5]" + learning_rate: 0.1 + activation: "ReLU" + normalization: "BatchNorm2d" + convolution: "Conv2d" + + qlty_window: 64 + qlty_step: 32 + qlty_border: 8 + + shuffle_train: True + batch_size_train: 1 + + batch_size_val: 1 + + batch_size_inference: 2 + val_pct: 0.2 + + depth: 4 + base_channels: 8 + growth_rate: 2 + hidden_rate: 1 diff --git a/src/_tests/test_inference.py b/src/_tests/test_inference.py new file mode 100644 index 0000000..05d0d3e --- /dev/null +++ b/src/_tests/test_inference.py @@ -0,0 +1,2 @@ +import numpy as np +from ..train import train diff --git a/src/_tests/test_training.py b/src/_tests/test_training.py new file mode 100644 index 0000000..0b0c348 --- /dev/null +++ b/src/_tests/test_training.py @@ -0,0 +1,27 @@ +import numpy as np +from ..train import train + + +# TODO: test general yaml loading, check param type + +# TODO: test model params loading, check pydantic class type / format + +# TODO: check dir creation? How to handle file system change during pytest? + +# TODO: load TiledDataset from fixture client, test already done. + +# TODO: test data and mask array dim and shape + +# TODO: test train_loader and val_loader from crop_split_load func, check length + +# TODO: test build_network. How to deal with lengthy func? test all network options? + +# TODO: test weights and criterion? + +# TODO: test dvc? + +# TODO: test trainer building + +# TODO: test 1 epoch, check param saving + + From 1a092c7c09f12f7a87e8b8eae65a881e83797e42 Mon Sep 17 00:00:00 2001 From: Tibbers Hao Date: Mon, 17 Jun 2024 14:22:21 -0700 Subject: [PATCH 4/4] fixed format for pre-commit --- src/_tests/example_tunet.yaml | 6 +++--- src/_tests/test_inference.py | 2 -- src/_tests/test_tiled_dataset.py | 22 ++++++++++++++-------- src/_tests/test_training.py | 8 +------- 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/src/_tests/example_tunet.yaml b/src/_tests/example_tunet.yaml index 0e27ab7..8cd1eeb 100644 --- a/src/_tests/example_tunet.yaml +++ b/src/_tests/example_tunet.yaml @@ -2,11 +2,11 @@ # I/O io_parameters: - data_tiled_uri: + data_tiled_uri: data_tiled_api_key: - mask_tiled_uri: + mask_tiled_uri: mask_tiled_api_key: - seg_tiled_uri: + seg_tiled_uri: uid_save: uid_retrieve: models_dir: . diff --git a/src/_tests/test_inference.py b/src/_tests/test_inference.py index 05d0d3e..e69de29 100644 --- a/src/_tests/test_inference.py +++ b/src/_tests/test_inference.py @@ -1,2 +0,0 @@ -import numpy as np -from ..train import train diff --git a/src/_tests/test_tiled_dataset.py b/src/_tests/test_tiled_dataset.py index 6ff0eca..b7f10b5 100644 --- a/src/_tests/test_tiled_dataset.py +++ b/src/_tests/test_tiled_dataset.py @@ -1,10 +1,11 @@ import numpy as np + from ..tiled_dataset import TiledDataset def test_with_mask_training(client): tiled_dataset = TiledDataset( - data_tiled_client=client["reconstructions"]["recon1"], + data_tiled_client=client["reconstructions"]["recon1"], mask_tiled_client=client["uid0001"], is_training=True, ) @@ -14,14 +15,15 @@ def test_with_mask_training(client): assert len(tiled_dataset[0]) == 2 # Check data assert tiled_dataset[0][0].shape == (3, 3) - assert not np.all(tiled_dataset[0][0]) # should be all 0s + assert not np.all(tiled_dataset[0][0]) # should be all 0s # Check mask assert tiled_dataset[0][1].shape == (3, 3) - assert np.all(tiled_dataset[0][1])# should be all 1s + assert np.all(tiled_dataset[0][1]) # should be all 1s + def test_with_mask_inference(client): tiled_dataset = TiledDataset( - data_tiled_client=client["reconstructions"]["recon1"], + data_tiled_client=client["reconstructions"]["recon1"], mask_tiled_client=client["uid0001"], is_training=False, ) @@ -30,17 +32,21 @@ def test_with_mask_inference(client): assert len(tiled_dataset) == 1 # Check data assert tiled_dataset[0].shape == (3, 3) - assert not np.all(tiled_dataset[0]) # should be all 0s + assert not np.all(tiled_dataset[0]) # should be all 0s + def test_no_mask_inference(client): tiled_dataset = TiledDataset( - data_tiled_client=client["reconstructions"]["recon1"], + data_tiled_client=client["reconstructions"]["recon1"], is_training=False, ) assert tiled_dataset assert len(tiled_dataset) == 2 # Check data assert tiled_dataset[0].shape == (3, 3) - assert not np.all(tiled_dataset[0]) # should be all 0s + assert not np.all(tiled_dataset[0]) # should be all 0s + -# TODO: Test qlty cropping within tiled_dataset. Since this part has been moved to the training script and performed outside, this is not on higher priority. +# TODO: Test qlty cropping within tiled_dataset. +# Since this part has been moved to the training script and performed outside, +# this is not on higher priority. diff --git a/src/_tests/test_training.py b/src/_tests/test_training.py index 0b0c348..4f8811e 100644 --- a/src/_tests/test_training.py +++ b/src/_tests/test_training.py @@ -1,14 +1,10 @@ -import numpy as np -from ..train import train - - # TODO: test general yaml loading, check param type # TODO: test model params loading, check pydantic class type / format # TODO: check dir creation? How to handle file system change during pytest? -# TODO: load TiledDataset from fixture client, test already done. +# TODO: load TiledDataset from fixture client, test already done. # TODO: test data and mask array dim and shape @@ -23,5 +19,3 @@ # TODO: test trainer building # TODO: test 1 epoch, check param saving - -