diff --git a/sources/zarr/large_image_source_zarr/__init__.py b/sources/zarr/large_image_source_zarr/__init__.py index 71f0273bd..fe85441ec 100644 --- a/sources/zarr/large_image_source_zarr/__init__.py +++ b/sources/zarr/large_image_source_zarr/__init__.py @@ -1,8 +1,12 @@ import math import os +import shutil +import tempfile import threading +import uuid from importlib.metadata import PackageNotFoundError from importlib.metadata import version as _importlib_version +from pathlib import Path import numpy as np import packaging.version @@ -10,10 +14,10 @@ import large_image from large_image.cache_util import LruCacheMetaclass, methodcache -from large_image.constants import TILE_FORMAT_NUMPY, SourcePriority +from large_image.constants import NEW_IMAGE_PATH_FLAG, TILE_FORMAT_NUMPY, SourcePriority from large_image.exceptions import TileSourceError, TileSourceFileNotFoundError from large_image.tilesource import FileTileSource -from large_image.tilesource.utilities import nearPowerOfTwo +from large_image.tilesource.utilities import _imageToNumpy, nearPowerOfTwo try: __version__ = _importlib_version(__name__) @@ -53,6 +57,13 @@ def __init__(self, path, **kwargs): """ super().__init__(path, **kwargs) + if str(path).startswith(NEW_IMAGE_PATH_FLAG): + self._initNew(path, **kwargs) + else: + self._initOpen(**kwargs) + self._tileLock = threading.RLock() + + def _initOpen(self, **kwargs): self._largeImagePath = str(self._getLargeImagePath()) self._zarr = None if not os.path.isfile(self._largeImagePath) and '//:' not in self._largeImagePath: @@ -80,7 +91,48 @@ def __init__(self, path, **kwargs): except Exception: msg = 'File cannot be opened -- not an OME NGFF file or understandable zarr file.' raise TileSourceError(msg) - self._tileLock = threading.RLock() + + def _initNew(self, path, **kwargs): + """ + Initialize the tile class for creating a new image. + """ + self._tempdir = tempfile.TemporaryDirectory(path) + self._zarr_store = zarr.DirectoryStore(self._tempdir.name) + self._zarr = zarr.open(self._zarr_store, mode='w') + # Make unpickleable + self._unpickleable = True + self._largeImagePath = None + self._dims = {} + self.sizeX = self.sizeY = self.levels = 0 + self.tileWidth = self.tileHeight = self._tileSize + self._cacheValue = str(uuid.uuid4()) + self._output = None + self._editable = True + self._bandRanges = None + self._addLock = threading.RLock() + self._framecount = 0 + self._mm_x = 0 + self._mm_y = 0 + self._levels = [] + + def __del__(self): + if not hasattr(self, '_derivedSource'): + try: + self._zarr.close() + except Exception: + pass + try: + shutil.rmtree(self._tempdir) + except Exception: + pass + + def _checkEditable(self): + """ + Raise an exception if this is not an editable image. + """ + if not self._editable: + msg = 'Not an editable image' + raise TileSourceError(msg) def _getGeneralAxes(self, arr): """ @@ -284,6 +336,8 @@ def _validateZarr(self): baseArray.shape[self._axes.get('c')] in {1, 3, 4}): self._bandCount = baseArray.shape[self._axes['c']] self._axes['s'] = self._axes.pop('c') + elif 's' in self._axes: + self._bandCount = baseArray.shape[self._axes['s']] self._zarrFindLevels() self._getScale() stride = 1 @@ -326,6 +380,13 @@ def getNativeMagnification(self): 'mm_y': mm_y, } + def getState(self): + # Use the _cacheValue to avoid caching the source and tiles if we are + # creating something new. + if not hasattr(self, '_cacheValue'): + return super().getState() + return super().getState() + ',%s' % (self._cacheValue, ) + def getMetadata(self): """ Return a dictionary of metadata containing levels, sizeX, sizeY, @@ -333,6 +394,8 @@ def getMetadata(self): :returns: metadata dictionary. """ + if self._levels is None: + self._validateZarr() result = super().getMetadata() if self._framecount > 1: result['frames'] = frames = [] @@ -397,6 +460,9 @@ def _getAssociatedImage(self, imageKey): @methodcache() def getTile(self, x, y, z, pilImageAllowed=False, numpyAllowed=False, **kwargs): + if self._levels is None: + self._validateZarr() + frame = self._getFrame(**kwargs) self._xyzInRange(x, y, z, frame, self._framecount) x0, y0, x1, y1, step = self._xyzToCorners(x, y, z) @@ -439,6 +505,221 @@ def getTile(self, x, y, z, pilImageAllowed=False, numpyAllowed=False, **kwargs): return self._outputTile(tile, TILE_FORMAT_NUMPY, x, y, z, pilImageAllowed, numpyAllowed, **kwargs) + def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs): + """ + Add a numpy or image tile to the image, expanding the image as needed + to accommodate it. Note that x and y can be negative. If so, the + output image (and internal memory access of the image) will act as if + the 0, 0 point is the most negative position. Cropping is applied + after this offset. + + :param tile: a numpy array, PIL Image, or a binary string + with an image. The numpy array can have 2 or 3 dimensions. + :param x: location in destination for upper-left corner. + :param y: location in destination for upper-left corner. + :param mask: a 2-d numpy array (or 3-d if the last dimension is 1). + If specified, areas where the mask is false will not be altered. + :param axes: a string or list of strings specifying the names of axes + in the same order as the tile dimensions + :param kwargs: start locations for any additional axes + """ + # TODO: improve band bookkeeping + + self._checkEditable() + placement = { + 'x': x, + 'y': y, + **kwargs, + } + if not isinstance(tile, np.ndarray) or axes is None: + axes = 'yxs' + tile, mode = _imageToNumpy(tile) + elif not isinstance(axes, str) and not isinstance(axes, list): + err = 'Invalid type for axes. Must be str or list[str].' + raise ValueError(err) + axes = [x.lower() for x in axes] + if axes[-1] != 's': + axes.append('s') + if mask is not None and len(axes) - 1 == len(mask.shape): + mask = mask[:, :, np.newaxis] + if 'x' not in axes or 'y' not in axes: + err = 'Invalid value for axes. Must contain "y" and "x".' + raise ValueError(err) + for k in placement: + if k not in axes: + axes[0:0] = [k] + with self._addLock: + self._axes = {k: i for i, k in enumerate(axes)} + while len(tile.shape) < len(axes): + tile = np.expand_dims(tile, axis=0) + while mask is not None and len(mask.shape) < len(axes): + mask = np.expand_dims(mask, axis=0) + + new_dims = { + a: max( + self._dims.get(a, 0), + placement.get(a, 0) + tile.shape[i], + ) + for a, i in self._axes.items() + } + placement_slices = tuple([ + slice(placement.get(a, 0), placement.get(a, 0) + tile.shape[i], 1) + for i, a in enumerate(axes) + ]) + + current_arrays = dict(self._zarr.arrays()) + chunking = None + if 'root' not in current_arrays: + root = np.empty(tuple(new_dims.values()), dtype=tile.dtype) + chunking = tuple([ + self._tileSize if a in ['x', 'y'] else + new_dims.get('s') if a == 's' else 1 + for a in axes + ]) + else: + root = current_arrays['root'] + root.resize(*tuple(new_dims.values())) + if root.chunks[-1] != new_dims.get('s'): + # rechunk if length of samples axis changes + chunking = tuple([ + self._tileSize if a in ['x', 'y'] else + new_dims.get('s') if a == 's' else 1 + for a in axes + ]) + + if mask is not None: + root[placement_slices] = np.where(mask, tile, root[placement_slices]) + else: + root[placement_slices] = tile + if chunking: + self._zarr.create_dataset('root', data=root[:], chunks=chunking, overwrite=True) + + # Edit OME metadata + self._zarr.attrs.update({ + 'multiscales': [{ + 'version': '0.5-dev', + 'axes': [{ + 'name': a, + 'type': 'space' if a in ['x', 'y'] else 'other', + } for a in axes], + 'datasets': [{'path': 0}], + }], + 'omero': {'version': '0.5-dev'}, + }) + + # Edit large_image attributes + self._dims = new_dims + self._dtype = tile.dtype + self._bandCount = new_dims.get(axes[-1]) # last axis is assumed to be bands + self.sizeX = new_dims.get('x') + self.sizeY = new_dims.get('y') + self._framecount = np.prod([ + length + for axis, length in new_dims.items() + if axis in axes[:-3] + ]) + self._cacheValue = str(uuid.uuid4()) + self._levels = None + self.levels = int(max(1, math.ceil(math.log(max( + self.sizeX / self.tileWidth, self.sizeY / self.tileHeight)) / math.log(2)) + 1)) + + @property + def crop(self): + """ + Crop only applies to the output file, not the internal data access. + + It consists of x, y, w, h in pixels. + """ + return getattr(self, '_crop', None) + + @crop.setter + def crop(self, value): + self._checkEditable() + if value is None: + self._crop = None + return + x, y, w, h = value + x = int(x) + y = int(y) + w = int(w) + h = int(h) + if x < 0 or y < 0 or w <= 0 or h <= 0: + msg = 'Crop must have non-negative x, y and positive w, h' + raise TileSourceError(msg) + self._crop = (x, y, w, h) + + def write( + self, + path, + lossy=True, + alpha=True, + overwriteAllowed=True, + ): + """ + Output the current image to a file. + + :param path: output path. + :param lossy: if false, emit a lossless file. + :param alpha: True if an alpha channel is allowed. + :param overwriteAllowed: if False, raise an exception if the output + path exists. + """ + if os.path.exists(path): + if overwriteAllowed: + if os.path.isdir(path): + shutil.rmtree(path) + else: + os.remove(path) + else: + raise TileSourceError('Output path exists (%s).' % str(path)) + + # TODO: compute half, quarter, etc. resolutions + self._validateZarr() + suffix = Path(path).suffix + data_dir = self._tempdir + data_store = self._zarr_store + + if self.crop: + x, y, w, h = self.crop + current_arrays = dict(self._zarr.arrays()) + # create new temp storage for cropped data + data_dir = tempfile.TemporaryDirectory() + data_store = zarr.DirectoryStore(data_dir.name) + cropped_zarr = zarr.open(data_store, mode='w') + for arr_name in current_arrays: + arr = np.array(current_arrays[arr_name]) + cropped_arr = arr.take( + indices=range(x, x + w), + axis=self._axes.get('x'), + ).take( + indices=range(y, y + h), + axis=self._axes.get('y'), + ) + cropped_zarr.create_dataset(arr_name, data=cropped_arr, overwrite=True) + cropped_zarr.attrs.update(self._zarr.attrs) + + if suffix == '.zarr': + shutil.copytree(data_dir.name, path) + + elif suffix in ['.db', '.sqlite']: + sqlite_store = zarr.SQLiteStore(path) + zarr.copy_store(data_store, sqlite_store, if_exists='replace') + sqlite_store.close() + + elif suffix == '.zip': + zip_store = zarr.ZipStore(path) + zarr.copy_store(data_store, zip_store, if_exists='replace') + zip_store.close() + + else: + from large_image_converter import convert + + attrs_path = Path(data_dir.name) / '.zattrs' + convert(str(attrs_path), path, overwrite=overwriteAllowed) + + if self.crop: + shutil.rmtree(data_dir.name) + def open(*args, **kwargs): """ @@ -452,3 +733,11 @@ def canRead(*args, **kwargs): Check if an input can be read by the module class. """ return ZarrFileTileSource.canRead(*args, **kwargs) + + +def new(*args, **kwargs): + """ + Create a new image, collecting the results from patches of numpy arrays or + smaller images. + """ + return ZarrFileTileSource(NEW_IMAGE_PATH_FLAG + str(uuid.uuid4()), *args, **kwargs) diff --git a/test/test_examples.py b/test/test_examples.py index b6801b327..07d464124 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -85,8 +85,8 @@ def test_sum_squares_import(): @pytest.mark.parametrize(('sink', 'outname', 'openpath'), [ ('multivips', 'sample', 'sample/results.yml'), - # ('zarr', 'sample.zip', 'sample.zip'), - # ('multizarr', 'sample', 'sample/results.yml'), + ('zarr', 'sample.zip', 'sample.zip'), + ('multizarr', 'sample', 'sample/results.yml'), ]) def test_algorithm_progression(sink, outname, openpath, tmp_path): import large_image diff --git a/test/test_sink.py b/test/test_sink.py index 316a2c3ea..f6bccf26c 100644 --- a/test/test_sink.py +++ b/test/test_sink.py @@ -1,195 +1,179 @@ -import pathlib -import random -import tempfile +import large_image_source_test +import large_image_source_zarr import numpy as np import pytest import large_image -possible_axes = { - 'x': [1, 10], - 'y': [1, 10], - 'c': [1, 40], - 'z': [1, 40], - 't': [1, 40], - 'p': [1, 20], - 'q': [1, 20], - 's': [3, 3], -} - -include_axes = { - 'c': False, - 'z': False, - 't': False, - 'p': False, - 'q': False, -} - -possible_data_ranges = [ - [0, 1, 2, float], - [0, 1, 2, np.float16], - [0, 1, 2, np.float32], - [0, 1, 2, np.float64], - [0, 2**8, -1, np.uint8], - [0, 2**8, -1, float], - [0, 2**8, -1, int], - [0, 2**16, -2, np.uint16], - [0, 2**16, -2, float], - [0, 2**32, -4, int], - [-2**7, 2**7, -1, np.int8], - [-2**7, 2**7, -1, float], - [-2**7, 2**7, -1, int], - [-2**15, 2**15, -2, np.int16], - [-2**15, 2**15, -2, float], - [-2**15, 2**15, -2, int], - [-2**31, 2**31, -4, np.int32], - [-2**31, 2**31, -4, float], - [-2**31, 2**31, -4, int], - [-1, 1, 2, float], - [-1, 1, 2, np.float16], - [-1, 1, 2, np.float32], - [-1, 1, 2, np.float64], +TMP_DIR = 'tmp/zarr_sink' +FILE_TYPES = [ + 'tiff', + 'sqlite', + 'db', + 'zip', + 'zarr', + # "dz", + # 'svi', + # 'svs', ] -max_tile_size = 100 -tile_overlap_ratio = 0.5 - - -# https://stackoverflow.com/questions/18915378/rounding-to-significant-figures-in-numpy -def signif(x, minval, maxval, digits): - if x == 0: - return 0 - return max(min(round(x, digits), max(1, maxval - 1)), minval) - - -def get_dims(x, y, s, max=False): - tile_shape = [x, y] - for axis_name, include in include_axes.items(): - if include: - axis_min_max = possible_axes[axis_name] - if max: - tile_shape.append(axis_min_max[1]) - else: - tile_shape.append(random.randint(*axis_min_max)) - # s is last axis - tile_shape.append(s) - return tile_shape - - -def random_tile(data_range): - tile_shape = get_dims( - random.randint(1, max_tile_size), - random.randint(1, max_tile_size), - random.randint(*possible_axes['s']), - include_axes, + +def copyFromSource(source, sink): + metadata = source.getMetadata() + for frame in metadata.get('frames', []): + for tile in source.tileIterator(frame=frame['Frame'], format='numpy'): + t = tile['tile'] + x, y = tile['x'], tile['y'] + kwargs = { + 'z': frame['IndexZ'], + 'c': frame['IndexC'], + } + sink.addTile(t, x=x, y=y, **kwargs) + + +def testNew(): + sink = large_image_source_zarr.new() + assert sink.metadata['levels'] == 0 + assert sink.getRegion(format='numpy')[0].shape[:2] == (0, 0) + + +def testBasicAddTile(): + sink = large_image_source_zarr.new() + sink.addTile(np.random.random((10, 10)), 0, 0) + sink.addTile(np.random.random((10, 10, 2)), 10, 0) + + metadata = sink.getMetadata() + assert metadata.get('levels') == 1 + assert metadata.get('sizeX') == 20 + assert metadata.get('sizeY') == 10 + assert metadata.get('bandCount') == 2 + assert metadata.get('dtype') == 'float64' + + +def testAddTileWithMask(): + sink = large_image_source_zarr.new() + tile0 = np.random.random((10, 10)) + sink.addTile(tile0, 0, 0) + orig = sink.getRegion(format='numpy')[0] + tile1 = np.random.random((10, 10)) + sink.addTile(tile1, 0, 0, mask=np.random.random((10, 10)) > 0.5) + cur = sink.getRegion(format='numpy')[0] + assert (tile0 == orig[:, :, 0]).all() + assert not (tile1 == orig[:, :, 0]).all() + assert not (tile0 == cur[:, :, 0]).all() + assert not (tile1 == cur[:, :, 0]).all() + + +def testExtraAxis(): + sink = large_image_source_zarr.new() + sink.addTile(np.random.random((256, 256)), 0, 0, z=1) + metadata = sink.getMetadata() + assert metadata.get('bandCount') == 1 + assert len(metadata.get('frames')) == 2 + + +@pytest.mark.parametrize('file_type', FILE_TYPES) +def testCrop(file_type, tmp_path): + output_file = tmp_path / f'test.{file_type}' + sink = large_image_source_zarr.new() + + # add tiles with some overlap + sink.addTile(np.random.random((10, 10)), 0, 0) + sink.addTile(np.random.random((10, 10)), 8, 0) + sink.addTile(np.random.random((10, 10)), 0, 8) + sink.addTile(np.random.random((10, 10)), 8, 8) + + region, _ = sink.getRegion(format='numpy') + shape = region.shape + assert shape == (18, 18, 1) + + sink.crop = (2, 2, 10, 10) + + # crop only applies when using write + sink.write(output_file) + if file_type == 'zarr': + output_file /= '.zattrs' + written = large_image.open(output_file) + region, _ = written.getRegion(format='numpy') + shape = region.shape + assert shape == (10, 10, 1) + + +@pytest.mark.parametrize('file_type', FILE_TYPES) +def testImageCopySmall(file_type, tmp_path): + output_file = tmp_path / f'test.{file_type}' + sink = large_image_source_zarr.new() + source = large_image_source_test.TestTileSource( + fractal=True, + tileWidth=128, + tileHeight=128, + sizeX=512, + sizeY=1024, + frames='c=2,z=3', + ) + copyFromSource(source, sink) + + metadata = sink.getMetadata() + assert metadata.get('sizeX') == 512 + assert metadata.get('sizeY') == 1024 + assert metadata.get('dtype') == 'uint8' + assert metadata.get('levels') == 2 + assert metadata.get('bandCount') == 3 + assert len(metadata.get('frames')) == 6 + + # TODO: fix these failures; unexpected metadata when reading it back + sink.write(output_file) + if file_type == 'zarr': + output_file /= '.zattrs' + written = large_image.open(output_file) + new_metadata = written.metadata + + assert new_metadata.get('sizeX') == 512 + assert new_metadata.get('sizeY') == 1024 + assert new_metadata.get('dtype') == 'uint8' + assert new_metadata.get('levels') == 2 or new_metadata.get('levels') == 3 + assert new_metadata.get('bandCount') == 3 + assert len(new_metadata.get('frames')) == 6 + + +@pytest.mark.parametrize('file_type', FILE_TYPES) +def testImageCopySmallMultiband(file_type, tmp_path): + output_file = tmp_path / f'test.{file_type}' + sink = large_image_source_zarr.new() + bands = ( + 'red=400-12000,green=0-65535,blue=800-4000,' + 'ir1=200-24000,ir2=200-22000,gray=100-10000,other=0-65535' ) - tile = np.random.rand(*tile_shape) - tile *= (data_range[1] - data_range[0]) - tile += data_range[0] - tile = tile.astype(data_range[3]) # apply dtype - mask = np.random.randint(2, size=tile_shape[:-1]) - return (tile, mask) - - -def frame_with_zeros(data, desired_size, start_location=None): - if len(desired_size) == 0: - return data - if not start_location or len(start_location) == 0: - start_location = [0] - framed = [ - frame_with_zeros( - data[x - start_location[0]], - desired_size[1:], - start_location=start_location[1:], - ) - if ( # frame with zeros if x>=start and x= start_location[0] and - x < data.shape[0] + start_location[0] - ) # fill with zeros otherwise - else np.zeros(desired_size[1:]) - for x in range(desired_size[0]) - ] - return np.array(framed) - - -@pytest.mark.parametrize('data_range', possible_data_ranges) -def testImageGeneration(data_range): - source = large_image.new() - tile_grid = [ - int(random.randint(*possible_axes['x'])), - int(random.randint(*possible_axes['y'])), - ] - if data_range is None: - data_range = random.choice(possible_data_ranges) - - # create comparison matrix at max size and fill with zeros - expected_shape = get_dims( - tile_grid[1] * max_tile_size, tile_grid[0] * max_tile_size, 4, True, + source = large_image_source_test.TestTileSource( + fractal=True, + tileWidth=128, + tileHeight=128, + sizeX=512, + sizeY=1024, + frames='c=2,z=3', + bands=bands, ) - expected = np.ndarray(expected_shape) - expected.fill(0) - max_x, max_y = 0, 0 - - print( - f'placing {tile_grid[0] * tile_grid[1]} random tiles in available space: {expected_shape}') - print('tile overlap ratio:', tile_overlap_ratio) - print('data range:', data_range) - for x in range(tile_grid[0]): - for y in range(tile_grid[1]): - start_location = [ - int(x * max_tile_size * tile_overlap_ratio), - int(y * max_tile_size * tile_overlap_ratio), - ] - tile, mask = random_tile(data_range) - tile_shape = tile.shape - source.addTile(tile, *start_location, mask=mask) - max_x = max(max_x, start_location[1] + tile_shape[0]) - max_y = max(max_y, start_location[0] + tile_shape[1]) - - framed_tile = np.array(frame_with_zeros( - tile, - expected.shape, - start_location=start_location[::-1], - )) - framed_mask = np.array(frame_with_zeros( - mask.repeat(tile_shape[-1], -1).reshape(tile_shape), - expected.shape, - start_location=start_location[::-1], - )) - - np.putmask(expected, framed_mask, framed_tile) - - with tempfile.TemporaryDirectory() as tmp_dir: - # TODO: make destination use mdf5 extension - destination = pathlib.Path(tmp_dir, 'sample.tiff') - source.write(destination, lossy=False) - result, _ = source.getRegion(format='numpy') - - # trim unused space from expected - expected = expected[:max_x, :max_y] - - # round to specified precision - precision_vector = np.vectorize(signif) - expected = precision_vector(expected, data_range[0], data_range[1], data_range[2]) - result = precision_vector(result, data_range[0], data_range[1], data_range[2]) - - # ignore alpha values for now - expected = expected.take(indices=range(-1), axis=-1) - result = result.take(indices=range(-1), axis=-1) - - # For debugging - # difference = numpy.subtract(result, expected) - # print(difference) - # print(expected[numpy.nonzero(difference)]) - # print(result[numpy.nonzero(difference)]) - - assert np.array_equal(result, expected) - # resultFromFile, _ = large_image.open(destination).getRegion(format='numpy') - # print(resultFromFile.shape, result.shape) - # assert numpy.array_equal(result, resultFromFile) - print(f'Success; result matrix {result.shape} equals expected matrix {expected.shape}.') - - -if __name__ == '__main__': - testImageGeneration(None) + copyFromSource(source, sink) + + metadata = sink.getMetadata() + assert metadata.get('sizeX') == 512 + assert metadata.get('sizeY') == 1024 + assert metadata.get('dtype') == 'uint16' + assert metadata.get('levels') == 2 + assert metadata.get('bandCount') == 7 + assert len(metadata.get('frames')) == 6 + + # TODO: fix these failures; unexpected metadata when reading it back + sink.write(output_file) + if file_type == 'zarr': + output_file /= '.zattrs' + written = large_image.open(output_file) + new_metadata = written.getMetadata() + + assert new_metadata.get('sizeX') == 512 + assert new_metadata.get('sizeY') == 1024 + assert new_metadata.get('dtype') == 'uint16' + assert new_metadata.get('levels') == 2 or new_metadata.get('levels') == 3 + assert new_metadata.get('bandCount') == 7 + assert len(new_metadata.get('frames')) == 6