Skip to content

Commit

Permalink
Merge pull request #1508 from girder/zarr-sink-metadata
Browse files Browse the repository at this point in the history
Zarr sink metadata management
  • Loading branch information
annehaley authored Jun 18, 2024
2 parents 4f6f053 + 1bbc0b8 commit f7ba4bf
Show file tree
Hide file tree
Showing 2 changed files with 317 additions and 27 deletions.
209 changes: 182 additions & 27 deletions sources/zarr/large_image_source_zarr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, path, **kwargs):
def _initOpen(self, **kwargs):
self._largeImagePath = str(self._getLargeImagePath())
self._zarr = None
self._editable = False
if not os.path.isfile(self._largeImagePath) and '//:' not in self._largeImagePath:
raise TileSourceFileNotFoundError(self._largeImagePath) from None
try:
Expand Down Expand Up @@ -126,7 +127,11 @@ def _initNew(self, path, **kwargs):
self._framecount = 0
self._mm_x = 0
self._mm_y = 0
self._channelNames = []
self._channelColors = []
self._imageDescription = None
self._levels = []
self._associatedImages = {}

def __del__(self):
if not hasattr(self, '_derivedSource'):
Expand Down Expand Up @@ -312,6 +317,8 @@ def _validateZarr(self):
Validate that we can read tiles from the zarr parent group in
self._zarr. Set up the appropriate class variables.
"""
if self._editable:
self._writeInternalMetadata()
found = self._scanZarrGroup(self._zarr)
if found['best'] is None:
msg = 'No data array that can be used.'
Expand All @@ -324,8 +331,10 @@ def _validateZarr(self):
msg = 'Conflicting xy axis data.'
raise TileSourceError(msg)
self._channels = found['channels']
self._associatedImages = [
(g, a) for g, a in found['associated'] if not any(g is gb for gb, _ in self._series)]
self._associatedImages = {
g.name.replace('/', ''): (g, a)
for g, a in found['associated'] if not any(g is gb for gb, _ in self._series)
}
self.sizeX = baseArray.shape[self._axes['x']]
self.sizeY = baseArray.shape[self._axes['y']]
self.tileWidth = (
Expand Down Expand Up @@ -429,6 +438,8 @@ def getInternalMetadata(self, **kwargs):
:returns: a dictionary of data or None.
"""
if self._editable:
self._writeInternalMetadata()
result = {}
result['zarr'] = {
'base': self._zarr.attrs.asdict(),
Expand All @@ -445,7 +456,7 @@ def getAssociatedImagesList(self):
:return: the list of image keys.
"""
return [f'image_{idx}' for idx in range(len(self._associatedImages))]
return list(self._associatedImages.keys())

def _getAssociatedImage(self, imageKey):
"""
Expand All @@ -454,15 +465,9 @@ def _getAssociatedImage(self, imageKey):
:param imageKey: the key of the associated image.
:return: the image in PIL format or None.
"""
if not imageKey.startswith('image_'):
return
try:
idx = int(imageKey[6:])
except Exception:
return
if idx < 0 or idx >= len(self._associatedImages):
if imageKey not in self._associatedImages:
return
group, arr = self._associatedImages[idx]
group, arr = self._associatedImages[imageKey]
axes = self._getGeneralAxes(arr)
trans = [idx for idx in range(len(arr.shape))
if idx not in axes.values()] + [axes['y'], axes['x']]
Expand Down Expand Up @@ -565,8 +570,6 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs):
:param kwargs: start locations for any additional axes. Note that
``level`` is a reserved word and not permitted for an axis name.
"""
# TODO: improve band bookkeeping

self._checkEditable()
store_path = str(kwargs.pop('level', 0))
placement = {
Expand Down Expand Up @@ -629,20 +632,8 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs):
path=store_path,
)

# If base data changed, update large_image attributes and OME metadata
# If base data changed, update large_image attributes
if store_path == '0':
self._zarr.attrs.update({
'multiscales': [{
'version': '0.5-dev',
'axes': [{
'name': a,
'type': 'space' if a in ['x', 'y'] else 'other',
} for a in axes],
'datasets': [{'path': 0}],
}],
'omero': {'version': '0.5-dev'},
})

self._dtype = tile.dtype
self._bandCount = new_dims.get(axes[-1]) # last axis is assumed to be bands
self.sizeX = new_dims.get('x')
Expand All @@ -657,6 +648,113 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs):
self.levels = int(max(1, math.ceil(math.log(max(
self.sizeX / self.tileWidth, self.sizeY / self.tileHeight)) / math.log(2)) + 1))

def addAssociatedImage(self, image, imageKey=None):
"""
Add an associated image to this source.
:param image: a numpy array, PIL Image, or a binary string
with an image. The numpy array can have 2 or 3 dimensions.
"""
data, _ = _imageToNumpy(image)
with self._addLock:
if imageKey is None:
# Each associated image should be in its own group
num_existing = len(self.getAssociatedImagesList())
imageKey = f'image_{num_existing + 1}'
group = self._zarr.require_group(imageKey)
arr = zarr.array(
data,
store=self._zarr_store,
path=f'{imageKey}/image',
)
self._associatedImages[imageKey] = (group, arr)

