From 5767ee034f143466d2d86434f8f2805aa63a5b50 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 13 Dec 2024 17:01:11 -0500 Subject: [PATCH] feat: Add device parameter to CellMapDatasetWriter for improved device management and update import for multiscale attributes --- src/cellmap_data/dataset_writer.py | 5 +++++ src/cellmap_data/image.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index 1a523fe..59906eb 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -46,6 +46,7 @@ def __init__( rng: Optional[torch.Generator] = None, empty_value: float | int = 0, overwrite: bool = False, + device: Optional[str | torch.device] = None, ) -> None: """Initializes the CellMapDatasetWriter. @@ -76,6 +77,7 @@ def __init__( rng (Optional[torch.Generator]): The random number generator to use. empty_value (float | int): The value to use for empty data in an array. overwrite (bool): Whether to overwrite existing data. + device (Optional[str | torch.device]): The device to use for the dataset. If None, will default to "cuda" if available, then "mps", otherwise "cpu". """ self.raw_path = raw_path self.target_path = target_path @@ -109,6 +111,9 @@ def __init__( self.target_array_writers[array_name] = self.get_target_array_writer( array_name, array_info ) + if device is not None: + self._device = device + self.to(device) @property def center(self) -> Mapping[str, float] | None: diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 7d8c0fd..bcd3f37 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -8,7 +8,7 @@ import xarray_tensorstore as xt import zarr from pydantic_ome_ngff.v04.axis import Axis -from pydantic_ome_ngff.v04.multiscale import GroupAttrs, MultiscaleMetadata +from pydantic_ome_ngff.v04.multiscale import MultiscaleGroupAttrs, MultiscaleMetadata from pydantic_ome_ngff.v04.transform import ( Scale, Translation, @@ -197,7 +197,7 @@ def multiscale_attrs(self) -> MultiscaleMetadata: try: return self._multiscale_attrs except AttributeError: - self._multiscale_attrs: MultiscaleMetadata = GroupAttrs( + self._multiscale_attrs: MultiscaleMetadata = MultiscaleGroupAttrs( multiscales=self.group.attrs["multiscales"] ).multiscales[0] return self._multiscale_attrs