Skip to content

Commit

Permalink
add tests for tiling
Browse files Browse the repository at this point in the history
not convinced with the tests with overlap...
  • Loading branch information
franioli committed Jan 24, 2024
1 parent f40e3d9 commit 5031039
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/deep_image_matching/utils/tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def compute_tiles_by_size(

if isinstance(overlap, int):
overlap = (overlap, overlap)
elif isinstance(overlap, tuple) or isinstance(window_size, List):
# transpose to be (H, W)
overlap = (overlap[1], overlap[0])
elif not isinstance(overlap, tuple) or isinstance(window_size, List):
raise TypeError("overlap must be an integer or a tuple of integers")
overlap = overlap
Expand Down
148 changes: 148 additions & 0 deletions tests/test_tiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import numpy as np
import pytest
import torch

from deep_image_matching.utils.tiling import Tiler


@pytest.fixture
def tiler():
return Tiler()


def test_compute_tiles_by_size_no_overlap_no_padding(tiler):
# Create a numpy array with shape (100, 100, 3)
input_shape = (100, 100, 3)
input_image = np.random.randint(0, 255, input_shape, dtype=np.uint8)
window_size = 50
overlap = 0

tiles, origins, padding = tiler.compute_tiles_by_size(
input_image, window_size, overlap
)

# Assert the output types and shapes
assert isinstance(tiles, dict)
assert isinstance(origins, dict)
assert isinstance(padding, tuple)
assert len(padding) == 4

# Assert the number of tiles and origins
assert len(tiles) == 4
assert len(origins) == 4

# Assert the shape of the tiles
for tile in tiles.values():
assert tile.shape == (window_size, window_size, 3)

# Assert the padding values
assert padding == (0, 0, 0, 0)


def test_compute_tiles_by_size_no_overlap_padding(tiler):
# Create a numpy array with shape (100, 100, 3)
input_shape = (100, 100, 3)
input_image = np.random.randint(0, 255, input_shape, dtype=np.uint8)
window_size = 40
overlap = 0

tiles, origins, padding = tiler.compute_tiles_by_size(
input_image, window_size, overlap
)

# Assert the output types and shapes
assert isinstance(tiles, dict)
assert isinstance(origins, dict)
assert isinstance(padding, tuple)
assert len(padding) == 4

# Assert the number of tiles and origins
assert len(tiles) == 9
assert len(origins) == 9

# Assert the shape of the tiles
for tile in tiles.values():
assert tile.shape == (window_size, window_size, 3)

# Assert the padding values
assert padding == (10, 10, 10, 10)


def test_compute_tiles_by_size_overlap_no_padding(tiler):
# Create a numpy array with shape (100, 100, 3)
input_shape = (100, 100, 3)
input_image = np.random.randint(0, 255, input_shape, dtype=np.uint8)
window_size = 50
overlap = 10

tiles, origins, padding = tiler.compute_tiles_by_size(
input_image, window_size, overlap
)

# Assert the output types and shapes
assert isinstance(tiles, dict)
assert isinstance(origins, dict)
assert isinstance(padding, tuple)
assert len(padding) == 4

# Assert the number of tiles and origins
assert len(tiles) == 4
assert len(origins) == 4

# Assert the shape of the tiles
for tile in tiles.values():
assert tile.shape == (window_size, window_size, 3)

# Assert the padding values
assert padding == (0, 0, 0, 0)


def test_compute_tiles_by_size_with_torch_tensor(tiler):
# Create a torch tensor with shape (3, 100, 100)
channels = 3
input_shape = (channels, 100, 100)
input_image = torch.randint(0, 255, input_shape, dtype=torch.uint8)
window_size = (50, 50)
overlap = (0, 0)

tiles, origins, padding = tiler.compute_tiles_by_size(
input_image, window_size, overlap
)

# Assert the output types and shapes
assert isinstance(tiles, dict)
assert isinstance(origins, dict)
assert isinstance(padding, tuple)
assert len(padding) == 4

# Assert the number of tiles and origins
assert len(tiles) == 4
assert len(origins) == 4

# Assert the shape of the tiles
for tile in tiles.values():
assert tile.shape == (window_size[0], window_size[1], channels)

# Assert the padding values
assert padding == (0, 0, 0, 0)


def test_compute_tiles_by_size_with_invalid_input(tiler):
# Create an invalid window_size (a string)
input_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
window_size = "32"
overlap = 8

with pytest.raises(TypeError):
tiler.compute_tiles_by_size(input_image, window_size, overlap)

# Create an invalid overlap (a float)
window_size = 32
overlap = 8.0

with pytest.raises(TypeError):
tiler.compute_tiles_by_size(input_image, window_size, overlap)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 5031039

Please sign in to comment.