Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compute() on plate grid #299

Merged
merged 9 commits into from
Sep 29, 2023
42 changes: 15 additions & 27 deletions ome_zarr/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import dask.array as da
import numpy as np
from dask import delayed

from .axes import Axes
from .format import format_from_version
Expand Down Expand Up @@ -420,39 +419,34 @@ def __init__(self, node: Node) -> None:
self.img_metadata = image_node.metadata
self.img_pyramid_shapes = [d.shape for d in image_node.data]

def get_field(tile_name: str, level: int) -> np.ndarray:
def get_field(row: int, col: int, level: int) -> da.core.Array:
"""tile_name is 'row,col'"""
row, col = (int(n) for n in tile_name.split(","))
field_index = (column_count * row) + col
image_path = image_paths[field_index]
path = f"{image_path}/{level}"
LOGGER.debug("LOADING tile... %s", path)
data = None
try:
data = self.zarr.load(path)
# handle e.g. 2x2 grid with only 3 images/fields
if field_index < len(image_paths):
image_path = image_paths[field_index]
path = f"{image_path}/{level}"
data = self.zarr.load(path)
except ValueError:
LOGGER.error("Failed to load %s", path)
data = np.zeros(self.img_pyramid_shapes[level], dtype=self.numpy_type)
if data is None:
data = da.zeros(self.img_pyramid_shapes[level], dtype=self.numpy_type)
return data

lazy_reader = delayed(get_field)

def get_lazy_well(level: int, tile_shape: tuple) -> da.Array:
lazy_rows = []
for row in range(row_count):
lazy_row: List[da.Array] = []
for col in range(column_count):
tile_name = f"{row},{col}"
LOGGER.debug(
"creating lazy_reader. row: %s col: %s level: %s",
row,
col,
level,
)
lazy_tile = da.from_delayed(
lazy_reader(tile_name, level),
shape=tile_shape,
dtype=self.numpy_type,
)
lazy_tile = get_field(row, col, level)
lazy_row.append(lazy_tile)
lazy_rows.append(da.concatenate(lazy_row, axis=x_index))
return da.concatenate(lazy_rows, axis=y_index)
Expand Down Expand Up @@ -535,31 +529,25 @@ def get_tile_path(self, level: int, row: int, col: int) -> str:
def get_stitched_grid(self, level: int, tile_shape: tuple) -> da.core.Array:
LOGGER.debug("get_stitched_grid() level: %s, tile_shape: %s", level, tile_shape)

def get_tile(tile_name: str) -> np.ndarray:
def get_tile(row: int, col: int) -> da.core.Array:
"""tile_name is 'level,z,c,t,row,col'"""
row, col = (int(n) for n in tile_name.split(","))
path = self.get_tile_path(level, row, col)
LOGGER.debug("LOADING tile... %s with shape: %s", path, tile_shape)
LOGGER.debug("creating tile... %s with shape: %s", path, tile_shape)

try:
# this is a dask array - data not loaded from source yet
data = self.zarr.load(path)
except ValueError:
LOGGER.exception("Failed to load %s", path)
data = np.zeros(tile_shape, dtype=self.numpy_type)
data = da.zeros(tile_shape, dtype=self.numpy_type)
return data

lazy_reader = delayed(get_tile)

lazy_rows = []
# For level 0, return whole image for each tile
for row in range(self.row_count):
lazy_row: List[da.Array] = []
for col in range(self.column_count):
tile_name = f"{row},{col}"
lazy_tile = da.from_delayed(
lazy_reader(tile_name), shape=tile_shape, dtype=self.numpy_type
)
lazy_row.append(lazy_tile)
lazy_row.append(get_tile(row, col))
lazy_rows.append(da.concatenate(lazy_row, axis=len(self.axes) - 1))
return da.concatenate(lazy_rows, axis=len(self.axes) - 2)

Expand Down
23 changes: 20 additions & 3 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import dask.array as da
import numpy as np
import pytest
import zarr
from numpy import ones, zeros

from ome_zarr.data import create_zarr
from ome_zarr.io import parse_url
from ome_zarr.reader import Node, Plate, Reader
from ome_zarr.reader import Node, Plate, Reader, Well
from ome_zarr.writer import write_image, write_plate_metadata, write_well_metadata


Expand Down Expand Up @@ -95,16 +96,32 @@ def test_multiwells_plate(self, field_paths):
nodes = list(reader())
# currently reading plate labels disabled. Only 1 node
assert len(nodes) == 1

plate_node = nodes[0]
assert len(plate_node.specs) == 1
assert isinstance(plate_node.specs[0], Plate)
# data should be a Dask array
pyramid = plate_node.data
assert isinstance(pyramid[0], da.Array)
# if we compute(), expect to get numpy array
result = pyramid[0].compute()
assert isinstance(result, np.ndarray)

# Get the plate node's array. It should be fused from the first field of all
# well arrays (which in this test are non-zero), with zero values for wells
# that failed to load (not expected) or the surplus area not filled by a well.
expected_num_pixels = (
len(well_paths) * len(field_paths[:1]) * np.prod((1, 1, 1, 256, 256))
)
pyramid_0 = plate_node.data[0]
pyramid_0 = pyramid[0]
assert np.asarray(pyramid_0).sum() == expected_num_pixels
# assert len(nodes[1].specs) == 1

# assert isinstance(nodes[1].specs[0], PlateLabels)

reader = Reader(parse_url(f"{self.path}/{well_paths[0]}"))
nodes = list(reader())
assert isinstance(nodes[0].specs[0], Well)
pyramid = nodes[0].data
assert isinstance(pyramid[0], da.Array)
result = pyramid[0].compute()
assert isinstance(result, np.ndarray)
Loading