From 5a9dd146b4a4e23ff758b304c3ab18e5f495c8b6 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Fri, 18 Oct 2024 09:30:01 -0700 Subject: [PATCH] Vendor pad shape function (#189) * vendor the private function from iohub * add test for the added function * test full shape case --- tests/translation/test_predict_writer.py | 8 ++++++++ viscy/translation/predict_writer.py | 11 ++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 tests/translation/test_predict_writer.py diff --git a/tests/translation/test_predict_writer.py b/tests/translation/test_predict_writer.py new file mode 100644 index 00000000..c757774a --- /dev/null +++ b/tests/translation/test_predict_writer.py @@ -0,0 +1,8 @@ +from viscy.translation.predict_writer import _pad_shape + + +def test_pad_shape(): + assert _pad_shape((2, 3), 3) == (1, 2, 3) + assert _pad_shape((4, 5), 4) == (1, 1, 4, 5) + full_shape = tuple(range(1, 6)) + assert _pad_shape(full_shape, 5) == full_shape diff --git a/viscy/translation/predict_writer.py b/viscy/translation/predict_writer.py index 3e9317b1..2c6affa6 100644 --- a/viscy/translation/predict_writer.py +++ b/viscy/translation/predict_writer.py @@ -5,7 +5,7 @@ import numpy as np import torch -from iohub.ngff import ImageArray, Plate, Position, _pad_shape, open_ome_zarr +from iohub.ngff import ImageArray, Plate, Position, open_ome_zarr from iohub.ngff_meta import TransformationMeta from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import BasePredictionWriter @@ -17,6 +17,15 @@ _logger = logging.getLogger("lightning.pytorch") +def _pad_shape(shape: tuple[int, ...], target: int = 5) -> tuple[int, ...]: + """ + Pad shape tuple to a target length. + Vendored from ``iohub.ngff.nodes._pad_shape()``. + """ + pad = target - len(shape) + return (1,) * pad + shape + + def _resize_image(image: ImageArray, t_index: int, z_slice: slice) -> None: """Resize image array if incoming (1, C, Z, Y, X) stack is not within bounds.""" if image.shape[0] <= t_index or image.shape[2] < z_slice.stop: