Skip to content

Commit

Permalink
feat: Add device parameter to CellMapDatasetWriter for improved devic…
Browse files Browse the repository at this point in the history
…e management and update import for multiscale attributes
  • Loading branch information
rhoadesScholar committed Dec 13, 2024
1 parent eda089e commit 5767ee0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
5 changes: 5 additions & 0 deletions src/cellmap_data/dataset_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
rng: Optional[torch.Generator] = None,
empty_value: float | int = 0,
overwrite: bool = False,
device: Optional[str | torch.device] = None,
) -> None:
"""Initializes the CellMapDatasetWriter.
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
rng (Optional[torch.Generator]): The random number generator to use.
empty_value (float | int): The value to use for empty data in an array.
overwrite (bool): Whether to overwrite existing data.
device (Optional[str | torch.device]): The device to use for the dataset. If None, will default to "cuda" if available, then "mps", otherwise "cpu".
"""
self.raw_path = raw_path
self.target_path = target_path
Expand Down Expand Up @@ -109,6 +111,9 @@ def __init__(
self.target_array_writers[array_name] = self.get_target_array_writer(
array_name, array_info
)
if device is not None:
self._device = device
self.to(device)

@property
def center(self) -> Mapping[str, float] | None:
Expand Down
4 changes: 2 additions & 2 deletions src/cellmap_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import xarray_tensorstore as xt
import zarr
from pydantic_ome_ngff.v04.axis import Axis
from pydantic_ome_ngff.v04.multiscale import GroupAttrs, MultiscaleMetadata
from pydantic_ome_ngff.v04.multiscale import MultiscaleGroupAttrs, MultiscaleMetadata
from pydantic_ome_ngff.v04.transform import (
Scale,
Translation,
Expand Down Expand Up @@ -197,7 +197,7 @@ def multiscale_attrs(self) -> MultiscaleMetadata:
try:
return self._multiscale_attrs
except AttributeError:
self._multiscale_attrs: MultiscaleMetadata = GroupAttrs(
self._multiscale_attrs: MultiscaleMetadata = MultiscaleGroupAttrs(
multiscales=self.group.attrs["multiscales"]
).multiscales[0]
return self._multiscale_attrs
Expand Down

0 comments on commit 5767ee0

Please sign in to comment.