From 5031039b5081f2603ec06ca2da5abbc0dfccf325 Mon Sep 17 00:00:00 2001 From: Francesco Ioli Date: Wed, 24 Jan 2024 10:31:55 +0100 Subject: [PATCH] add tests for tiling not convinced with the tests with overlap... --- src/deep_image_matching/utils/tiling.py | 3 + tests/test_tiling.py | 148 ++++++++++++++++++++++++ 2 files changed, 151 insertions(+) create mode 100644 tests/test_tiling.py diff --git a/src/deep_image_matching/utils/tiling.py b/src/deep_image_matching/utils/tiling.py index c4f6ba3e..c8fea665 100644 --- a/src/deep_image_matching/utils/tiling.py +++ b/src/deep_image_matching/utils/tiling.py @@ -89,6 +89,9 @@ def compute_tiles_by_size( if isinstance(overlap, int): overlap = (overlap, overlap) + elif isinstance(overlap, tuple) or isinstance(window_size, List): + # transpose to be (H, W) + overlap = (overlap[1], overlap[0]) elif not isinstance(overlap, tuple) or isinstance(window_size, List): raise TypeError("overlap must be an integer or a tuple of integers") overlap = overlap diff --git a/tests/test_tiling.py b/tests/test_tiling.py new file mode 100644 index 00000000..50f3f9dc --- /dev/null +++ b/tests/test_tiling.py @@ -0,0 +1,148 @@ +import numpy as np +import pytest +import torch + +from deep_image_matching.utils.tiling import Tiler + + +@pytest.fixture +def tiler(): + return Tiler() + + +def test_compute_tiles_by_size_no_overlap_no_padding(tiler): + # Create a numpy array with shape (100, 100, 3) + input_shape = (100, 100, 3) + input_image = np.random.randint(0, 255, input_shape, dtype=np.uint8) + window_size = 50 + overlap = 0 + + tiles, origins, padding = tiler.compute_tiles_by_size( + input_image, window_size, overlap + ) + + # Assert the output types and shapes + assert isinstance(tiles, dict) + assert isinstance(origins, dict) + assert isinstance(padding, tuple) + assert len(padding) == 4 + + # Assert the number of tiles and origins + assert len(tiles) == 4 + assert len(origins) == 4 + + # Assert the shape of the tiles + for tile in tiles.values(): + assert tile.shape == (window_size, window_size, 3) + + # Assert the padding values + assert padding == (0, 0, 0, 0) + + +def test_compute_tiles_by_size_no_overlap_padding(tiler): + # Create a numpy array with shape (100, 100, 3) + input_shape = (100, 100, 3) + input_image = np.random.randint(0, 255, input_shape, dtype=np.uint8) + window_size = 40 + overlap = 0 + + tiles, origins, padding = tiler.compute_tiles_by_size( + input_image, window_size, overlap + ) + + # Assert the output types and shapes + assert isinstance(tiles, dict) + assert isinstance(origins, dict) + assert isinstance(padding, tuple) + assert len(padding) == 4 + + # Assert the number of tiles and origins + assert len(tiles) == 9 + assert len(origins) == 9 + + # Assert the shape of the tiles + for tile in tiles.values(): + assert tile.shape == (window_size, window_size, 3) + + # Assert the padding values + assert padding == (10, 10, 10, 10) + + +def test_compute_tiles_by_size_overlap_no_padding(tiler): + # Create a numpy array with shape (100, 100, 3) + input_shape = (100, 100, 3) + input_image = np.random.randint(0, 255, input_shape, dtype=np.uint8) + window_size = 50 + overlap = 10 + + tiles, origins, padding = tiler.compute_tiles_by_size( + input_image, window_size, overlap + ) + + # Assert the output types and shapes + assert isinstance(tiles, dict) + assert isinstance(origins, dict) + assert isinstance(padding, tuple) + assert len(padding) == 4 + + # Assert the number of tiles and origins + assert len(tiles) == 4 + assert len(origins) == 4 + + # Assert the shape of the tiles + for tile in tiles.values(): + assert tile.shape == (window_size, window_size, 3) + + # Assert the padding values + assert padding == (0, 0, 0, 0) + + +def test_compute_tiles_by_size_with_torch_tensor(tiler): + # Create a torch tensor with shape (3, 100, 100) + channels = 3 + input_shape = (channels, 100, 100) + input_image = torch.randint(0, 255, input_shape, dtype=torch.uint8) + window_size = (50, 50) + overlap = (0, 0) + + tiles, origins, padding = tiler.compute_tiles_by_size( + input_image, window_size, overlap + ) + + # Assert the output types and shapes + assert isinstance(tiles, dict) + assert isinstance(origins, dict) + assert isinstance(padding, tuple) + assert len(padding) == 4 + + # Assert the number of tiles and origins + assert len(tiles) == 4 + assert len(origins) == 4 + + # Assert the shape of the tiles + for tile in tiles.values(): + assert tile.shape == (window_size[0], window_size[1], channels) + + # Assert the padding values + assert padding == (0, 0, 0, 0) + + +def test_compute_tiles_by_size_with_invalid_input(tiler): + # Create an invalid window_size (a string) + input_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + window_size = "32" + overlap = 8 + + with pytest.raises(TypeError): + tiler.compute_tiles_by_size(input_image, window_size, overlap) + + # Create an invalid overlap (a float) + window_size = 32 + overlap = 8.0 + + with pytest.raises(TypeError): + tiler.compute_tiles_by_size(input_image, window_size, overlap) + + +if __name__ == "__main__": + pytest.main([__file__])