def _writeInternalMetadata(self):
self._checkEditable()
with self._addLock:
name = str(self._tempdir.name).split('/')[-1]
arrays = dict(self._zarr.arrays())
channel_axis = self._axes.get('s') or self._axes.get('c')
datasets = []
axes = []
channels = []
rdefs = {'model': 'color' if len(self._channelColors) else 'greyscale'}
sorted_axes = [a[0] for a in sorted(self._axes.items(), key=lambda item: item[1])]
for arr_name in arrays:
level = int(arr_name)
scale = [1.0 for a in sorted_axes]
scale[self._axes.get('x')] = self._mm_x * (2 ** level)
scale[self._axes.get('y')] = self._mm_y * (2 ** level)
dataset_metadata = {
'path': arr_name,
'coordinateTransformations': [{
'type': 'scale',
'scale': scale,
}],
}
datasets.append(dataset_metadata)
for a in sorted_axes:
axis_metadata = {'name': a}
if a in ['x', 'y']:
axis_metadata['type'] = 'space'
axis_metadata['unit'] = 'millimeter'
elif a in ['s', 'c']:
axis_metadata['type'] = 'channel'
elif a == 't':
rdefs['defaultT'] = 0
elif a == 'z':
rdefs['defaultZ'] = 0
axes.append(axis_metadata)
if channel_axis and len(arrays) > 0:
base_array = list(arrays.values())[0]
base_shape = base_array.shape
for c in range(base_shape[channel_axis]):
channel_metadata = {
'active': True,
'coefficient': 1,
'color': 'FFFFFF',
'family': 'linear',
'inverted': False,
'label': f'Band {c + 1}',
}
channel_data = base_array[..., c]
channel_min = np.min(channel_data)
channel_max = np.max(channel_data)
channel_metadata['window'] = {
'end': channel_max,
'max': channel_max,
'min': channel_min,
'start': channel_min,
}
if len(self._channelNames) > c:
channel_metadata['label'] = self._channelNames[c]
if len(self._channelColors) > c:
channel_metadata['color'] = self._channelColors[c]
channels.append(channel_metadata)
# Guidelines from https://ngff.openmicroscopy.org/latest/
self._zarr.attrs.update({
'multiscales': [{
'version': '0.5',
'name': name,
'axes': axes,
'datasets': datasets,
'metadata': {
'description': self._imageDescription or '',
'kwargs': {
'multichannel': (channel_axis is not None),
},
},
}],
'omero': {
'id': 1,
'version': '0.5',
'name': name,
'channels': channels,
'rdefs': rdefs,
},
'bioformats2raw.layout': 3,
})

@property
def crop(self):
"""
Expand All @@ -682,6 +780,59 @@ def crop(self, value):
raise TileSourceError(msg)
self._crop = (x, y, w, h)

@property
def mm_x(self):
return self._mm_x

@mm_x.setter
def mm_x(self, value):
self._checkEditable()
value = float(value) if value is not None else None
if value is not None and value <= 0:
msg = 'mm_x must be positive or None'
raise TileSourceError(msg)
self._mm_x = value

@property
def mm_y(self):
return self._mm_y

@mm_y.setter
def mm_y(self, value):
self._checkEditable()
value = float(value) if value is not None else None
if value is not None and value <= 0:
msg = 'mm_y must be positive or None'
raise TileSourceError(msg)
self._mm_y = value

@property
def imageDescription(self):
return self._imageDescription

@imageDescription.setter
def imageDescription(self, description):
self._checkEditable()
self._imageDescription = description

@property
def channelNames(self):
return self._channelNames

@channelNames.setter
def channelNames(self, names):
self._checkEditable()
self._channelNames = names

@property
def channelColors(self):
return self._channelColors

@channelColors.setter
def channelColors(self, colors):
self._checkEditable()
self._channelColors = colors

def _generateDownsampledLevels(self, resample_method):
self._checkEditable()
current_arrays = dict(self._zarr.arrays())
Expand All @@ -705,6 +856,7 @@ def _generateDownsampledLevels(self, resample_method):
width=4096 + tile_overlap['x'],
height=4096 + tile_overlap['y'],
)
sorted_axes = [a[0] for a in sorted(self._axes.items(), key=lambda item: item[1])]
for level in range(1, self.levels):
scale_factor = 2 ** level
iterator_output = dict(
Expand Down Expand Up @@ -739,7 +891,7 @@ def _generateDownsampledLevels(self, resample_method):
x=x,
y=y,
**frame_position,
axes=list(self._axes.keys()),
axes=sorted_axes,
level=level,
)
self._validateZarr() # refresh self._levels before continuing
Expand Down Expand Up @@ -799,6 +951,8 @@ def write(
**frame_position,
)

source._writeInternalMetadata()

if suffix in ['.zarr', '.db', '.sqlite', '.zip']:
if resample is None:
resample = (
Expand All @@ -807,6 +961,7 @@ def write(
else ResampleMethod.NP_NEAREST
)
source._generateDownsampledLevels(resample)
source._writeInternalMetadata() # rewrite with new level datasets

if suffix == '.zarr':
shutil.copytree(source._tempdir.name, path)
Expand Down
Loading

0 comments on commit f7ba4bf

Please sign in to comment.