From 1d10fd0d044b0699132d917608e62b98d7fce290 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 30 Oct 2024 10:48:43 -0400 Subject: [PATCH 01/17] =?UTF-8?q?feat:=20=F0=9F=9A=A7=20ImageWriter=20in?= =?UTF-8?q?=20construction.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/cellmap_data/__init__.py | 2 +- src/cellmap_data/image.py | 342 ++++++++++++++++++++++++++++++++--- 2 files changed, 319 insertions(+), 25 deletions(-) diff --git a/src/cellmap_data/__init__.py b/src/cellmap_data/__init__.py index e1d0f0c..abf1ee5 100644 --- a/src/cellmap_data/__init__.py +++ b/src/cellmap_data/__init__.py @@ -19,7 +19,7 @@ from .dataloader import CellMapDataLoader from .datasplit import CellMapDataSplit from .dataset import CellMapDataset -from .image import CellMapImage +from .image import CellMapImage, EmptyImage, ImageWriter from .subdataset import CellMapSubset from . import transforms from . import utils diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 233d7d3..15e863b 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -1,10 +1,19 @@ +from functools import singledispatch import os +from types import SimpleNamespace from typing import Any, Callable, Mapping, Optional, Sequence import torch +from upath import UPath from xarray_ome_ngff.v04.multiscale import coords_from_transforms from pydantic_ome_ngff.v04.multiscale import GroupAttrs, MultiscaleMetadata -from pydantic_ome_ngff.v04.transform import Scale, Translation +from pydantic_ome_ngff.v04.axis import Axis +from pydantic_ome_ngff.v04.transform import ( + Scale, + Translation, + VectorScale, + VectorTranslation, +) import xarray import tensorstore @@ -106,33 +115,46 @@ def __init__( self.context = context self._current_spatial_transforms = None self._current_coords = None - self.device = "cuda" if torch.cuda.is_available() else "cpu" + self._current_center = None + if torch.cuda.is_available(): + self.device = "cuda" + elif torch.backends.mps.is_available(): + self.device = "mps" + else: + self.device = "cpu" def __getitem__(self, center: Mapping[str, float]) -> torch.Tensor: """Returns image data centered around the given point, based on the scale and shape of the target output image.""" - # Find vectors of coordinates in world space to pull data from - coords = {} - for c in self.axes: - if center[c] - self.output_size[c] / 2 < self.bounding_box[c][0]: - # raise ValueError( - UserWarning( - f"Center {center[c]} is out of bounds for axis {c} in image {self.path}. {center[c] - self.output_size[c] / 2} would be less than {self.bounding_box[c][0]}" + if isinstance(list(center.values())[0], int | float): + self._current_center = center + + # Find vectors of coordinates in world space to pull data from + coords = {} + for c in self.axes: + if center[c] - self.output_size[c] / 2 < self.bounding_box[c][0]: + # raise ValueError( + UserWarning( + f"Center {center[c]} is out of bounds for axis {c} in image {self.path}. {center[c] - self.output_size[c] / 2} would be less than {self.bounding_box[c][0]}" + ) + # center[c] = self.bounding_box[c][0] + self.output_size[c] / 2 + if center[c] + self.output_size[c] / 2 > self.bounding_box[c][1]: + # raise ValueError( + UserWarning( + f"Center {center[c]} is out of bounds for axis {c} in image {self.path}. {center[c] + self.output_size[c] / 2} would be greater than {self.bounding_box[c][1]}" + ) + # center[c] = self.bounding_box[c][1] - self.output_size[c] / 2 + coords[c] = np.linspace( + center[c] - self.output_size[c] / 2, + center[c] + self.output_size[c] / 2, + self.output_shape[c], ) - # center[c] = self.bounding_box[c][0] + self.output_size[c] / 2 - if center[c] + self.output_size[c] / 2 > self.bounding_box[c][1]: - # raise ValueError( - UserWarning( - f"Center {center[c]} is out of bounds for axis {c} in image {self.path}. {center[c] + self.output_size[c] / 2} would be greater than {self.bounding_box[c][1]}" - ) - # center[c] = self.bounding_box[c][1] - self.output_size[c] / 2 - coords[c] = np.linspace( - center[c] - self.output_size[c] / 2, - center[c] + self.output_size[c] / 2, - self.output_shape[c], - ) - # Apply any spatial transformations to the coordinates and return the image data as a PyTorch tensor - data = self.apply_spatial_transforms(coords) + # Apply any spatial transformations to the coordinates and return the image data as a PyTorch tensor + data = self.apply_spatial_transforms(coords) + else: + self._current_center = None + self._current_coords = center + data = torch.tensor(self.return_data(self._current_coords).values) # type: ignore # Apply any value transformations to the data if self.value_transform is not None: @@ -475,7 +497,7 @@ def return_data( ), ) -> xarray.DataArray: """Pulls data from the image based on the given coordinates, applying interpolation if necessary, and returns the data as an xarray DataArray.""" - if not isinstance(coords[list(coords.keys())[0]][0], float | int): + if not isinstance(list(coords.values())[0][0], float | int): data = self.array.interp( coords=coords, method=self.interpolation, # type: ignore @@ -605,3 +627,275 @@ def to(self, device: str) -> None: def set_spatial_transforms(self, transforms: Mapping[str, Any] | None) -> None: """Imitates the method in CellMapImage, but does nothing for an EmptyImage object.""" pass + + +class ImageWriter: + """ + This class is used to write image data to a single-resolution zarr. + + Attributes: + path (str): The path to the image file. + label_class (str): The label class of the image. + scale (Mapping[str, float]): The scale of the image in physical space. + write_voxel_shape (Mapping[str, int]): The shape of data written to the image in voxels. + axes (str): The order of the axes in the image. + context (Optional[tensorstore.Context]): The context for the TensorStore. + + Methods: + __setitem__(center: Mapping[str, float], data: torch.Tensor): Writes the given data to the image at the given center. + __repr__() -> str: Returns a string representation of the ImageWriter object. + + Properties: + + shape (Mapping[str, int]): Returns the shape of the image in voxels. + center (Mapping[str, float]): Returns the center of the image in world units. + full_coords: Returns the full coordinates of the image in world units. + array_path (str): Returns the path to the single-scale image array. + translation (Mapping[str, float]): Returns the translation of the image in world units. + bounding_box (Mapping[str, list[float]]): Returns the bounding box of the dataset in world units. + sampling_box (Mapping[str, list[float]]): Returns the sampling box of the dataset in world units. + """ + + def __init__( + self, + path: str | UPath, + label_class: str, + scale: Mapping[str, float] | Sequence[float], + bounding_box: Mapping[str, list[float]], + write_voxel_shape: Mapping[str, int] | Sequence[int], + axis_order: str = "zyx", + context: Optional[tensorstore.Context] = None, + overwrite: bool = False, + dtype: np.dtype = np.float32, + fill_value: float | int = 0, + ) -> None: + """Initializes an ImageWriter object. + + Args: + path (str): The path to the image file. + label_class (str): The label class of the image. + scale (Mapping[str, float]): The scale of the image in physical space. + bounding_box (Mapping[str, list[float]]): The total region of interest for the image in world units. Example: {"x": [12.0, 102.0], "y": [12.0, 102.0], "z": [12.0, 102.0]}. + write_voxel_shape (Mapping[str, int]): The shape of data written to the image in voxels. + axis_order (str, optional): The order of the axes in the image. Defaults to "zyx". + context (Optional[tensorstore.Context], optional): The context for the TensorStore. Defaults to None. + overwrite (bool, optional): Whether to overwrite the image if it already exists. Defaults to False. + dtype (np.dtype, optional): The data type of the image. Defaults to np.float32. + fill_value (float | int, optional): The value to fill the empty image with before values are written. Defaults to 0. + """ + self.path = path + self.label_class = label_class + if isinstance(scale, Sequence): + if len(axis_order) > len(scale): + scale = [scale[0]] * (len(axis_order) - len(scale)) + list(scale) + scale = {c: s for c, s in zip(axis_order, scale)} + if isinstance(write_voxel_shape, Sequence): + if len(axis_order) > len(write_voxel_shape): + write_voxel_shape = [1] * ( + len(axis_order) - len(write_voxel_shape) + ) + list(write_voxel_shape) + write_voxel_shape = {c: t for c, t in zip(axis_order, write_voxel_shape)} + self.scale = scale + self.bounding_box = bounding_box + self.write_voxel_shape = write_voxel_shape + self.write_world_shape = { + c: write_voxel_shape[c] * scale[c] for c in axis_order + } + self.axes = axis_order[: len(write_voxel_shape)] + self.context = context + self.overwrite = overwrite + self.dtype = dtype + self.fill_value = fill_value + + # Create the new zarr's metadata + dims = [c for c in axis_order] + self.metadata = { + "offset": list(self.offset.values()), + "axes": dims, + "voxel_size": list(self.scale.values()), + "shape": list(self.shape.values()), + "units": ["nm"] * len(write_voxel_shape), + "chunk_shape": list(write_voxel_shape.values()), + } + # array = xarray.DataArray( + # self.fill_value, + # dims=self.metadata["axes"], + # coords=self.full_coords, + # attrs=self.metadata, + # name=self.label_class, + # ) + # dataset = array.to_dataset() + # dataset[self.label_class].encoding = {"chunks": self.chunk_shape} + # if overwrite or not UPath(path).exists(): + # dataset.to_zarr(path, mode="w", write_empty_chunks=False, compute=False) + # else: + # dataset.to_zarr(path, mode="a", write_empty_chunks=False, compute=False) + + @property + def array(self) -> xarray.DataArray: + """Returns the image data as an xarray DataArray.""" + try: + return self._array + except AttributeError: + # Construct an xarray with Tensorstore backend + # spec = xt._zarr_spec_from_path(self.path) + spec = { + "driver": "zarr", + "kvstore": {"driver": "file", "path": self.path}, + # "transform": { + # "input_labels": self.metadata["axes"], + # # "scale": self.metadata["voxel_size"], + # "input_inclusive_min": self.metadata["offset"], + # "input_shape": self.metadata["shape"], + # # "units": self.metadata["units"], + # }, + } + open_kwargs = { + "read": True, + "write": True, + "create": True, + "delete_existing": self.overwrite, + "dtype": self.dtype, + "shape": list(self.shape.values()), + "fill_value": self.fill_value, + "chunk_layout": tensorstore.ChunkLayout( + write_chunk_shape=self.chunk_shape + ), + # "metadata": self.metadata, + # "transaction": tensorstore.Transaction(atomic=True), + "context": self.context, + # "dimension_units": ["nm" if c != "c" else "" for c in self.axes], + } + array_future = tensorstore.open( + spec, + **open_kwargs, + ) + try: + array = array_future.result() + except ValueError as e: + Warning(e) + UserWarning("Falling back to zarr3 driver") + spec["driver"] = "zarr3" + array_future = tensorstore.open(spec, **open_kwargs) + array = array_future.result() + data = xt._TensorStoreAdapter(array) + self._array = xarray.DataArray(data=data, coords=self.full_coords) + return self._array + + @property + def chunk_shape(self) -> Sequence[int]: + """Returns the shape of the chunks for the image.""" + try: + return self._chunk_shape + except AttributeError: + self._chunk_shape = list(self.write_voxel_shape.values()) + return self._chunk_shape + + @property + def world_shape(self) -> Mapping[str, float]: + """Returns the shape of the image in world units.""" + try: + return self._world_shape + except AttributeError: + self._world_shape = { + c: self.bounding_box[c][1] - self.bounding_box[c][0] for c in self.axes + } + return self._world_shape + + @property + def shape(self) -> Mapping[str, int]: + """Returns the shape of the image in voxels.""" + try: + return self._shape + except AttributeError: + self._shape = { + c: int(self.world_shape[c] // self.scale[c]) for c in self.axes + } + return self._shape + + @property + def center(self) -> Mapping[str, float]: + """Returns the center of the image in world units.""" + try: + return self._center + except AttributeError: + center = {} + for c, (start, stop) in self.bounding_box.items(): + center[c] = start + (stop - start) / 2 + self._center = center + return self._center + + @property + def offset(self) -> Mapping[str, float]: + """Returns the offset of the image in world units.""" + try: + return self._offset + except AttributeError: + self._offset = {c: self.bounding_box[c][0] for c in self.axes} + return self._offset + + @property + def full_coords(self) -> tuple[xarray.DataArray, ...]: + """Returns the full coordinates of the image in world units.""" + try: + return self._full_coords + except AttributeError: + self._full_coords = coords_from_transforms( + axes=[ + Axis( + name=c, + type="space" if c != "c" else "channel", + unit="nm" if c != "c" else "", + ) + for c in self.axes + ], + transforms=( + VectorScale(scale=tuple(self.scale.values())), + VectorTranslation(translation=tuple(self.offset.values())), + ), + shape=tuple(self.shape.values()), + ) + return self._full_coords + + def align_coords( + self, coords: Mapping[str, tuple[Sequence, np.ndarray]] + ) -> Mapping[str, tuple[Sequence, np.ndarray]]: + """Aligns the given coordinates to the image's coordinates.""" + aligned_coords = {} + for c in self.axes: + if c in coords: + # Align each coorinate for the axis to the nearest image's coordinates + + aligned_coords[c] = ( + coords[c][0], + np.linspace( + self.offset[c], + self.offset[c] + self.world_shape[c], + len(coords[c][0]), + ), + ) + return aligned_coords + + def __setitem__( + self, + coords: Mapping[str, float] | Mapping[str, tuple[Sequence, np.ndarray]], + data: torch.Tensor | np.ndarray, + ) -> None: + """Writes the given data to the image at the given center or coordinates (in world units).""" + # Find vectors of coordinates in world space to write data to if necessary + if isinstance(list(coords.values())[0], int | float): + center = coords + coords = {} + for c in self.axes: + coords[c] = np.linspace( # type: ignore + center[c] - self.write_world_shape[c] / 2, # type: ignore + center[c] + self.write_world_shape[c] / 2, # type: ignore + self.write_voxel_shape[c], + ) + # if isinstance(data, torch.Tensor): + # data = data.cpu().numpy() + self.array.loc[coords] = data + + def __repr__(self) -> str: + """Returns a string representation of the ImageWriter object.""" + return f"ImageWriter({self.path}: {self.label_class} @ {self.scale.values()})" From bbce0ccd1ad55d96616155d387d5a45b7d26a421 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 30 Oct 2024 13:15:55 -0400 Subject: [PATCH 02/17] =?UTF-8?q?fix:=20=E2=9C=A8=20WriteImage=20works.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Need to do proper Zarr metadata setup. --- pyproject.toml | 3 ++- src/cellmap_data/image.py | 33 ++++++++++++++++++++------------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ddfc63e..3342689 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,8 +32,9 @@ dependencies = [ "pydantic_ome_ngff", "xarray_ome_ngff", "tensorstore", + "xarray-tensorstore @ git+https://github.com/rhoadesScholar/xarray-tensorstore.git", # "xarray-tensorstore @ git+https://github.com/google/xarray-tensorstore.git", - "xarray-tensorstore", + # "xarray-tensorstore", "universal_pathlib>=0.2.0", "fsspec[s3,http]", "cellpose", diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 15e863b..4613dad 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -863,17 +863,14 @@ def align_coords( """Aligns the given coordinates to the image's coordinates.""" aligned_coords = {} for c in self.axes: - if c in coords: - # Align each coorinate for the axis to the nearest image's coordinates - - aligned_coords[c] = ( - coords[c][0], - np.linspace( - self.offset[c], - self.offset[c] + self.world_shape[c], - len(coords[c][0]), - ), - ) + # Find the nearest coordinate in the image's actual coordinate grid + aligned_coords[c] = np.array( + self.array.coords[c][ + np.abs(np.array(self.array.coords[c])[:, None] - coords[c]).argmin( + axis=0 + ) + ] + ).squeeze() return aligned_coords def __setitem__( @@ -884,7 +881,7 @@ def __setitem__( """Writes the given data to the image at the given center or coordinates (in world units).""" # Find vectors of coordinates in world space to write data to if necessary if isinstance(list(coords.values())[0], int | float): - center = coords + center = self.align_coords(coords) coords = {} for c in self.axes: coords[c] = np.linspace( # type: ignore @@ -892,9 +889,19 @@ def __setitem__( center[c] + self.write_world_shape[c] / 2, # type: ignore self.write_voxel_shape[c], ) + coords = self.align_coords(coords) + else: + # coords = self.align_coords(coords) + # print( + # "Warning, setting data to specific coordinates is experimental and may be slow." + # ) + # TODO: Add support for writing to full coordinates (not just center) + raise NotImplementedError( + "Writing to specific coordinates is not yet implemented." + ) # if isinstance(data, torch.Tensor): # data = data.cpu().numpy() - self.array.loc[coords] = data + self.array.loc[coords] = data.cpu().squeeze().numpy().astype(self.dtype) def __repr__(self) -> str: """Returns a string representation of the ImageWriter object.""" From bd0ce55facee7f496d38f6e16e3d288d8bd34b4b Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 30 Oct 2024 15:00:27 -0400 Subject: [PATCH 03/17] =?UTF-8?q?feat:=20=E2=9C=A8=20Allow=20write=20to=20?= =?UTF-8?q?array=20with=20scalar?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/cellmap_data/image.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 4613dad..c4a2ed4 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -780,6 +780,17 @@ def array(self) -> xarray.DataArray: array = array_future.result() data = xt._TensorStoreAdapter(array) self._array = xarray.DataArray(data=data, coords=self.full_coords) + + # Set the metadata for the Zarr array + ds = zarr.open_array(self.path) + for key, value in self.metadata.items(): + ds.attrs[key] = value + ds.attrs["_ARRAY_DIMENSIONS"] = self.metadata["axes"] + ds.attrs["dimension_units"] = [ + f"{s} {u}" + for s, u in zip(self.metadata["voxel_size"], self.metadata["units"]) + ] + return self._array @property @@ -899,9 +910,10 @@ def __setitem__( raise NotImplementedError( "Writing to specific coordinates is not yet implemented." ) - # if isinstance(data, torch.Tensor): - # data = data.cpu().numpy() - self.array.loc[coords] = data.cpu().squeeze().numpy().astype(self.dtype) + if isinstance(data, torch.Tensor): + data = data.cpu() + + self.array.loc[coords] = np.array(data).squeeze().astype(self.dtype) def __repr__(self) -> str: """Returns a string representation of the ImageWriter object.""" From 7bf9e61d274a75a469776d798c3d4a8a95a0b63f Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 30 Oct 2024 15:59:04 -0400 Subject: [PATCH 04/17] =?UTF-8?q?fix:=20=F0=9F=90=9B=20Fix=20bug=20where?= =?UTF-8?q?=20center=20slice=20isn't=20written.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/cellmap_data/image.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index c4a2ed4..945490b 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -764,7 +764,6 @@ def array(self) -> xarray.DataArray: # "metadata": self.metadata, # "transaction": tensorstore.Transaction(atomic=True), "context": self.context, - # "dimension_units": ["nm" if c != "c" else "" for c in self.axes], } array_future = tensorstore.open( spec, @@ -892,7 +891,7 @@ def __setitem__( """Writes the given data to the image at the given center or coordinates (in world units).""" # Find vectors of coordinates in world space to write data to if necessary if isinstance(list(coords.values())[0], int | float): - center = self.align_coords(coords) + center = coords coords = {} for c in self.axes: coords[c] = np.linspace( # type: ignore From 77fe16c7a840ea795e341fbbf19e880cbb544f8b Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Thu, 31 Oct 2024 11:23:55 -0400 Subject: [PATCH 05/17] =?UTF-8?q?feat:=20=F0=9F=9A=91=EF=B8=8F=20Write=20m?= =?UTF-8?q?ultiscale=20image=20metadata=20with=20ImageWriter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/cellmap_data/image.py | 78 ++++++++++++++++-------- src/cellmap_data/utils/__init__.py | 6 ++ src/cellmap_data/utils/metadata.py | 97 ++++++++++++++++++++++++++++++ 3 files changed, 156 insertions(+), 25 deletions(-) create mode 100644 src/cellmap_data/utils/metadata.py diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 945490b..51cfbb2 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -1,27 +1,25 @@ -from functools import singledispatch import os -from types import SimpleNamespace from typing import Any, Callable, Mapping, Optional, Sequence -import torch -from upath import UPath -from xarray_ome_ngff.v04.multiscale import coords_from_transforms -from pydantic_ome_ngff.v04.multiscale import GroupAttrs, MultiscaleMetadata +import numpy as np +import tensorstore +import torch +import xarray +import xarray_tensorstore as xt +import zarr from pydantic_ome_ngff.v04.axis import Axis +from pydantic_ome_ngff.v04.multiscale import GroupAttrs, MultiscaleMetadata from pydantic_ome_ngff.v04.transform import ( Scale, Translation, VectorScale, VectorTranslation, ) - -import xarray -import tensorstore -import xarray_tensorstore as xt -import numpy as np -import zarr - from scipy.spatial.transform import Rotation as rot +from upath import UPath +from xarray_ome_ngff.v04.multiscale import coords_from_transforms + +from cellmap_data.utils import create_multiscale_metadata class CellMapImage: @@ -663,6 +661,7 @@ def __init__( scale: Mapping[str, float] | Sequence[float], bounding_box: Mapping[str, list[float]], write_voxel_shape: Mapping[str, int] | Sequence[int], + scale_level: int = 0, axis_order: str = "zyx", context: Optional[tensorstore.Context] = None, overwrite: bool = False, @@ -672,8 +671,9 @@ def __init__( """Initializes an ImageWriter object. Args: - path (str): The path to the image file. + path (str): The path to the base folder of the multiscale image file. label_class (str): The label class of the image. + scale_level (int): The multiscale level of the image. Defaults to 0. scale (Mapping[str, float]): The scale of the image in physical space. bounding_box (Mapping[str, list[float]]): The total region of interest for the image in world units. Example: {"x": [12.0, 102.0], "y": [12.0, 102.0], "z": [12.0, 102.0]}. write_voxel_shape (Mapping[str, int]): The shape of data written to the image in voxels. @@ -683,7 +683,8 @@ def __init__( dtype (np.dtype, optional): The data type of the image. Defaults to np.float32. fill_value (float | int, optional): The value to fill the empty image with before values are written. Defaults to 0. """ - self.path = path + self.base_path = path + self.path = UPath(path) / f"s{scale_level}" self.label_class = label_class if isinstance(scale, Sequence): if len(axis_order) > len(scale): @@ -702,6 +703,7 @@ def __init__( c: write_voxel_shape[c] * scale[c] for c in axis_order } self.axes = axis_order[: len(write_voxel_shape)] + self.scale_level = scale_level self.context = context self.overwrite = overwrite self.dtype = dtype @@ -714,7 +716,7 @@ def __init__( "axes": dims, "voxel_size": list(self.scale.values()), "shape": list(self.shape.values()), - "units": ["nm"] * len(write_voxel_shape), + "units": "nanometer", "chunk_shape": list(write_voxel_shape.values()), } # array = xarray.DataArray( @@ -737,11 +739,34 @@ def array(self) -> xarray.DataArray: try: return self._array except AttributeError: + # Write multi-scale metadata + os.makedirs(UPath(self.base_path), exist_ok=True) + # Add .zgroup files + group_path = str(self.base_path).split(".zarr")[0] + ".zarr" + # print(group_path) + for group in [""] + list( + UPath(str(self.base_path).split(".zarr")[-1]).parts + )[1:]: + group_path = UPath(group_path) / group + # print(group_path) + with open(group_path / ".zgroup", "w") as f: + f.write('{"zarr_format": 2}') + create_multiscale_metadata( + ds_name=str(self.base_path), + voxel_size=self.metadata["voxel_size"], + translation=self.metadata["offset"], + units=self.metadata["units"], + axes=self.metadata["axes"], + base_scale_level=self.scale_level, + levels_to_add=0, + out_path=str(UPath(self.base_path) / ".zattrs"), + ) + # Construct an xarray with Tensorstore backend # spec = xt._zarr_spec_from_path(self.path) spec = { "driver": "zarr", - "kvstore": {"driver": "file", "path": self.path}, + "kvstore": {"driver": "file", "path": str(self.path)}, # "transform": { # "input_labels": self.metadata["axes"], # # "scale": self.metadata["voxel_size"], @@ -780,15 +805,18 @@ def array(self) -> xarray.DataArray: data = xt._TensorStoreAdapter(array) self._array = xarray.DataArray(data=data, coords=self.full_coords) + # Add .zattrs file + with open(UPath(self.path) / ".zattrs", "w") as f: + f.write("{}") # Set the metadata for the Zarr array - ds = zarr.open_array(self.path) - for key, value in self.metadata.items(): - ds.attrs[key] = value - ds.attrs["_ARRAY_DIMENSIONS"] = self.metadata["axes"] - ds.attrs["dimension_units"] = [ - f"{s} {u}" - for s, u in zip(self.metadata["voxel_size"], self.metadata["units"]) - ] + # ds = zarr.open_array(self.path) + # for key, value in self.metadata.items(): + # ds.attrs[key] = value + # ds.attrs["_ARRAY_DIMENSIONS"] = self.metadata["axes"] + # ds.attrs["dimension_units"] = [ + # f"{s} {u}" + # for s, u in zip(self.metadata["voxel_size"], self.metadata["units"]) + # ] return self._array diff --git a/src/cellmap_data/utils/__init__.py b/src/cellmap_data/utils/__init__.py index 57051d1..f53ebd6 100644 --- a/src/cellmap_data/utils/__init__.py +++ b/src/cellmap_data/utils/__init__.py @@ -1,2 +1,8 @@ from .figs import get_image_grid, get_image_dict, get_image_grid_numpy from .dtype import torch_max_value +from .metadata import ( + create_multiscale_metadata, + add_multiscale_metadata_levels, + generate_base_multiscales_metadata, + write_metadata, +) diff --git a/src/cellmap_data/utils/metadata.py b/src/cellmap_data/utils/metadata.py new file mode 100644 index 0000000..97c674f --- /dev/null +++ b/src/cellmap_data/utils/metadata.py @@ -0,0 +1,97 @@ +import json +from typing import Optional + + +def generate_base_multiscales_metadata( + ds_name: str, + scale_level: int, + voxel_size: list, + translation: list, + units: str, + axes: list, +): + if ".zarr" in ds_name: + ds_name = ds_name.split(".zarr")[-1] + z_attrs: dict = {"multiscales": [{}]} + z_attrs["multiscales"][0]["axes"] = [ + {"name": axis, "type": "space", "unit": units} for axis in axes + ] + z_attrs["multiscales"][0]["coordinateTransformations"] = [ + {"scale": [1.0, 1.0, 1.0], "type": "scale"} + ] + z_attrs["multiscales"][0]["datasets"] = [ + { + "coordinateTransformations": [ + {"scale": voxel_size, "type": "scale"}, + {"translation": translation, "type": "translation"}, + ], + "path": f"s{scale_level}", + } + ] + + z_attrs["multiscales"][0]["name"] = ds_name + z_attrs["multiscales"][0]["version"] = "0.4" + + return z_attrs + + +def add_multiscale_metadata_levels(multsc, base_scale_level, levels_to_add): + # store original array in a new .zarr file as an arr_name scale + z_attrs = multsc + # print(z_attrs) + base_scale = z_attrs["multiscales"][0]["datasets"][0]["coordinateTransformations"][ + 0 + ]["scale"] + base_trans = z_attrs["multiscales"][0]["datasets"][0]["coordinateTransformations"][ + 1 + ]["translation"] + for level in range(base_scale_level, base_scale_level + levels_to_add): + # print(f"{level=}") + + # break the slices up into batches, to make things easier for the dask scheduler + sn = [dim * pow(2, level) for dim in base_scale] + trn = [ + (dim * (pow(2, level - 1) - 0.5)) + tr + for (dim, tr) in zip(base_scale, base_trans) + ] + + z_attrs["multiscales"][0]["datasets"].append( + { + "coordinateTransformations": [ + {"type": "scale", "scale": sn}, + {"type": "translation", "translation": trn}, + ], + "path": f"s{level + 1}", + } + ) + + return z_attrs + + +def create_multiscale_metadata( + ds_name: str, + voxel_size: list, + translation: list, + units: str, + axes: list, + base_scale_level: int = 0, + levels_to_add: int = 0, + out_path: Optional[str] = None, +): + z_attrs = generate_base_multiscales_metadata( + ds_name, base_scale_level, voxel_size, translation, units, axes + ) + if levels_to_add > 0: + z_attrs = add_multiscale_metadata_levels( + z_attrs, base_scale_level, levels_to_add + ) + + if out_path is not None: + write_metadata(z_attrs, out_path) + else: + return z_attrs + + +def write_metadata(z_attrs, out_path): + with open(out_path, "w") as f: + f.write(json.dumps(z_attrs, indent=4)) From 3773ce623133c3a201a2add4e8023b5b45e4eaf3 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Thu, 31 Oct 2024 15:42:42 -0400 Subject: [PATCH 06/17] WIP: DatasetWriter --- src/cellmap_data/__init__.py | 1 + src/cellmap_data/dataloader.py | 2 +- src/cellmap_data/dataset_writer.py | 452 +++++++++++++++++++++++++++++ 3 files changed, 454 insertions(+), 1 deletion(-) create mode 100644 src/cellmap_data/dataset_writer.py diff --git a/src/cellmap_data/__init__.py b/src/cellmap_data/__init__.py index abf1ee5..9365857 100644 --- a/src/cellmap_data/__init__.py +++ b/src/cellmap_data/__init__.py @@ -19,6 +19,7 @@ from .dataloader import CellMapDataLoader from .datasplit import CellMapDataSplit from .dataset import CellMapDataset +from .dataset_writer import CellMapDatasetWriter from .image import CellMapImage, EmptyImage, ImageWriter from .subdataset import CellMapSubset from . import transforms diff --git a/src/cellmap_data/dataloader.py b/src/cellmap_data/dataloader.py index eaffe6f..0b105db 100644 --- a/src/cellmap_data/dataloader.py +++ b/src/cellmap_data/dataloader.py @@ -113,7 +113,7 @@ def refresh(self): def collate_fn(self, batch: list[dict]) -> dict[str, torch.Tensor]: """Combine a list of dictionaries from different sources into a single dictionary for output.""" - outputs: dict[str, torch.Tensor] = {} + outputs = {} for b in batch: for key, value in b.items(): if key not in outputs: diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py new file mode 100644 index 0000000..75e606a --- /dev/null +++ b/src/cellmap_data/dataset_writer.py @@ -0,0 +1,452 @@ +# %% +import os +from typing import Any, Callable, Mapping, Sequence, Optional +import numpy as np +import torch +from torch.utils.data import Dataset +import tensorstore +from upath import UPath +from .image import CellMapImage, EmptyImage, ImageWriter +import logging + +logger = logging.getLogger(__name__) + + +def split_target_path(path: str) -> tuple[str, list[str]]: + """Splits a path to groundtruth data into the main path string, and the classes supplied for it.""" + try: + path_prefix, path_rem = path.split("[") + classes, path_suffix = path_rem.split("]") + classes = classes.split(",") + path_string = path_prefix + "{label}" + path_suffix + except ValueError: + path_string = path + classes = [path.split(os.path.sep)[-1]] + return path_string, classes + + +# %% +class CellMapDatasetWriter(Dataset): + """ + This class is used to write a dataset to disk in a format that can be read by the CellMapDataset class. It is useful, for instance, for writing predictions from a model to disk. + """ + + def __init__( + self, + raw_path: str, # TODO: Switch "raw_path" to "input_path" + target_path: str, + classes: Sequence[str], + input_arrays: Mapping[str, Mapping[str, Sequence[int | float]]], + target_arrays: Mapping[str, Mapping[str, Sequence[int | float]]], + target_bounds: Mapping[str, Mapping[str, list[float]]], + raw_value_transforms: Optional[Callable] = None, + axis_order: str = "zyx", + context: Optional[tensorstore.Context] = None, # type: ignore + rng: Optional[torch.Generator] = None, + empty_value: float | int = 0, + overwrite: bool = False, + ) -> None: + """Initializes the CellMapDatasetWriter. + + Args: + + raw_path (str): The full path to the raw data zarr, excluding the mulstiscale level. + target_path (str): The full path to the ground truth data zarr, excluding the mulstiscale level and the class name. + classes (Sequence[str]): The classes in the dataset. + input_arrays (Mapping[str, Mapping[str, Sequence[int | float]]]): The input arrays to return for processing. The dictionary should have the following structure:: + + { + "array_name": { + "shape": tuple[int], + "scale": Sequence[float], + + and optionally: + "scale_level": int, + }, + ... + } + + where 'array_name' is the name of the array, 'shape' is the shape of the array in voxels, and 'scale' is the scale of the array in world units. The 'scale_level' is the multiscale level to use for the array, otherwise set to 0 if not supplied. + target_arrays (Mapping[str, Mapping[str, Sequence[int | float]]]): The target arrays to write to disk, with format matching that for input_arrays. + target_bounds (Mapping[str, Mapping[str, list[float]]]): The bounding boxes for each target array, in world units. Example: {"array_1": {"x": [12.0, 102.0], "y": [12.0, 102.0], "z": [12.0, 102.0]}}. + raw_value_transforms (Optional[Callable]): The value transforms to apply to the raw data. + axis_order (str): The order of the axes in the data. + context (Optional[tensorstore.Context]): The context to use for the tensorstore. + rng (Optional[torch.Generator]): The random number generator to use. + empty_value (float | int): The value to use for empty data in an array. + overwrite (bool): Whether to overwrite existing data. + """ + self.raw_path = raw_path + self.target_path = target_path + self.classes = classes + self.input_arrays = input_arrays + self.target_arrays = target_arrays + self.target_bounds = target_bounds + self.raw_value_transforms = raw_value_transforms + self.axis_order = axis_order + self.context = context + self._rng = rng + self.empty_value = empty_value + self.overwrite = overwrite + self._current_center = None + self.input_sources: dict[str, CellMapImage] = {} + for array_name, array_info in self.input_arrays.items(): + self.input_sources[array_name] = CellMapImage( + self.raw_path, + "raw", + array_info["scale"], + array_info["shape"], # type: ignore + value_transform=self.raw_value_transforms, + context=self.context, + pad=True, + pad_value=0, # inputs to the network should be padded with 0 + interpolation="linear", + ) + self.target_array_writers: dict[str, dict[str, ImageWriter]] = {} + for array_name, array_info in self.target_arrays.items(): + self.target_array_writers[array_name] = self.get_target_array_writer( + array_name, array_info + ) + + @property + def center(self) -> Mapping[str, float] | None: + """Returns the center of the dataset in world units.""" + try: + return self._center + except AttributeError: + if self.bounding_box is None: + self._center = None + else: + center = {} + for c, (start, stop) in self.bounding_box.items(): + center[c] = start + (stop - start) / 2 + self._center = center + return self._center + + # TODO + @property + def largest_voxel_sizes(self) -> Mapping[str, float]: + """Returns the largest voxel size of the dataset.""" + try: + return self._largest_voxel_sizes + except AttributeError: + largest_voxel_size = {c: 0.0 for c in self.axis_order} + for source in list(self.input_sources.values()) + list( + self.target_sources.values() + ): + if isinstance(source, dict): + for _, source in source.items(): + if not hasattr(source, "scale") or source.scale is None: # type: ignore + continue + for c, size in source.scale.items(): # type: ignore + largest_voxel_size[c] = max(largest_voxel_size[c], size) + else: + if not hasattr(source, "scale") or source.scale is None: + continue + for c, size in source.scale.items(): + largest_voxel_size[c] = max(largest_voxel_size[c], size) + self._largest_voxel_sizes = largest_voxel_size + + return self._largest_voxel_sizes + + # TODO + @property + def bounding_box(self) -> Mapping[str, list[float]]: + """Returns the bounding box of the dataset.""" + try: + return self._bounding_box + except AttributeError: + bounding_box = None + for source in list(self.input_sources.values()) + list( + self.target_sources.values() + ): + if isinstance(source, dict): + for source in source.values(): + if not hasattr(source, "bounding_box"): + continue + bounding_box = self._get_box_intersection( + source.bounding_box, bounding_box # type: ignore + ) + else: + if not hasattr(source, "bounding_box"): + continue + bounding_box = self._get_box_intersection( + source.bounding_box, bounding_box + ) + if bounding_box is None: + logger.warning( + "Bounding box is None. This may result in errors when trying to sample from the dataset." + ) + bounding_box = {c: [-np.inf, np.inf] for c in self.axis_order} + self._bounding_box = bounding_box + return self._bounding_box + + @property + def bounding_box_shape(self) -> Mapping[str, int]: + """Returns the shape of the bounding box of the dataset in voxels of the largest voxel size requested.""" + try: + return self._bounding_box_shape + except AttributeError: + self._bounding_box_shape = self._get_box_shape(self.bounding_box) + return self._bounding_box_shape + + # TODO + @property + def sampling_box(self) -> Mapping[str, list[float]]: + """Returns the sampling box of the dataset (i.e. where centers can be drawn from and still have full samples drawn from within the bounding box).""" + try: + return self._sampling_box + except AttributeError: + sampling_box = None + for source in list(self.input_sources.values()) + list( + self.target_sources.values() + ): + if isinstance(source, dict): + for source in source.values(): + if not hasattr(source, "sampling_box"): + continue + sampling_box = self._get_box_intersection( + source.sampling_box, sampling_box # type: ignore + ) + else: + if not hasattr(source, "sampling_box"): + continue + sampling_box = self._get_box_intersection( + source.sampling_box, sampling_box + ) + if sampling_box is None: + logger.warning( + "Sampling box is None. This may result in errors when trying to sample from the dataset." + ) + sampling_box = {c: [-np.inf, np.inf] for c in self.axis_order} + self._sampling_box = sampling_box + return self._sampling_box + + @property + def sampling_box_shape(self) -> dict[str, int]: + """Returns the shape of the sampling box of the dataset in voxels of the largest voxel size requested.""" + try: + return self._sampling_box_shape + except AttributeError: + self._sampling_box_shape = self._get_box_shape(self.sampling_box) + if self.pad: + for c, size in self._sampling_box_shape.items(): + if size <= 0: + logger.warning( + f"Sampling box shape is <= 0 for axis {c} with size {size}. Setting to 1" + ) + self._sampling_box_shape[c] = 1 + return self._sampling_box_shape + + @property + def size(self) -> int: + """Returns the size of the dataset in voxels of the largest voxel size requested.""" + try: + return self._size + except AttributeError: + self._size = np.prod( + [stop - start for start, stop in self.bounding_box.values()] + ).astype(int) + return self._size + + # TODO: Switch this to write_indices + @property + def validation_indices(self) -> Sequence[int]: + """Returns the indices of the dataset that will produce non-overlapping tiles for use in validation, based on the largest requested voxel size.""" + try: + return self._validation_indices + except AttributeError: + chunk_size = {} + for c, size in self.bounding_box_shape.items(): + chunk_size[c] = np.ceil(size - self.sampling_box_shape[c]).astype(int) + self._validation_indices = self.get_indices(chunk_size) + return self._validation_indices + + @property + def device(self) -> torch.device: + """Returns the device for the dataset.""" + try: + return self._device + except AttributeError: + if torch.cuda.is_available(): + self._device = torch.device("cuda") + elif torch.backends.mps.is_available(): + self._device = torch.device("mps") + else: + self._device = torch.device("cpu") + self.to(self._device) + return self._device + + def __len__(self) -> int: + """Returns the length of the dataset, determined by the number of coordinates that could be sampled as the center for an array request.""" + try: + return self._len + except AttributeError: + size = np.prod([self.sampling_box_shape[c] for c in self.axis_order]) + self._len = int(size) + return self._len + + def get_center(self, idx: int) -> dict[str, float]: + try: + center = np.unravel_index( + idx, [self.sampling_box_shape[c] for c in self.axis_order] + ) + except ValueError: + logger.error( + f"Index {idx} out of bounds for dataset {self} of length {len(self)}" + ) + logger.warning(f"Returning closest index in bounds") + # TODO: This is a hacky temprorary fix. Need to figure out why this is happening + center = [self.sampling_box_shape[c] - 1 for c in self.axis_order] + center = { + c: float(center[i] * self.largest_voxel_sizes[c] + self.sampling_box[c][0]) + for i, c in enumerate(self.axis_order) + } + return center + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + """Returns a crop of the input and target data as PyTorch tensors, corresponding to the coordinate of the unwrapped index.""" + + self._current_idx = idx + self._current_center = self.get_center(idx) + outputs = {} + for array_name in self.input_arrays.keys(): + array = self.input_sources[array_name][center] # type: ignore + # TODO: Assumes 1 channel (i.e. grayscale) + if array.shape[0] != 1: + outputs[array_name] = array[None, ...] + else: + outputs[array_name] = array + outputs["idx"] = idx + + return outputs + + def __setitem__( + self, idx: int, arrays: dict[str, torch.Tensor | np.ndarray] + ) -> None: + """ + Writes the values for the given arrays at the given index. + + Args: + idx (int): The index to write the arrays to. + arrays (dict[str, torch.Tensor | np.ndarray]): The arrays to write to disk, with data either split by label class into a dictionary, or divided by class along the channel dimension of an array/tensor. The dictionary should have the following structure:: + + { + "array_name": torch.Tensor | np.ndarray | dict[str, torch.Tensor | np.ndarray], + ... + } + """ + self._current_idx = idx + self._current_center = center = self.get_center(idx) + for array_name, array in arrays.items(): + if isinstance(array, dict): + for label, array in array.items(): + self.target_array_writers[array_name][label][center] = array + else: + for c, label in enumerate(self.classes): + self.target_array_writers[array_name][label][center] = array[c] + + def __repr__(self) -> str: + """Returns a string representation of the dataset.""" + return f"CellMapDataset(\n\tRaw path: {self.raw_path}\n\tOutput path(s): {self.target_path}\n\tClasses: {self.classes})" + + def get_target_array_writer( + self, array_name: str, array_info: Mapping[str, Sequence[int | float]] + ) -> dict[str, ImageWriter]: + """Returns a dictionary of ImageWriter for the target images (per class) for a given array.""" + target_image_writers = {} + for label in self.classes: + target_image_writers[label] = self.get_image_writer( + array_name, label, array_info + ) + + return target_image_writers + + def get_image_writer( + self, + array_name: str, + label: str, + array_info: Mapping[str, Sequence[int | float] | int], + ) -> ImageWriter: + return ImageWriter( + path=str(UPath(self.target_path) / label), + label_class=label, + scale=array_info["scale"], # type: ignore + bounding_box=self.target_bounds[array_name], + write_voxel_shape=array_info["shape"], # type: ignore + scale_level=array_info.get("scale_level", 0), # type: ignore + axis_order=self.axis_order, + context=self.context, + fill_value=self.empty_value, + overwrite=self.overwrite, + ) + + def _get_box_shape(self, source_box: Mapping[str, list[float]]) -> dict[str, int]: + """Returns the shape of the box in voxels of the largest voxel size requested.""" + box_shape = {} + for c, (start, stop) in source_box.items(): + size = stop - start + size /= self.largest_voxel_sizes[c] + box_shape[c] = int(np.floor(size)) + return box_shape + + def _get_box_intersection( + self, + source_box: Mapping[str, list[float]] | None, + current_box: Mapping[str, list[float]] | None, + ) -> Mapping[str, list[float]] | None: + """Returns the intersection of the source and current boxes.""" + if source_box is not None: + if current_box is None: + return source_box + for c, (start, stop) in source_box.items(): + assert stop > start, f"Invalid box: {start} to {stop}" + current_box[c][0] = max(current_box[c][0], start) + current_box[c][1] = min(current_box[c][1], stop) + return current_box + + def verify(self) -> bool: + """Verifies that the dataset is valid to draw samples from.""" + # TODO: make more robust + try: + return len(self) > 0 + except Exception as e: + print(f"Error: {e}") + return False + + def get_indices(self, chunk_size: Mapping[str, int]) -> Sequence[int]: + """Returns the indices of the dataset that will tile the dataset according to the chunk_size.""" + # TODO: ADD TEST + # Get padding per axis + indices_dict = {} + for c, size in chunk_size.items(): + indices_dict[c] = np.arange(0, self.sampling_box_shape[c], size, dtype=int) + + indices = [] + # Generate linear indices by unraveling all combinations of axes indices + for i in np.ndindex(*[len(indices_dict[c]) for c in self.axis_order]): + index = [indices_dict[c][j] for c, j in zip(self.axis_order, i)] + index = np.ravel_multi_index(index, list(self.sampling_box_shape.values())) + indices.append(index) + return indices + + def to(self, device: str | torch.device) -> "CellMapDatasetWriter": + """Sets the device for the dataset.""" + self._device = torch.device(device) + for source in self.input_sources.values(): + if isinstance(source, dict): + for source in source.values(): + if not hasattr(source, "to"): + continue + source.to(device) + else: + if not hasattr(source, "to"): + continue + source.to(device) + return self + + def set_raw_value_transforms(self, transforms: Callable) -> None: + """Sets the raw value transforms for the dataset.""" + self.raw_value_transforms = transforms + for source in self.input_sources.values(): + source.value_transform = transforms From 7619c173ff74ea7ce8885af1f429e61984e65dfd Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 1 Nov 2024 10:33:47 -0400 Subject: [PATCH 07/17] =?UTF-8?q?fix:=20=F0=9F=90=9B=20Fix=20coordinate=20?= =?UTF-8?q?alignement=20in=20image=20writer.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/cellmap_data/dataset_writer.py | 107 +++++++++++++---------------- src/cellmap_data/image.py | 35 +++++++--- 2 files changed, 73 insertions(+), 69 deletions(-) diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index 75e606a..a1f5ea4 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -123,56 +123,40 @@ def center(self) -> Mapping[str, float] | None: self._center = center return self._center - # TODO @property - def largest_voxel_sizes(self) -> Mapping[str, float]: - """Returns the largest voxel size of the dataset.""" + def smallest_voxel_sizes(self) -> Mapping[str, float]: + """Returns the smallest voxel size of the dataset.""" try: - return self._largest_voxel_sizes + return self._smallest_voxel_sizes except AttributeError: - largest_voxel_size = {c: 0.0 for c in self.axis_order} + smallest_voxel_size = {c: np.inf for c in self.axis_order} for source in list(self.input_sources.values()) + list( - self.target_sources.values() + self.target_array_writers.values() ): if isinstance(source, dict): for _, source in source.items(): if not hasattr(source, "scale") or source.scale is None: # type: ignore continue for c, size in source.scale.items(): # type: ignore - largest_voxel_size[c] = max(largest_voxel_size[c], size) + smallest_voxel_size[c] = min(smallest_voxel_size[c], size) else: if not hasattr(source, "scale") or source.scale is None: continue for c, size in source.scale.items(): - largest_voxel_size[c] = max(largest_voxel_size[c], size) - self._largest_voxel_sizes = largest_voxel_size + smallest_voxel_size[c] = min(smallest_voxel_size[c], size) + self._smallest_voxel_sizes = smallest_voxel_size - return self._largest_voxel_sizes + return self._smallest_voxel_sizes - # TODO @property def bounding_box(self) -> Mapping[str, list[float]]: - """Returns the bounding box of the dataset.""" + """Returns the bounding box inclusive of all the target images.""" try: return self._bounding_box except AttributeError: bounding_box = None - for source in list(self.input_sources.values()) + list( - self.target_sources.values() - ): - if isinstance(source, dict): - for source in source.values(): - if not hasattr(source, "bounding_box"): - continue - bounding_box = self._get_box_intersection( - source.bounding_box, bounding_box # type: ignore - ) - else: - if not hasattr(source, "bounding_box"): - continue - bounding_box = self._get_box_intersection( - source.bounding_box, bounding_box - ) + for bounding_box in self.target_bounds.values(): + ... if bounding_box is None: logger.warning( "Bounding box is None. This may result in errors when trying to sample from the dataset." @@ -183,7 +167,7 @@ def bounding_box(self) -> Mapping[str, list[float]]: @property def bounding_box_shape(self) -> Mapping[str, int]: - """Returns the shape of the bounding box of the dataset in voxels of the largest voxel size requested.""" + """Returns the shape of the bounding box of the dataset in voxels of the smallest voxel size requested.""" try: return self._bounding_box_shape except AttributeError: @@ -199,7 +183,7 @@ def sampling_box(self) -> Mapping[str, list[float]]: except AttributeError: sampling_box = None for source in list(self.input_sources.values()) + list( - self.target_sources.values() + self.target_array_writers.values() ): if isinstance(source, dict): for source in source.values(): @@ -224,23 +208,22 @@ def sampling_box(self) -> Mapping[str, list[float]]: @property def sampling_box_shape(self) -> dict[str, int]: - """Returns the shape of the sampling box of the dataset in voxels of the largest voxel size requested.""" + """Returns the shape of the sampling box of the dataset in voxels of the smallest voxel size requested.""" try: return self._sampling_box_shape except AttributeError: self._sampling_box_shape = self._get_box_shape(self.sampling_box) - if self.pad: - for c, size in self._sampling_box_shape.items(): - if size <= 0: - logger.warning( - f"Sampling box shape is <= 0 for axis {c} with size {size}. Setting to 1" - ) - self._sampling_box_shape[c] = 1 + for c, size in self._sampling_box_shape.items(): + if size <= 0: + logger.warning( + f"Sampling box shape is <= 0 for axis {c} with size {size}. Setting to 1" + ) + self._sampling_box_shape[c] = 1 return self._sampling_box_shape @property def size(self) -> int: - """Returns the size of the dataset in voxels of the largest voxel size requested.""" + """Returns the size of the dataset in voxels of the smallest voxel size requested.""" try: return self._size except AttributeError: @@ -251,16 +234,16 @@ def size(self) -> int: # TODO: Switch this to write_indices @property - def validation_indices(self) -> Sequence[int]: - """Returns the indices of the dataset that will produce non-overlapping tiles for use in validation, based on the largest requested voxel size.""" + def writer_indices(self) -> Sequence[int]: + """Returns the indices of the dataset that will produce non-overlapping tiles for use in writer, based on the smallest requested voxel size.""" try: - return self._validation_indices + return self._writer_indices except AttributeError: chunk_size = {} for c, size in self.bounding_box_shape.items(): chunk_size[c] = np.ceil(size - self.sampling_box_shape[c]).astype(int) - self._validation_indices = self.get_indices(chunk_size) - return self._validation_indices + self._writer_indices = self.get_indices(chunk_size) + return self._writer_indices @property def device(self) -> torch.device: @@ -299,7 +282,7 @@ def get_center(self, idx: int) -> dict[str, float]: # TODO: This is a hacky temprorary fix. Need to figure out why this is happening center = [self.sampling_box_shape[c] - 1 for c in self.axis_order] center = { - c: float(center[i] * self.largest_voxel_sizes[c] + self.sampling_box[c][0]) + c: float(center[i] * self.smallest_voxel_sizes[c] + self.sampling_box[c][0]) for i, c in enumerate(self.axis_order) } return center @@ -322,13 +305,15 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: return outputs def __setitem__( - self, idx: int, arrays: dict[str, torch.Tensor | np.ndarray] + self, + indices: int | torch.Tensor | np.ndarray | Sequence[int], + arrays: dict[str, torch.Tensor | np.ndarray], ) -> None: """ Writes the values for the given arrays at the given index. Args: - idx (int): The index to write the arrays to. + indices (int | torch.Tensor | np.ndarray | Sequence[int]): The index or indices to write the arrays to. arrays (dict[str, torch.Tensor | np.ndarray]): The arrays to write to disk, with data either split by label class into a dictionary, or divided by class along the channel dimension of an array/tensor. The dictionary should have the following structure:: { @@ -336,15 +321,21 @@ def __setitem__( ... } """ - self._current_idx = idx - self._current_center = center = self.get_center(idx) - for array_name, array in arrays.items(): - if isinstance(array, dict): - for label, array in array.items(): - self.target_array_writers[array_name][label][center] = array - else: - for c, label in enumerate(self.classes): - self.target_array_writers[array_name][label][center] = array[c] + if isinstance(indices, int): + indices = [indices] # type: ignore + for idx in indices: # type: ignore + self._current_idx = idx + self._current_center = center = self.get_center(idx) + for array_name, array in arrays.items(): + if isinstance(array, int) or isinstance(array, float): + for c, label in enumerate(self.classes): + self.target_array_writers[array_name][label][center] = array + elif isinstance(array, dict): + for label, array in array.items(): + self.target_array_writers[array_name][label][center] = array + else: + for c, label in enumerate(self.classes): + self.target_array_writers[array_name][label][center] = array[c] def __repr__(self) -> str: """Returns a string representation of the dataset.""" @@ -382,11 +373,11 @@ def get_image_writer( ) def _get_box_shape(self, source_box: Mapping[str, list[float]]) -> dict[str, int]: - """Returns the shape of the box in voxels of the largest voxel size requested.""" + """Returns the shape of the box in voxels of the smallest voxel size requested.""" box_shape = {} for c, (start, stop) in source_box.items(): size = stop - start - size /= self.largest_voxel_sizes[c] + size /= self.smallest_voxel_sizes[c] box_shape[c] = int(np.floor(size)) return box_shape diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 51cfbb2..2d6619e 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -911,23 +911,36 @@ def align_coords( ).squeeze() return aligned_coords + def aligned_coords_from_center(self, center: Mapping[str, float]): + """Returns the aligned coordinates for the given center with linear sequential coordinates aligned to the image's reference frame.""" + coords = {} + for c in self.axes: + # Get number of voxels for the axis + num_voxels = self.write_voxel_shape[c] + + # Get index of closest start voxel to the edge of the write space, based on the center + start_requested = center[c] - self.write_world_shape[c] / 2 + start_aligned_idx = int( + np.abs(self.array.coords[c] - start_requested).argmin() + ) + + # Get the aligned range of coordinates + coords[c] = self.array.coords[c][ + start_aligned_idx : start_aligned_idx + num_voxels + ] + + return coords + def __setitem__( self, coords: Mapping[str, float] | Mapping[str, tuple[Sequence, np.ndarray]], - data: torch.Tensor | np.ndarray, + data: torch.Tensor | np.typing.ArrayLike | float | int, # type: ignore ) -> None: - """Writes the given data to the image at the given center or coordinates (in world units).""" + """Writes the given data to the image at the given center or coordinates (in world units). Supports writing torch.Tensor, numpy.ndarray, and scalar data types.""" # Find vectors of coordinates in world space to write data to if necessary if isinstance(list(coords.values())[0], int | float): center = coords - coords = {} - for c in self.axes: - coords[c] = np.linspace( # type: ignore - center[c] - self.write_world_shape[c] / 2, # type: ignore - center[c] + self.write_world_shape[c] / 2, # type: ignore - self.write_voxel_shape[c], - ) - coords = self.align_coords(coords) + coords = self.aligned_coords_from_center(center) # type: ignore else: # coords = self.align_coords(coords) # print( @@ -944,4 +957,4 @@ def __setitem__( def __repr__(self) -> str: """Returns a string representation of the ImageWriter object.""" - return f"ImageWriter({self.path}: {self.label_class} @ {self.scale.values()})" + return f"ImageWriter({self.path}: {self.label_class} @ {list(self.scale.values())} {self.metadata['units']})" From 65e47f23d45b0a4789386c07573ca17346e8e9f4 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Sun, 3 Nov 2024 22:26:34 -0500 Subject: [PATCH 08/17] =?UTF-8?q?feat:=20=E2=9C=A8=20Add=20negative=20and?= =?UTF-8?q?=20vectorized=20indexing=20of=20datasets.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 5 +- src/cellmap_data/dataloader.py | 5 +- src/cellmap_data/dataset.py | 2 + src/cellmap_data/dataset_writer.py | 146 +++++++++++++++++++---------- src/cellmap_data/image.py | 73 +++++++-------- src/cellmap_data/subdataset.py | 7 +- 6 files changed, 146 insertions(+), 92 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3342689..d7b4a73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "hatchling.build" name = "cellmap-data" description = "Utility for loading CellMap data for machine learning training, utilizing PyTorch, Xarray, TensorStore, and PyDantic." readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" license = { text = "BSD 3-Clause License" } authors = [ { email = "rhoadesj@hhmi.org", name = "Jeff Rhoades" }, @@ -17,9 +17,8 @@ classifiers = [ "Development Status :: 3 - Alpha", "License :: OSI Approved :: BSD License", "Natural Language :: English", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Typing :: Typed", ] dynamic = ["version"] diff --git a/src/cellmap_data/dataloader.py b/src/cellmap_data/dataloader.py index 0b105db..a4d6653 100644 --- a/src/cellmap_data/dataloader.py +++ b/src/cellmap_data/dataloader.py @@ -3,6 +3,7 @@ from .dataset import CellMapDataset from .multidataset import CellMapMultiDataset from .subdataset import CellMapSubset +from .dataset_writer import CellMapDatasetWriter from typing import Callable, Iterable, Optional, Sequence @@ -30,7 +31,9 @@ class CellMapDataLoader: def __init__( self, - dataset: CellMapMultiDataset | CellMapDataset | CellMapSubset, + dataset: ( + CellMapMultiDataset | CellMapDataset | CellMapSubset | CellMapDatasetWriter + ), classes: Iterable[str], batch_size: int = 1, num_workers: int = 0, diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index e100554..09f7482 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -403,6 +403,8 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: """Returns a crop of the input and target data as PyTorch tensors, corresponding to the coordinate of the unwrapped index.""" + idx = np.array(idx) + idx[idx < 0] = len(self) + idx[idx < 0] try: center = np.unravel_index( idx, [self.sampling_box_shape[c] for c in self.axis_order] diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index a1f5ea4..67834d6 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -1,12 +1,15 @@ # %% import os -from typing import Any, Callable, Mapping, Sequence, Optional +from typing import Callable, Mapping, Sequence, Optional import numpy as np import torch from torch.utils.data import Dataset import tensorstore from upath import UPath -from .image import CellMapImage, EmptyImage, ImageWriter + +from torch.utils.data import Subset + +from .image import CellMapImage, ImageWriter import logging logger = logging.getLogger(__name__) @@ -89,6 +92,7 @@ def __init__( self.empty_value = empty_value self.overwrite = overwrite self._current_center = None + self._current_idx = None self.input_sources: dict[str, CellMapImage] = {} for array_name, array_info in self.input_arrays.items(): self.input_sources[array_name] = CellMapImage( @@ -155,8 +159,8 @@ def bounding_box(self) -> Mapping[str, list[float]]: return self._bounding_box except AttributeError: bounding_box = None - for bounding_box in self.target_bounds.values(): - ... + for current_box in self.target_bounds.values(): + bounding_box = self._get_box_union(current_box, bounding_box) if bounding_box is None: logger.warning( "Bounding box is None. This may result in errors when trying to sample from the dataset." @@ -174,30 +178,20 @@ def bounding_box_shape(self) -> Mapping[str, int]: self._bounding_box_shape = self._get_box_shape(self.bounding_box) return self._bounding_box_shape - # TODO @property def sampling_box(self) -> Mapping[str, list[float]]: - """Returns the sampling box of the dataset (i.e. where centers can be drawn from and still have full samples drawn from within the bounding box).""" + """Returns the sampling box of the dataset (i.e. where centers can be drawn from and still have full samples written to within the bounding box).""" try: return self._sampling_box except AttributeError: sampling_box = None - for source in list(self.input_sources.values()) + list( - self.target_array_writers.values() - ): - if isinstance(source, dict): - for source in source.values(): - if not hasattr(source, "sampling_box"): - continue - sampling_box = self._get_box_intersection( - source.sampling_box, sampling_box # type: ignore - ) - else: - if not hasattr(source, "sampling_box"): - continue - sampling_box = self._get_box_intersection( - source.sampling_box, sampling_box - ) + for array_name, array_info in self.target_arrays.items(): + padding = {c: np.ceil((shape * scale) / 2) for c, shape, scale in zip(self.axis_order, array_info["shape"], array_info["scale"])} # type: ignore + this_box = { + c: [bounds[0] - padding[c], bounds[1] + padding[c]] + for c, bounds in self.target_bounds[array_name].items() + } + sampling_box = self._get_box_union(this_box, sampling_box) if sampling_box is None: logger.warning( "Sampling box is None. This may result in errors when trying to sample from the dataset." @@ -232,19 +226,57 @@ def size(self) -> int: ).astype(int) return self._size - # TODO: Switch this to write_indices @property def writer_indices(self) -> Sequence[int]: - """Returns the indices of the dataset that will produce non-overlapping tiles for use in writer, based on the smallest requested voxel size.""" + """Returns the indices of the dataset that will produce non-overlapping tiles for use in writer, based on the smallest requested target array.""" try: return self._writer_indices except AttributeError: - chunk_size = {} - for c, size in self.bounding_box_shape.items(): - chunk_size[c] = np.ceil(size - self.sampling_box_shape[c]).astype(int) + # Get smallest target array in world units + chunk_box = None + for array_info in self.target_arrays.values(): + this_box = { + c: [0, array_info["shape"][i] * array_info["scale"][i]] + for i, c in enumerate(self.axis_order) + } + chunk_box = self._get_box_intersection(this_box, chunk_box) + + # Convert to voxel units + chunk_size = {c: stop - start for c, (start, stop) in chunk_box.items()} self._writer_indices = self.get_indices(chunk_size) return self._writer_indices + def blocks(self) -> Subset: + """A subset of the validation datasets, tiling the validation datasets with non-overlapping blocks.""" + try: + return self._blocks + except AttributeError: + self._blocks = Subset( + self, + self.writer_indices, + ) + return self._blocks + + def loader( + self, + batch_size: int = 1, + num_workers: int = 0, + rng: Optional[torch.Generator] = None, + **kwargs, + ): + """Returns a DataLoader for the dataset.""" + from .dataloader import CellMapDataLoader + + return CellMapDataLoader( + self.blocks, + classes=self.classes, + batch_size=batch_size, + num_workers=num_workers, + is_train=False, + rng=rng, + **kwargs, + ) + @property def device(self) -> torch.device: """Returns the device for the dataset.""" @@ -270,6 +302,8 @@ def __len__(self) -> int: return self._len def get_center(self, idx: int) -> dict[str, float]: + idx = np.array(idx) + idx[idx < 0] = len(self) + idx[idx < 0] try: center = np.unravel_index( idx, [self.sampling_box_shape[c] for c in self.axis_order] @@ -282,7 +316,7 @@ def get_center(self, idx: int) -> dict[str, float]: # TODO: This is a hacky temprorary fix. Need to figure out why this is happening center = [self.sampling_box_shape[c] - 1 for c in self.axis_order] center = { - c: float(center[i] * self.smallest_voxel_sizes[c] + self.sampling_box[c][0]) + c: center[i] * self.smallest_voxel_sizes[c] + self.sampling_box[c][0] for i, c in enumerate(self.axis_order) } return center @@ -294,26 +328,26 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: self._current_center = self.get_center(idx) outputs = {} for array_name in self.input_arrays.keys(): - array = self.input_sources[array_name][center] # type: ignore + array = self.input_sources[array_name][self._current_center] # type: ignore # TODO: Assumes 1 channel (i.e. grayscale) if array.shape[0] != 1: outputs[array_name] = array[None, ...] else: outputs[array_name] = array - outputs["idx"] = idx + outputs["idx"] = torch.tensor(idx) return outputs def __setitem__( self, - indices: int | torch.Tensor | np.ndarray | Sequence[int], + idx: int | torch.Tensor | np.ndarray | Sequence[int], arrays: dict[str, torch.Tensor | np.ndarray], ) -> None: """ Writes the values for the given arrays at the given index. Args: - indices (int | torch.Tensor | np.ndarray | Sequence[int]): The index or indices to write the arrays to. + idx (int | torch.Tensor | np.ndarray | Sequence[int]): The index or indices to write the arrays to. arrays (dict[str, torch.Tensor | np.ndarray]): The arrays to write to disk, with data either split by label class into a dictionary, or divided by class along the channel dimension of an array/tensor. The dictionary should have the following structure:: { @@ -321,21 +355,24 @@ def __setitem__( ... } """ - if isinstance(indices, int): - indices = [indices] # type: ignore - for idx in indices: # type: ignore - self._current_idx = idx - self._current_center = center = self.get_center(idx) - for array_name, array in arrays.items(): - if isinstance(array, int) or isinstance(array, float): - for c, label in enumerate(self.classes): - self.target_array_writers[array_name][label][center] = array - elif isinstance(array, dict): - for label, array in array.items(): - self.target_array_writers[array_name][label][center] = array - else: - for c, label in enumerate(self.classes): - self.target_array_writers[array_name][label][center] = array[c] + self._current_idx = idx + self._current_center = self.get_center(self._current_idx) + for array_name, array in arrays.items(): + if isinstance(array, int) or isinstance(array, float): + for c, label in enumerate(self.classes): + self.target_array_writers[array_name][label][ + self._current_center + ] = array + elif isinstance(array, dict): + for label, label_array in array.items(): + self.target_array_writers[array_name][label][ + self._current_center + ] = label_array + else: + for c, label in enumerate(self.classes): + self.target_array_writers[array_name][label][ + self._current_center + ] = array[c] def __repr__(self) -> str: """Returns a string representation of the dataset.""" @@ -381,6 +418,21 @@ def _get_box_shape(self, source_box: Mapping[str, list[float]]) -> dict[str, int box_shape[c] = int(np.floor(size)) return box_shape + def _get_box_union( + self, + source_box: Mapping[str, list[float]] | None, + current_box: Mapping[str, list[float]] | None, + ) -> Mapping[str, list[float]] | None: + """Returns the union of the source and current boxes.""" + if source_box is not None: + if current_box is None: + return source_box + for c, (start, stop) in source_box.items(): + assert stop > start, f"Invalid box: {start} to {stop}" + current_box[c][0] = min(current_box[c][0], start) + current_box[c][1] = max(current_box[c][1], stop) + return current_box + def _get_box_intersection( self, source_box: Mapping[str, list[float]] | None, diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 2d6619e..9d34e6a 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -501,17 +501,11 @@ def return_data( method=self.interpolation, # type: ignore ) elif self.pad: - try: - tolerance = self._tolerance - except AttributeError: - self._tolerance = np.ones(coords[self.axes[0]].shape) * np.max( - list(self.scale.values()) - ) - tolerance = self._tolerance data = self.array.reindex( **coords, method="nearest", - tolerance=tolerance, + tolerance=np.ones(coords[self.axes[0]].shape) + * np.max(list(self.scale.values())), fill_value=self.pad_value, ) else: @@ -719,19 +713,6 @@ def __init__( "units": "nanometer", "chunk_shape": list(write_voxel_shape.values()), } - # array = xarray.DataArray( - # self.fill_value, - # dims=self.metadata["axes"], - # coords=self.full_coords, - # attrs=self.metadata, - # name=self.label_class, - # ) - # dataset = array.to_dataset() - # dataset[self.label_class].encoding = {"chunks": self.chunk_shape} - # if overwrite or not UPath(path).exists(): - # dataset.to_zarr(path, mode="w", write_empty_chunks=False, compute=False) - # else: - # dataset.to_zarr(path, mode="a", write_empty_chunks=False, compute=False) @property def array(self) -> xarray.DataArray: @@ -847,7 +828,7 @@ def shape(self) -> Mapping[str, int]: return self._shape except AttributeError: self._shape = { - c: int(self.world_shape[c] // self.scale[c]) for c in self.axes + c: int(np.ceil(self.world_shape[c] / self.scale[c])) for c in self.axes } return self._shape @@ -898,6 +879,7 @@ def full_coords(self) -> tuple[xarray.DataArray, ...]: def align_coords( self, coords: Mapping[str, tuple[Sequence, np.ndarray]] ) -> Mapping[str, tuple[Sequence, np.ndarray]]: + # TODO: Deprecate this function? """Aligns the given coordinates to the image's coordinates.""" aligned_coords = {} for c in self.axes: @@ -915,9 +897,6 @@ def aligned_coords_from_center(self, center: Mapping[str, float]): """Returns the aligned coordinates for the given center with linear sequential coordinates aligned to the image's reference frame.""" coords = {} for c in self.axes: - # Get number of voxels for the axis - num_voxels = self.write_voxel_shape[c] - # Get index of closest start voxel to the edge of the write space, based on the center start_requested = center[c] - self.write_world_shape[c] / 2 start_aligned_idx = int( @@ -926,7 +905,7 @@ def aligned_coords_from_center(self, center: Mapping[str, float]): # Get the aligned range of coordinates coords[c] = self.array.coords[c][ - start_aligned_idx : start_aligned_idx + num_voxels + start_aligned_idx : start_aligned_idx + self.write_voxel_shape[c] ] return coords @@ -936,24 +915,40 @@ def __setitem__( coords: Mapping[str, float] | Mapping[str, tuple[Sequence, np.ndarray]], data: torch.Tensor | np.typing.ArrayLike | float | int, # type: ignore ) -> None: - """Writes the given data to the image at the given center or coordinates (in world units). Supports writing torch.Tensor, numpy.ndarray, and scalar data types.""" + """Writes the given data to the image at the given center (in world units). Supports writing torch.Tensor, numpy.ndarray, and scalar data types, including for batches.""" + if not isinstance(data, (int, float)): + if any(data.shape[i] > self.shape[c] for i, c in enumerate(self.axes)): + raise ValueError( + f"Image {self.path} is too small to write data of shape {data.shape}. Image shape is {self.shape}." + ) # Find vectors of coordinates in world space to write data to if necessary if isinstance(list(coords.values())[0], int | float): center = coords coords = self.aligned_coords_from_center(center) # type: ignore - else: - # coords = self.align_coords(coords) - # print( - # "Warning, setting data to specific coordinates is experimental and may be slow." - # ) - # TODO: Add support for writing to full coordinates (not just center) - raise NotImplementedError( - "Writing to specific coordinates is not yet implemented." - ) - if isinstance(data, torch.Tensor): - data = data.cpu() + if isinstance(data, torch.Tensor): + data = data.cpu() + + data = np.array(data).squeeze().astype(self.dtype) - self.array.loc[coords] = np.array(data).squeeze().astype(self.dtype) + try: + self.array.loc[coords] = data + except ValueError as e: + print( + f"Writing to center {center} in image {self.path} failed. Coordinates: are not all within the image's bounds. Will drop out of bounds data." + ) + # Crop data to match the number of coordinates matched in the image + slices = [slice(None, len(coord)) for coord in coords.values()] + data = data[*slices] + self.array.loc[coords] = data + + else: + # Write batches + for i in range(len(coords[self.axes[0]])): + if isinstance(data, (int, float)): + this_data = data + else: + this_data = data[i] + self[{c: coords[c][i] for c in self.axes}] = this_data def __repr__(self) -> str: """Returns a string representation of the ImageWriter object.""" diff --git a/src/cellmap_data/subdataset.py b/src/cellmap_data/subdataset.py index 2c00d75..7f1af3c 100644 --- a/src/cellmap_data/subdataset.py +++ b/src/cellmap_data/subdataset.py @@ -1,6 +1,7 @@ from typing import Callable, Sequence from torch.utils.data import Subset from .dataset import CellMapDataset + from .multidataset import CellMapMultiDataset @@ -10,11 +11,13 @@ class CellMapSubset(Subset): """ def __init__( - self, dataset: CellMapDataset | CellMapMultiDataset, indices: Sequence[int] + self, + dataset: CellMapDataset | CellMapMultiDataset, + indices: Sequence[int], ) -> None: """ Args: - dataset: CellMapDataset | CellMapMultiDataset + dataset: CellMapDataset | CellMapMultiDataset | CellMapDatasetWriter The dataset to be subsetted. indices: Sequence[int] The indices of the dataset to be used as the subset. From 62620df55e28ab0ce007d3d1764a5fad3a8df115 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 4 Nov 2024 13:35:57 -0500 Subject: [PATCH 09/17] =?UTF-8?q?fix:=20=F0=9F=90=9B=20Correct=20image=20w?= =?UTF-8?q?riter=20tiling.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ready for merge after testing with training --- src/cellmap_data/dataloader.py | 7 +-- src/cellmap_data/dataset.py | 2 +- src/cellmap_data/dataset_writer.py | 80 +++++++++++++++++++----------- 3 files changed, 55 insertions(+), 34 deletions(-) diff --git a/src/cellmap_data/dataloader.py b/src/cellmap_data/dataloader.py index a4d6653..900d52a 100644 --- a/src/cellmap_data/dataloader.py +++ b/src/cellmap_data/dataloader.py @@ -1,8 +1,7 @@ import torch -from torch.utils.data import DataLoader, Sampler +from torch.utils.data import DataLoader, Sampler, Subset from .dataset import CellMapDataset from .multidataset import CellMapMultiDataset -from .subdataset import CellMapSubset from .dataset_writer import CellMapDatasetWriter from typing import Callable, Iterable, Optional, Sequence @@ -31,9 +30,7 @@ class CellMapDataLoader: def __init__( self, - dataset: ( - CellMapMultiDataset | CellMapDataset | CellMapSubset | CellMapDatasetWriter - ), + dataset: CellMapMultiDataset | CellMapDataset | Subset | CellMapDatasetWriter, classes: Iterable[str], batch_size: int = 1, num_workers: int = 0, diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index 09f7482..58c86cb 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -410,11 +410,11 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: idx, [self.sampling_box_shape[c] for c in self.axis_order] ) except ValueError: + # TODO: This is a hacky temprorary fix. Need to figure out why this is happening logger.error( f"Index {idx} out of bounds for dataset {self} of length {len(self)}" ) logger.warning(f"Returning closest index in bounds") - # TODO: This is a hacky temprorary fix. Need to figure out why this is happening center = [self.sampling_box_shape[c] - 1 for c in self.axis_order] center = { c: center[i] * self.largest_voxel_sizes[c] + self.sampling_box[c][0] diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index 67834d6..59116f5 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -3,12 +3,10 @@ from typing import Callable, Mapping, Sequence, Optional import numpy as np import torch -from torch.utils.data import Dataset +from torch.utils.data import Dataset, Subset, DataLoader import tensorstore from upath import UPath -from torch.utils.data import Subset - from .image import CellMapImage, ImageWriter import logging @@ -152,6 +150,20 @@ def smallest_voxel_sizes(self) -> Mapping[str, float]: return self._smallest_voxel_sizes + @property + def smallest_target_array(self) -> Mapping[str, float]: + """Returns the smallest target array in world units.""" + try: + return self._smallest_target_array + except AttributeError: + smallest_target_array = {c: np.inf for c in self.axis_order} + for writer in self.target_array_writers.values(): + for _, writer in writer.items(): + for c, size in writer.write_world_shape.items(): + smallest_target_array[c] = min(smallest_target_array[c], size) + self._smallest_target_array = smallest_target_array + return self._smallest_target_array + @property def bounding_box(self) -> Mapping[str, list[float]]: """Returns the bounding box inclusive of all the target images.""" @@ -180,7 +192,7 @@ def bounding_box_shape(self) -> Mapping[str, int]: @property def sampling_box(self) -> Mapping[str, list[float]]: - """Returns the sampling box of the dataset (i.e. where centers can be drawn from and still have full samples written to within the bounding box).""" + """Returns the sampling box of the dataset (i.e. where centers should be drawn from and to fully sample within the bounding box).""" try: return self._sampling_box except AttributeError: @@ -188,7 +200,7 @@ def sampling_box(self) -> Mapping[str, list[float]]: for array_name, array_info in self.target_arrays.items(): padding = {c: np.ceil((shape * scale) / 2) for c, shape, scale in zip(self.axis_order, array_info["shape"], array_info["scale"])} # type: ignore this_box = { - c: [bounds[0] - padding[c], bounds[1] + padding[c]] + c: [bounds[0] + padding[c], bounds[1] - padding[c]] for c, bounds in self.target_bounds[array_name].items() } sampling_box = self._get_box_union(this_box, sampling_box) @@ -232,20 +244,10 @@ def writer_indices(self) -> Sequence[int]: try: return self._writer_indices except AttributeError: - # Get smallest target array in world units - chunk_box = None - for array_info in self.target_arrays.values(): - this_box = { - c: [0, array_info["shape"][i] * array_info["scale"][i]] - for i, c in enumerate(self.axis_order) - } - chunk_box = self._get_box_intersection(this_box, chunk_box) - - # Convert to voxel units - chunk_size = {c: stop - start for c, (start, stop) in chunk_box.items()} - self._writer_indices = self.get_indices(chunk_size) + self._writer_indices = self.get_indices(self.smallest_target_array) return self._writer_indices + @property def blocks(self) -> Subset: """A subset of the validation datasets, tiling the validation datasets with non-overlapping blocks.""" try: @@ -261,22 +263,29 @@ def loader( self, batch_size: int = 1, num_workers: int = 0, - rng: Optional[torch.Generator] = None, **kwargs, ): """Returns a DataLoader for the dataset.""" - from .dataloader import CellMapDataLoader - - return CellMapDataLoader( + return DataLoader( self.blocks, - classes=self.classes, batch_size=batch_size, num_workers=num_workers, - is_train=False, - rng=rng, + collate_fn=self.collate_fn, **kwargs, ) + def collate_fn(self, batch: list[dict]) -> dict[str, torch.Tensor]: + """Combine a list of dictionaries from different sources into a single dictionary for output.""" + outputs = {} + for b in batch: + for key, value in b.items(): + if key not in outputs: + outputs[key] = [] + outputs[key].append(value) + for key, value in outputs.items(): + outputs[key] = torch.stack(value) + return outputs + @property def device(self) -> torch.device: """Returns the device for the dataset.""" @@ -376,7 +385,7 @@ def __setitem__( def __repr__(self) -> str: """Returns a string representation of the dataset.""" - return f"CellMapDataset(\n\tRaw path: {self.raw_path}\n\tOutput path(s): {self.target_path}\n\tClasses: {self.classes})" + return f"CellMapDatasetWriter(\n\tRaw path: {self.raw_path}\n\tOutput path(s): {self.target_path}\n\tClasses: {self.classes})" def get_target_array_writer( self, array_name: str, array_info: Mapping[str, Sequence[int | float]] @@ -457,14 +466,26 @@ def verify(self) -> bool: print(f"Error: {e}") return False - def get_indices(self, chunk_size: Mapping[str, int]) -> Sequence[int]: - """Returns the indices of the dataset that will tile the dataset according to the chunk_size.""" + def get_indices(self, chunk_size: Mapping[str, float]) -> Sequence[int]: + """Returns the indices of the dataset that will tile the dataset according to the chunk_size (supplied in world units).""" # TODO: ADD TEST - # Get padding per axis + + # Convert the target chunk size in world units to voxel units + chunk_size = { + c: int(size // self.smallest_voxel_sizes[c]) + for c, size in chunk_size.items() + } + indices_dict = {} for c, size in chunk_size.items(): indices_dict[c] = np.arange(0, self.sampling_box_shape[c], size, dtype=int) + # Make sure the last index is included + if indices_dict[c][-1] != self.sampling_box_shape[c] - 1: + indices_dict[c] = np.append( + indices_dict[c], self.sampling_box_shape[c] - 1 + ) + indices = [] # Generate linear indices by unraveling all combinations of axes indices for i in np.ndindex(*[len(indices_dict[c]) for c in self.axis_order]): @@ -493,3 +514,6 @@ def set_raw_value_transforms(self, transforms: Callable) -> None: self.raw_value_transforms = transforms for source in self.input_sources.values(): source.value_transform = transforms + + +# %% From a8cdc45a87813ef388b4fa17b5633954d95227ec Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:43:19 -0500 Subject: [PATCH 10/17] Update ci.yml --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dab61f6..bc1a3a1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,7 +28,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.10', '3.11'] + python-version: ['3.11', '3.12'] platform: [ubuntu-latest, macos-latest, windows-latest] steps: From cf6171b2974b3f75d42cf9b6ac906880f509bf06 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:44:39 -0500 Subject: [PATCH 11/17] Update tests.yml --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1873b17..c844c9b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,7 +9,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.11"] + python-version: ["3.11", "3.12"] os: [ubuntu-latest, windows-latest, macos-latest] steps: From d53ddaf2b101e3954c277ce0f7f8f825aa2a9e1f Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:51:08 -0500 Subject: [PATCH 12/17] Update mypy.yml --- .github/workflows/mypy.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index 6711175..0786e51 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -1,6 +1,6 @@ name: Python mypy -on: [push, pull_request] +on: [pull_request] jobs: static-analysis: @@ -9,6 +9,8 @@ jobs: steps: - name: Setup Python uses: actions/setup-python@v5 + with: + python-version: 3.11 - name: Setup checkout uses: actions/checkout@v4 - name: mypy From 0a78eb8c68b6aaa30909bb147a35de6c8475f42b Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:53:04 -0500 Subject: [PATCH 13/17] Update mypy.yml --- .github/workflows/mypy.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index 0786e51..c25ec0f 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -1,6 +1,9 @@ name: Python mypy -on: [pull_request] +on: + push: + branches: [ "main" ] + pull_request: jobs: static-analysis: From 1e89fce9bb37148b98cd5fb8092dfef07be90156 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:54:02 -0500 Subject: [PATCH 14/17] Update tests.yml --- .github/workflows/tests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c844c9b..dca9036 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -2,6 +2,8 @@ name: Test on: push: + branches: [ "main" ] + pull_request: jobs: test: From 8bb0d97ac98803dc1ca30bed70a37079f13a0638 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Mon, 4 Nov 2024 14:08:08 -0500 Subject: [PATCH 15/17] Update subdataset.py --- src/cellmap_data/subdataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cellmap_data/subdataset.py b/src/cellmap_data/subdataset.py index 7f1af3c..3df9d53 100644 --- a/src/cellmap_data/subdataset.py +++ b/src/cellmap_data/subdataset.py @@ -17,7 +17,7 @@ def __init__( ) -> None: """ Args: - dataset: CellMapDataset | CellMapMultiDataset | CellMapDatasetWriter + dataset: CellMapDataset | CellMapMultiDataset The dataset to be subsetted. indices: Sequence[int] The indices of the dataset to be used as the subset. From 6371af1033a32144a934a761e9c6dbc16161e6cf Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Mon, 4 Nov 2024 14:10:04 -0500 Subject: [PATCH 16/17] Update subdataset.py --- src/cellmap_data/subdataset.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/cellmap_data/subdataset.py b/src/cellmap_data/subdataset.py index 3df9d53..8148596 100644 --- a/src/cellmap_data/subdataset.py +++ b/src/cellmap_data/subdataset.py @@ -10,11 +10,7 @@ class CellMapSubset(Subset): This subclasses PyTorch Subset to wrap a CellMapDataset or CellMapMultiDataset object under a common API, which can be used for dataloading. It maintains the same API as the Subset class. It retrieves raw and groundtruth data from a CellMapDataset or CellMapMultiDataset object. """ - def __init__( - self, - dataset: CellMapDataset | CellMapMultiDataset, - indices: Sequence[int], - ) -> None: + def __init__(self, dataset: CellMapDataset | CellMapMultiDataset, indices: Sequence[int], ) -> None: """ Args: dataset: CellMapDataset | CellMapMultiDataset From 86db20069e3f2662620005aac469376c3610993d Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 4 Nov 2024 14:17:24 -0500 Subject: [PATCH 17/17] =?UTF-8?q?style:=20=F0=9F=8E=A8=20Black=20format?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/cellmap_data/subdataset.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/cellmap_data/subdataset.py b/src/cellmap_data/subdataset.py index 8148596..3df9d53 100644 --- a/src/cellmap_data/subdataset.py +++ b/src/cellmap_data/subdataset.py @@ -10,7 +10,11 @@ class CellMapSubset(Subset): This subclasses PyTorch Subset to wrap a CellMapDataset or CellMapMultiDataset object under a common API, which can be used for dataloading. It maintains the same API as the Subset class. It retrieves raw and groundtruth data from a CellMapDataset or CellMapMultiDataset object. """ - def __init__(self, dataset: CellMapDataset | CellMapMultiDataset, indices: Sequence[int], ) -> None: + def __init__( + self, + dataset: CellMapDataset | CellMapMultiDataset, + indices: Sequence[int], + ) -> None: """ Args: dataset: CellMapDataset | CellMapMultiDataset