diff --git a/.gitignore b/.gitignore index ca536ae..008c4a2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.zarr.json +*.nwb .coverage diff --git a/README.md b/README.md index 1f9ffbe..879608d 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ LINDI features include: - A specification for representing arbitrary HDF5 files as Zarr stores. This handles scalar datasets, references, soft links, and compound data types for datasets. - A Zarr wrapper for remote or local HDF5 files (LindiH5ZarrStore). This involves pointers to remote files for remote data chunks. - A function for generating a reference file system .zarr.json file from a Zarr store. This is inspired by [kerchunk](https://github.com/fsspec/kerchunk). -- An h5py-like interface for accessing these Zarr stores that can be used with [pynwb](https://pynwb.readthedocs.io/en/stable/). +- An h5py-like interface for accessing these Zarr stores that can be used with [pynwb](https://pynwb.readthedocs.io/en/stable/). Both read and write operations are supported. This project was inspired by [kerchunk](https://github.com/fsspec/kerchunk) and [hdmf-zarr](https://hdmf-zarr.readthedocs.io/en/latest/index.html) and depends on [zarr](https://zarr.readthedocs.io/en/stable/), [h5py](https://www.h5py.org/), [remfile](https://github.com/magland/remfile) and [numcodecs](https://numcodecs.readthedocs.io/en/stable/). diff --git a/devel/demonstrate_slow_get_chunk_info.py b/devel/demonstrate_slow_get_chunk_info.py index 5ce9449..dc9e090 100644 --- a/devel/demonstrate_slow_get_chunk_info.py +++ b/devel/demonstrate_slow_get_chunk_info.py @@ -22,7 +22,9 @@ def demonstrate_slow_get_chunk_info(): print(f"shape: {shape}") # (128000, 212, 322, 2) print(f"chunk_shape: {chunk_shape}") # (3, 53, 81, 1) chunk_coord_shape = [ - (shape[i] + chunk_shape[i] - 1) // chunk_shape[i] for i in range(len(shape)) + # the shape could be zero -- for example dandiset 000559 - acquisition/depth_video/data has shape [0, 0, 0] + (shape[i] + chunk_shape[i] - 1) // chunk_shape[i] if chunk_shape[i] != 0 else 0 + for i in range(len(shape)) ] print(f"chunk_coord_shape: {chunk_coord_shape}") # [42667, 4, 4, 2] num_chunks = np.prod(chunk_coord_shape) diff --git a/devel/test_write_nwb.py b/devel/test_write_nwb.py new file mode 100644 index 0000000..91f6ac9 --- /dev/null +++ b/devel/test_write_nwb.py @@ -0,0 +1,110 @@ +from typing import Any + +from datetime import datetime +from uuid import uuid4 +import numpy as np +from dateutil.tz import tzlocal +from pynwb import NWBHDF5IO, NWBFile, H5DataIO +from pynwb.ecephys import LFP, ElectricalSeries +import zarr +import lindi + +nwbfile: Any = NWBFile( + session_description="my first synthetic recording", + identifier=str(uuid4()), + session_start_time=datetime.now(tzlocal()), + experimenter=[ + "Baggins, Bilbo", + ], + lab="Bag End Laboratory", + institution="University of Middle Earth at the Shire", + experiment_description="I went on an adventure to reclaim vast treasures.", + session_id="LONELYMTN001", +) + +device = nwbfile.create_device( + name="array", description="the best array", manufacturer="Probe Company 9000" +) + +nwbfile.add_electrode_column(name="label", description="label of electrode") + +nshanks = 4 +nchannels_per_shank = 3 +electrode_counter = 0 + +for ishank in range(nshanks): + # create an electrode group for this shank + electrode_group = nwbfile.create_electrode_group( + name="shank{}".format(ishank), + description="electrode group for shank {}".format(ishank), + device=device, + location="brain area", + ) + # add electrodes to the electrode table + for ielec in range(nchannels_per_shank): + nwbfile.add_electrode( + group=electrode_group, + label="shank{}elec{}".format(ishank, ielec), + location="brain area", + ) + electrode_counter += 1 + +all_table_region = nwbfile.create_electrode_table_region( + region=list(range(electrode_counter)), # reference row indices 0 to N-1 + description="all electrodes", +) + +raw_data = np.random.randn(300000, 100) +raw_electrical_series = ElectricalSeries( + name="ElectricalSeries", + data=H5DataIO(data=raw_data, chunks=(100000, 100)), # type: ignore + electrodes=all_table_region, + starting_time=0.0, # timestamp of the first sample in seconds relative to the session start time + rate=20000.0, # in Hz +) + +nwbfile.add_acquisition(raw_electrical_series) + +lfp_data = np.random.randn(50, 12) +lfp_electrical_series = ElectricalSeries( + name="ElectricalSeries", + data=lfp_data, + electrodes=all_table_region, + starting_time=0.0, + rate=200.0, +) + +lfp = LFP(electrical_series=lfp_electrical_series) + +ecephys_module = nwbfile.create_processing_module( + name="ecephys", description="processed extracellular electrophysiology data" +) +ecephys_module.add(lfp) + +nwbfile.add_unit_column(name="quality", description="sorting quality") + +firing_rate = 20 +n_units = 10 +res = 1000 +duration = 20 +for n_units_per_shank in range(n_units): + spike_times = ( + np.where(np.random.rand((res * duration)) < (firing_rate / res))[0] / res + ) + nwbfile.add_unit(spike_times=spike_times, quality="good") + +# with tempfile.TemporaryDirectory() as tmpdir: +tmpdir = '.' +dirname = f'{tmpdir}/test.nwb' +store = zarr.DirectoryStore(dirname) +# create a top-level group +root = zarr.group(store=store, overwrite=True) +client = lindi.LindiH5pyFile.from_zarr_store(store, mode='r+') +with NWBHDF5IO(file=client, mode='w') as io: + io.write(nwbfile) # type: ignore + +store2 = zarr.DirectoryStore(dirname) +client2 = lindi.LindiH5pyFile.from_zarr_store(store2, mode='r') +with NWBHDF5IO(file=client2, mode='r') as io: + nwbfile2 = io.read() # type: ignore + print(nwbfile2) diff --git a/docs/special_zarr_annotations.md b/docs/special_zarr_annotations.md index 7f64dc3..e9487b5 100644 --- a/docs/special_zarr_annotations.md +++ b/docs/special_zarr_annotations.md @@ -44,7 +44,7 @@ HDF5 references can appear within both attributes and datasets. For attributes, ### `_COMPOUND_DTYPE: [['x', 'int32'], ['y', 'float64'], ...]` -Zarr arrays can represent compound data types from HDF5 datasets. The `_COMPOUND_DTYPE` attribute on a Zarr array indicates this, listing each field's name and data type. The array data should be JSON encoded, aligning with the specified compound structure. The `h5py.Reference` type is also supported within these structures, enabling references within compound data types. +Zarr arrays can represent compound data types from HDF5 datasets. The `_COMPOUND_DTYPE` attribute on a Zarr array indicates this, listing each field's name and data type. The array data should be JSON encoded, aligning with the specified compound structure. The `h5py.Reference` type is also supported within these structures (represented by the type string ''). ## External Array Links diff --git a/examples/example_create_zarr_nwb.py b/examples/example_create_zarr_nwb.py new file mode 100644 index 0000000..eb2dd1e --- /dev/null +++ b/examples/example_create_zarr_nwb.py @@ -0,0 +1,121 @@ +from typing import Any +import shutil +import os +import zarr +import pynwb +import lindi + + +def example_create_zarr_nwb(): + zarr_dirname = 'example_nwb.zarr' + if os.path.exists(zarr_dirname): + shutil.rmtree(zarr_dirname) + + nwbfile = _create_sample_nwb_file() + + store = zarr.DirectoryStore(zarr_dirname) + zarr.group(store=store) # create a root group + with lindi.LindiH5pyFile.from_zarr_store(store, mode='r+') as client: + with pynwb.NWBHDF5IO(file=client, mode='r+') as io: + io.write(nwbfile) # type: ignore + + +def _create_sample_nwb_file(): + from datetime import datetime + from uuid import uuid4 + + import numpy as np + from dateutil.tz import tzlocal + + from pynwb import NWBFile + from pynwb.ecephys import LFP, ElectricalSeries + + nwbfile: Any = NWBFile( + session_description="my first synthetic recording", + identifier=str(uuid4()), + session_start_time=datetime.now(tzlocal()), + experimenter=[ + "Baggins, Bilbo", + ], + lab="Bag End Laboratory", + institution="University of Middle Earth at the Shire", + experiment_description="I went on an adventure to reclaim vast treasures.", + session_id="LONELYMTN001", + ) + + device = nwbfile.create_device( + name="array", description="the best array", manufacturer="Probe Company 9000" + ) + + nwbfile.add_electrode_column(name="label", description="label of electrode") + + nshanks = 4 + nchannels_per_shank = 3 + electrode_counter = 0 + + for ishank in range(nshanks): + # create an electrode group for this shank + electrode_group = nwbfile.create_electrode_group( + name="shank{}".format(ishank), + description="electrode group for shank {}".format(ishank), + device=device, + location="brain area", + ) + # add electrodes to the electrode table + for ielec in range(nchannels_per_shank): + nwbfile.add_electrode( + group=electrode_group, + label="shank{}elec{}".format(ishank, ielec), + location="brain area", + ) + electrode_counter += 1 + + all_table_region = nwbfile.create_electrode_table_region( + region=list(range(electrode_counter)), # reference row indices 0 to N-1 + description="all electrodes", + ) + + raw_data = np.random.randn(50, 12) + raw_electrical_series = ElectricalSeries( + name="ElectricalSeries", + data=raw_data, + electrodes=all_table_region, + starting_time=0.0, # timestamp of the first sample in seconds relative to the session start time + rate=20000.0, # in Hz + ) + + nwbfile.add_acquisition(raw_electrical_series) + + lfp_data = np.random.randn(50, 12) + lfp_electrical_series = ElectricalSeries( + name="ElectricalSeries", + data=lfp_data, + electrodes=all_table_region, + starting_time=0.0, + rate=200.0, + ) + + lfp = LFP(electrical_series=lfp_electrical_series) + + ecephys_module = nwbfile.create_processing_module( + name="ecephys", description="processed extracellular electrophysiology data" + ) + ecephys_module.add(lfp) + + nwbfile.add_unit_column(name="quality", description="sorting quality") + + firing_rate = 20 + n_units = 10 + res = 1000 + duration = 20 + for n_units_per_shank in range(n_units): + spike_times = ( + np.where(np.random.rand((res * duration)) < (firing_rate / res))[0] / res + ) + nwbfile.add_unit(spike_times=spike_times, quality="good") + + return nwbfile + + +if __name__ == '__main__': + example_create_zarr_nwb() diff --git a/examples/example_edit_nwb.py b/examples/example_edit_nwb.py new file mode 100644 index 0000000..735def0 --- /dev/null +++ b/examples/example_edit_nwb.py @@ -0,0 +1,32 @@ +import lindi +import h5py +import pynwb + + +# Define the URL for a remote .zarr.json file +url = 'https://kerchunk.neurosift.org/dandi/dandisets/000939/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/zarr.json' + +# Load the h5py-like client from the reference file system +client = lindi.LindiH5pyFile.from_reference_file_system(url, mode='r+') + +# modify the age of the subject +subject = client['general']['subject'] # type: ignore +assert isinstance(subject, h5py.Group) +del subject['age'] # type: ignore +subject.create_dataset('age', data=b'3w') + +# Create a new reference file system +rfs_new = client.to_reference_file_system() + +# Optionally write to a file +# import json +# with open('new.zarr.json', 'w') as f: +# json.dump(rfs_new, f) + +# Load a new h5py-like client from the new reference file system +client_new = lindi.LindiH5pyFile.from_reference_file_system(rfs_new) + +# Open using pynwb and verify that the subject age has been updated +with pynwb.NWBHDF5IO(file=client, mode="r") as io: + nwbfile = io.read() + print(nwbfile) diff --git a/lindi/LindiH5ZarrStore/FloatJsonEncoder.py b/lindi/LindiH5ZarrStore/FloatJsonEncoder.py deleted file mode 100644 index 0ee030b..0000000 --- a/lindi/LindiH5ZarrStore/FloatJsonEncoder.py +++ /dev/null @@ -1,41 +0,0 @@ -import json -import numpy as np - - -# From https://github.com/rly/h5tojson/blob/b162ff7f61160a48f1dc0026acb09adafdb422fa/h5tojson/h5tojson.py#L121-L156 -class FloatJSONEncoder(json.JSONEncoder): - """JSON encoder that converts NaN, Inf, and -Inf to strings.""" - - def encode(self, obj, *args, **kwargs): # type: ignore - """Convert NaN, Inf, and -Inf to strings.""" - obj = FloatJSONEncoder._convert_nan(obj) - return super().encode(obj, *args, **kwargs) - - def iterencode(self, obj, *args, **kwargs): # type: ignore - """Convert NaN, Inf, and -Inf to strings.""" - obj = FloatJSONEncoder._convert_nan(obj) - return super().iterencode(obj, *args, **kwargs) - - @staticmethod - def _convert_nan(obj): - """Convert NaN, Inf, and -Inf from a JSON object to strings.""" - if isinstance(obj, dict): - return {k: FloatJSONEncoder._convert_nan(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [FloatJSONEncoder._convert_nan(v) for v in obj] - elif isinstance(obj, float): - return FloatJSONEncoder._nan_to_string(obj) - return obj - - @staticmethod - def _nan_to_string(obj: float): - """Convert NaN, Inf, and -Inf from a float to a string.""" - if np.isnan(obj): - return "NaN" - elif np.isinf(obj): - if obj > 0: - return "Infinity" - else: - return "-Infinity" - else: - return float(obj) diff --git a/lindi/LindiH5ZarrStore/LindiH5ZarrStore.py b/lindi/LindiH5ZarrStore/LindiH5ZarrStore.py index 497a763..d1d1a20 100644 --- a/lindi/LindiH5ZarrStore/LindiH5ZarrStore.py +++ b/lindi/LindiH5ZarrStore/LindiH5ZarrStore.py @@ -7,14 +7,17 @@ import remfile from zarr.storage import Store, MemoryStore import h5py -from ._zarr_info_for_h5_dataset import _zarr_info_for_h5_dataset from ._util import ( _read_bytes, _get_chunk_byte_range, _get_byte_range_for_contiguous_dataset, + _join, + _get_chunk_names_for_dataset ) -from ._h5_attr_to_zarr_attr import _h5_attr_to_zarr_attr -from ._utils import _join, _get_chunk_names_for_dataset, _reformat_json +from ..conversion.attr_conversion import h5_to_zarr_attr +from ..conversion.reformat_json import reformat_json +from ..conversion.h5_filters_to_codecs import h5_filters_to_codecs +from ..conversion.create_zarr_dataset_from_h5_data import create_zarr_dataset_from_h5_data @dataclass @@ -66,9 +69,8 @@ def __init__( # Some datasets do not correspond to traditional chunked datasets. For # those datasets, we need to store the inline data so that we can return - # it when the chunk is requested. We store the inline data in a - # dictionary with the dataset name as the key. The values are the bytes. - self._inline_data_for_arrays: Dict[str, bytes] = {} + # it when the chunk is requested. + self._inline_arrays: Dict[str, InlineArray] = {} self._external_array_links: Dict[str, Union[dict, None]] = {} @@ -136,6 +138,12 @@ def __getitem__(self, key): # Otherwise, we assume it is a chunk file return self._get_chunk_file_bytes(key_parent=key_parent, key_name=key_name) + def get(self, key, default=None): + try: + return self.__getitem__(key) + except KeyError: + return default + def __contains__(self, key): """Check if a key is in the store (used by zarr).""" # it would be nice if we didn't have to repeat the logic from __getitem__ @@ -175,16 +183,23 @@ def __contains__(self, key): if external_array_link is not None: # The chunk files do not exist for external array links return False + if np.prod(h5_item.shape) == 0: + return False if h5_item.ndim == 0: return key_name == "0" chunk_name_parts = key_name.split(".") if len(chunk_name_parts) != h5_item.ndim: return False - shape = h5_item.shape - chunks = h5_item.chunks or shape - chunk_coords_shape = [ - (shape[i] + chunks[i] - 1) // chunks[i] for i in range(len(shape)) - ] + inline_array = self._get_inline_array(key, h5_item) + if inline_array.is_inline: + chunk_coords_shape = (1,) * h5_item.ndim + else: + shape = h5_item.shape + chunks = h5_item.chunks or shape + chunk_coords_shape = [ + (shape[i] + chunks[i] - 1) // chunks[i] if chunks[i] != 0 else 0 + for i in range(len(shape)) + ] chunk_coords = tuple(int(x) for x in chunk_name_parts) for i, c in enumerate(chunk_coords): if c < 0 or c >= chunk_coords_shape[i]: @@ -222,7 +237,7 @@ def _get_zattrs_bytes(self, parent_key: str): # if it's a soft link, we return a special attribute and ignore # the rest of the attributes because they should be stored in # the target of the soft link - return _reformat_json(json.dumps({ + return reformat_json(json.dumps({ "_SOFT_LINK": { "path": link.path } @@ -234,28 +249,18 @@ def _get_zattrs_bytes(self, parent_key: str): memory_store = MemoryStore() dummy_group = zarr.group(store=memory_store) for k, v in h5_item.attrs.items(): - v2 = _h5_attr_to_zarr_attr(v, label=f"{parent_key} {k}", h5f=self._h5f) + v2 = h5_to_zarr_attr(v, label=f"{parent_key} {k}", h5f=self._h5f) if v2 is not None: dummy_group.attrs[k] = v2 if isinstance(h5_item, h5py.Dataset): - if h5_item.ndim == 0: - dummy_group.attrs["_SCALAR"] = True - if h5_item.dtype.kind == "V": # compound type - compound_dtype = [ - [name, str(h5_item.dtype[name])] - for name in h5_item.dtype.names - ] - # For example: [['x', 'uint32'], ['y', 'uint32'], ['weight', 'float32']] - dummy_group.attrs["_COMPOUND_DTYPE"] = compound_dtype + inline_array = self._get_inline_array(parent_key, h5_item) + for k, v in inline_array.additional_zarr_attrs.items(): + dummy_group.attrs[k] = v external_array_link = self._get_external_array_link(parent_key, h5_item) if external_array_link is not None: dummy_group.attrs["_EXTERNAL_ARRAY_LINK"] = external_array_link - zattrs_content = _reformat_json(memory_store.get(".zattrs")) - if zattrs_content is not None: - return zattrs_content - else: - # No attributes, so we return an empty JSON object - return "{}".encode("utf-8") + zattrs_content = reformat_json(memory_store.get(".zattrs") or "{}".encode("utf-8")) + return zattrs_content def _get_zgroup_bytes(self, parent_key: str): """Get the .zgroup JSON text for a group""" @@ -268,7 +273,13 @@ def _get_zgroup_bytes(self, parent_key: str): # from it. memory_store = MemoryStore() zarr.group(store=memory_store) - return _reformat_json(memory_store.get(".zgroup")) + return reformat_json(memory_store.get(".zgroup")) + + def _get_inline_array(self, key: str, h5_dataset: h5py.Dataset): + if key in self._inline_arrays: + return self._inline_arrays[key] + self._inline_arrays[key] = InlineArray(h5_dataset) + return self._inline_arrays[key] def _get_zarray_bytes(self, parent_key: str): """Get the .zarray JSON text for a dataset""" @@ -278,9 +289,12 @@ def _get_zarray_bytes(self, parent_key: str): if not isinstance(h5_item, h5py.Dataset): raise Exception(f"Item {parent_key} is not a dataset") # get the shape, chunks, dtype, and filters from the h5 dataset - info = _zarr_info_for_h5_dataset(h5_item) - if info.inline_data is not None: - self._inline_data_for_arrays[parent_key] = info.inline_data + inline_array = self._get_inline_array(parent_key, h5_item) + if inline_array.is_inline: + return inline_array.zarray_bytes + + filters = h5_filters_to_codecs(h5_item) + # We create a dummy zarr dataset with the appropriate shape, chunks, # dtype, and filters and then copy the .zarray JSON text from it memory_store = MemoryStore() @@ -290,16 +304,18 @@ def _get_zarray_bytes(self, parent_key: str): # to get the .zarray JSON text from the dummy group. dummy_group.create_dataset( name="dummy_array", - shape=info.shape, - chunks=info.chunks, - dtype=info.dtype, + shape=h5_item.shape, + # It's important to not have chunks be None here because that would + # let zarr choose an optimal chunking, whereas we need this to reflect + # the actual chunking in the HDF5 file. + chunks=h5_item.chunks if h5_item.chunks is not None else h5_item.shape, + dtype=h5_item.dtype, compressor=None, order="C", - fill_value=info.fill_value, - filters=info.filters, - object_codec=info.object_codec, + fill_value=h5_item.fillvalue, + filters=filters ) - zarray_text = _reformat_json(memory_store.get("dummy_array/.zarray")) + zarray_text = reformat_json(memory_store.get("dummy_array/.zarray")) return zarray_text @@ -341,18 +357,19 @@ def _get_chunk_file_bytes_data(self, key_parent: str, key_name: str): f"Chunk name {key_name} does not match dataset dimensions" ) - # Check whether we have inline data for this array. This would be set - # when we created the .zarray JSON text for the dataset. Note that this - # means that the .zarray file needs to be read before the chunk files, - # which should always be the case (I assume). - if key_parent in self._inline_data_for_arrays: - x = self._inline_data_for_arrays[key_parent] - if isinstance(x, bytes): - return None, None, x - else: + # In the case of shape 0, we raise an exception because we shouldn't be here + if np.prod(h5_item.shape) == 0: + raise Exception( + f"Chunk file {key_parent}/{key_name} is not present because the dataset has shape 0." + ) + + inline_array = self._get_inline_array(key_parent, h5_item) + if inline_array.is_inline: + if key_name != inline_array.chunk_fname: raise Exception( - f"Inline data for dataset {key_parent} is not bytes. It is {type(x)}" + f"Chunk name {key_name} does not match dataset dimensions for inline array {key_parent}" ) + return None, None, inline_array.chunk_bytes # If this is a scalar, then the data should have been inline if h5_item.ndim == 0: @@ -366,7 +383,7 @@ def _get_chunk_file_bytes_data(self, key_parent: str, key_name: str): for i, c in enumerate(chunk_coords): if c < 0 or c >= h5_item.shape[i]: raise Exception( - f"Chunk coordinates {chunk_coords} out of range for dataset {key_parent}" + f"Chunk coordinates {chunk_coords} out of range for dataset {key_parent} with dtype {h5_item.dtype}" ) if h5_item.chunks is not None: # Get the byte range in the file for the chunk. @@ -376,7 +393,7 @@ def _get_chunk_file_bytes_data(self, key_parent: str, key_name: str): # coordinates are (0, 0, 0, ...) if chunk_coords != (0,) * h5_item.ndim: raise Exception( - f"Chunk coordinates {chunk_coords} are not (0, 0, 0, ...) for contiguous dataset {key_parent}" + f"Chunk coordinates {chunk_coords} are not (0, 0, 0, ...) for contiguous dataset {key_parent} with dtype {h5_item.dtype} and shape {h5_item.shape}" ) # Get the byte range in the file for the contiguous dataset byte_offset, byte_count = _get_byte_range_for_contiguous_dataset(h5_item) @@ -395,7 +412,8 @@ def _get_external_array_link(self, parent_key: str, h5_item: h5py.Dataset): shape = h5_item.shape chunks = h5_item.chunks chunk_coords_shape = [ - (shape[i] + chunks[i] - 1) // chunks[i] for i in range(len(shape)) + (shape[i] + chunks[i] - 1) // chunks[i] if chunks[i] != 0 else 0 + for i in range(len(shape)) ] num_chunks = np.prod(chunk_coords_shape) if num_chunks > self._opts.num_dataset_chunks_threshold: @@ -509,8 +527,8 @@ def _process_group(key, item: h5py.Group): _process_dataset(_join(key, k)) def _process_dataset(key): - # Add the .zattrs and .zarray files for the dataset - zattrs_bytes = self.get(f"{key}/.zattrs") + # Add the .zattrs and .zarray files for the dataset= + zattrs_bytes = self[f"{key}/.zattrs"] assert zattrs_bytes is not None if zattrs_bytes != b"{}": # don't include empty zattrs _add_ref(f"{key}/.zattrs", zattrs_bytes) @@ -528,7 +546,8 @@ def _process_dataset(key): shape = zarray_dict["shape"] chunks = zarray_dict.get("chunks", None) chunk_coords_shape = [ - (shape[i] + chunks[i] - 1) // chunks[i] + # the shape could be zero -- for example dandiset 000559 - acquisition/depth_video/data has shape [0, 0, 0] + (shape[i] + chunks[i] - 1) // chunks[i] if chunks[i] != 0 else 0 for i in range(len(shape)) ] # For example, chunk_names could be ['0', '1', '2', ...] @@ -556,3 +575,78 @@ def _process_dataset(key): # Process the groups recursively starting with the root group _process_group("", self._h5f) return ret + + +class InlineArray: + def __init__(self, h5_dataset: h5py.Dataset): + self._additional_zarr_attributes = {} + if h5_dataset.shape == (): + self._additional_zarr_attributes["_SCALAR"] = True + self._is_inline = True + ... + elif h5_dataset.dtype.kind in ['i', 'u', 'f']: # integer or float + self._is_inline = False + else: + self._is_inline = True + if h5_dataset.dtype.kind == "V" and h5_dataset.dtype.fields is not None: # compound type + compound_dtype = [] + for name in h5_dataset.dtype.names: + tt = h5_dataset.dtype[name] + if tt == h5py.special_dtype(ref=h5py.Reference): + tt = "" + compound_dtype.append((name, str(tt))) + # For example: [['x', 'uint32'], ['y', 'uint32'], ['weight', 'float32']] + self._additional_zarr_attributes["_COMPOUND_DTYPE"] = compound_dtype + if self._is_inline: + memory_store = MemoryStore() + dummy_group = zarr.group(store=memory_store) + size_is_zero = np.prod(h5_dataset.shape) == 0 + create_zarr_dataset_from_h5_data( + zarr_parent_group=dummy_group, + name='X', + # For inline data it's important for now that we enforce a + # single chunk because the rest of the code assumes a single + # chunk for inline data. The assumption is that the inline + # arrays are not going to be very large. + h5_chunks=h5_dataset.shape if h5_dataset.shape != () and not size_is_zero else None, + label=f'{h5_dataset.name}', + h5_shape=h5_dataset.shape, + h5_dtype=h5_dataset.dtype, + h5f=h5_dataset.file, + h5_data=h5_dataset[...] + ) + self._zarray_bytes = reformat_json(memory_store['X/.zarray']) + if not size_is_zero: + if h5_dataset.ndim == 0: + chunk_fname = '0' + else: + chunk_fname = '.'.join(['0'] * h5_dataset.ndim) + self._chunk_fname = chunk_fname + self._chunk_bytes = memory_store[f'X/{chunk_fname}'] + else: + self._chunk_fname = None + self._chunk_bytes = None + else: + self._zarray_bytes = None + self._chunk_fname = None + self._chunk_bytes = None + + @property + def is_inline(self): + return self._is_inline + + @property + def additional_zarr_attrs(self): + return self._additional_zarr_attributes + + @property + def zarray_bytes(self): + return self._zarray_bytes + + @property + def chunk_fname(self): + return self._chunk_fname + + @property + def chunk_bytes(self): + return self._chunk_bytes diff --git a/lindi/LindiH5ZarrStore/_h5_attr_to_zarr_attr.py b/lindi/LindiH5ZarrStore/_h5_attr_to_zarr_attr.py deleted file mode 100644 index 054b290..0000000 --- a/lindi/LindiH5ZarrStore/_h5_attr_to_zarr_attr.py +++ /dev/null @@ -1,112 +0,0 @@ -from typing import Any -import numpy as np -import h5py - - -def _h5_attr_to_zarr_attr(attr: Any, *, label: str = '', h5f: h5py.File): - """Convert an attribute from h5py to a format that zarr can accept. - - bytes -> decoded utf-8 string - int, float, str -> unchanged - list -> recursively convert each element - dict -> recursively convert each value - h5py.Reference -> convert to a reference object, see _h5_ref_to_zarr_attr - - Otherwise, raise NotImplementedError - """ - - # first disallow special strings - special_strings = ['NaN', 'Infinity', '-Infinity'] - if isinstance(attr, str) and attr in special_strings: - raise ValueError(f"Special string {attr} not allowed in attribute value at {label}") - if isinstance(attr, bytes) and attr in [x.encode('utf-8') for x in special_strings]: - raise ValueError(f"Special string {attr} not allowed in attribute value at {label}") - - if attr is None: - return None - elif isinstance(attr, bytes): - return attr.decode('utf-8') - elif isinstance(attr, (int, float, str)): - return attr - elif np.issubdtype(type(attr), np.integer): - return int(attr) - elif np.issubdtype(type(attr), np.floating): - return float(attr) - elif np.issubdtype(type(attr), np.bool_): - return bool(attr) - elif type(attr) is np.bytes_: - return attr.tobytes().decode('utf-8') - elif isinstance(attr, h5py.Reference): - return _h5_ref_to_zarr_attr(attr, label=label + '._REFERENCE', h5f=h5f) - elif isinstance(attr, list): - return [_h5_attr_to_zarr_attr(x, label=label, h5f=h5f) for x in attr] - elif isinstance(attr, dict): - return {k: _h5_attr_to_zarr_attr(v, label=label, h5f=h5f) for k, v in attr.items()} - elif isinstance(attr, np.ndarray): - return _h5_attr_to_zarr_attr(attr.tolist(), label=label, h5f=h5f) - else: - print(f'Warning: attribute of type {type(attr)} not handled: {label}') - raise NotImplementedError() - - -def _h5_ref_to_zarr_attr(ref: h5py.Reference, *, label: str = '', h5f: h5py.File): - """Convert an h5py reference to a format that zarr can accept. - - The format is a dictionary with a single key, '_REFERENCE', whose value is - another dictionary with the following keys: - - 'object_id', 'path', 'source', 'source_object_id' - - * object_id is the object ID of the target object. - * path is the path of the target object. - * source is always '.', meaning that path is relative to the root of the - file (I think) - * source_object_id is the object ID of the source object. - - See - https://hdmf-zarr.readthedocs.io/en/latest/storage.html#storing-object-references-in-attributes - - Note that we will also need to handle "region" references. I would propose - another field in the value containing the region info. See - https://hdmf-zarr.readthedocs.io/en/latest/storage.html#sec-zarr-storage-references-region - """ - file_id = h5f.id - - # The get_name call can actually be quite slow. A possible way around this - # is to do an initial pass through the file and build a map of object IDs to - # paths. This would need to happen elsewhere in the code. - deref_objname = h5py.h5r.get_name(ref, file_id) - if deref_objname is None: - raise ValueError(f"Could not dereference object with reference {ref}") - deref_objname = deref_objname.decode("utf-8") - - dref_obj = h5f[deref_objname] - object_id = dref_obj.attrs.get("object_id", None) - - # Here we assume that the file has a top-level attribute called "object_id". - # This will be the case for files created by the LindiH5ZarrStore class. - file_object_id = h5f.attrs.get("object_id", None) - - # See https://hdmf-zarr.readthedocs.io/en/latest/storage.html#storing-object-references-in-attributes - value = { - "object_id": object_id, - "path": deref_objname, - "source": ".", # Are we always going to use the top-level object as the source? - "source_object_id": file_object_id, - } - - # This is how hdmf_zarr does it, but I would propose to use a _REFERENCE key - # instead. Note that we will also need to handle "region" references. I - # would propose another field in the value containing the region info. See - # https://hdmf-zarr.readthedocs.io/en/latest/storage.html#sec-zarr-storage-references-region - - # return { - # "zarr_dtype": "object", - # "value": value - # } - - # important to run it through _h5_attr_to_zarr_attr to handle object IDs of - # type bytes - return _h5_attr_to_zarr_attr({ - "_REFERENCE": value - }, label=label, h5f=h5f) diff --git a/lindi/LindiH5ZarrStore/_util.py b/lindi/LindiH5ZarrStore/_util.py index 2976dbc..0866292 100644 --- a/lindi/LindiH5ZarrStore/_util.py +++ b/lindi/LindiH5ZarrStore/_util.py @@ -1,4 +1,4 @@ -from typing import IO +from typing import IO, List import numpy as np import h5py @@ -20,7 +20,8 @@ def _get_chunk_byte_range(h5_dataset: h5py.Dataset, chunk_coords: tuple) -> tupl assert chunk_shape is not None chunk_coords_shape = [ - (shape[i] + chunk_shape[i] - 1) // chunk_shape[i] + # the shape could be zero -- for example dandiset 000559 - acquisition/depth_video/data has shape [0, 0, 0] + (shape[i] + chunk_shape[i] - 1) // chunk_shape[i] if chunk_shape[i] != 0 else 0 for i in range(len(shape)) ] ndim = h5_dataset.ndim @@ -55,3 +56,30 @@ def _get_byte_range_for_contiguous_dataset(h5_dataset: h5py.Dataset) -> tuple: byte_offset = dsid.get_offset() byte_count = dsid.get_storage_size() return byte_offset, byte_count + + +def _join(a: str, b: str) -> str: + if a == "": + return b + else: + return f"{a}/{b}" + + +def _get_chunk_names_for_dataset(chunk_coords_shape: List[int]) -> List[str]: + """Get the chunk names for a dataset with the given chunk coords shape. + + For example: _get_chunk_names_for_dataset([1, 2, 3]) returns + ['0.0.0', '0.0.1', '0.0.2', '0.1.0', '0.1.1', '0.1.2'] + """ + ndim = len(chunk_coords_shape) + if ndim == 0: + return ["0"] + elif ndim == 1: + return [str(i) for i in range(chunk_coords_shape[0])] + else: + names0 = _get_chunk_names_for_dataset(chunk_coords_shape[1:]) + names = [] + for i in range(chunk_coords_shape[0]): + for name0 in names0: + names.append(f"{i}.{name0}") + return names diff --git a/lindi/LindiH5ZarrStore/_utils.py b/lindi/LindiH5ZarrStore/_utils.py deleted file mode 100644 index 892fe52..0000000 --- a/lindi/LindiH5ZarrStore/_utils.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import List, Union -import json -from .FloatJsonEncoder import FloatJSONEncoder - - -def _join(a: str, b: str) -> str: - if a == "": - return b - else: - return f"{a}/{b}" - - -def _get_chunk_names_for_dataset(chunk_coords_shape: List[int]) -> List[str]: - """Get the chunk names for a dataset with the given chunk coords shape. - - For example: _get_chunk_names_for_dataset([1, 2, 3]) returns - ['0.0.0', '0.0.1', '0.0.2', '0.1.0', '0.1.1', '0.1.2'] - """ - ndim = len(chunk_coords_shape) - if ndim == 0: - return ["0"] - elif ndim == 1: - return [str(i) for i in range(chunk_coords_shape[0])] - else: - names0 = _get_chunk_names_for_dataset(chunk_coords_shape[1:]) - names = [] - for i in range(chunk_coords_shape[0]): - for name0 in names0: - names.append(f"{i}.{name0}") - return names - - -def _reformat_json(x: Union[bytes, None]) -> Union[bytes, None]: - """Reformat to not include whitespace and to encode NaN, Inf, and -Inf as strings.""" - if x is None: - return None - a = json.loads(x.decode("utf-8")) - return json.dumps(a, cls=FloatJSONEncoder, separators=(",", ":")).encode("utf-8") diff --git a/lindi/LindiH5ZarrStore/_zarr_info_for_h5_dataset.py b/lindi/LindiH5ZarrStore/_zarr_info_for_h5_dataset.py deleted file mode 100644 index de1105c..0000000 --- a/lindi/LindiH5ZarrStore/_zarr_info_for_h5_dataset.py +++ /dev/null @@ -1,215 +0,0 @@ -import json -import struct -from typing import Union, List, Any, Tuple -from dataclasses import dataclass -import numpy as np -import numcodecs -import h5py -from numcodecs.abc import Codec -from ._h5_attr_to_zarr_attr import _h5_ref_to_zarr_attr -from ._h5_filters_to_codecs import _h5_filters_to_codecs - - -@dataclass -class ZarrInfoForH5Dataset: - shape: Tuple[int] - chunks: Tuple[int] - dtype: Any - filters: Union[List[Codec], None] - fill_value: Any - object_codec: Union[None, Codec] - inline_data: Union[bytes, None] - - -def _zarr_info_for_h5_dataset(h5_dataset: h5py.Dataset) -> ZarrInfoForH5Dataset: - """Get the information needed to create a zarr dataset from an h5py dataset. - - This is the main workhorse function for LindiH5ZarrStore. It takes an h5py - dataset and returns a ZarrInfoForH5Dataset object. - - It handles the following cases: - - For non-scalars, if it is a numeric array, then the data can stay where it - is in the hdf5 file. The hdf5 filters are translated into zarr filters using - the _h5_filters_to_codecs function. - - If it is a non-scalar object array, then the inline_data will be a JSON - string and the JSON codec will be used. - - When the shape is (), we have a scalar dataset. Since zarr doesn't support - scalar datasets, we make an array of shape (1,). The _SCALAR attribute will - be set to True elsewhere to indicate that it is actually a scalar. The - inline_data attribute will be set. In the case of a numeric scalar, it will - be a bytes object with the binary representation of the value. In the case - of an object, the inline_data will be a JSON string and the JSON codec will - be used. - """ - shape = h5_dataset.shape - dtype = h5_dataset.dtype - - if len(shape) == 0: - # scalar dataset - value = h5_dataset[()] - # zarr doesn't support scalar datasets, so we make an array of shape (1,) - # and the _SCALAR attribute will be set to True elsewhere to indicate that - # it is a scalar dataset - - # Let's handle all the possible types explicitly - numeric_format_str = _get_numeric_format_str(dtype) - if numeric_format_str is not None: - # Handle the simple numeric types - inline_data = struct.pack(numeric_format_str, value) - return ZarrInfoForH5Dataset( - shape=(1,), - chunks=(1,), # be explicit about chunks - dtype=dtype, - filters=None, - fill_value=0, - object_codec=None, - inline_data=inline_data - ) - elif dtype == object: - # For type object, we are going to use the JSON codec - # which requires inline data of the form [[val], '|O', [1]] - if isinstance(value, (bytes, str)): - if isinstance(value, bytes): - value = value.decode() - return ZarrInfoForH5Dataset( - shape=(1,), - chunks=(1,), # be explicit about chunks - dtype=dtype, - filters=None, - fill_value=' ', - object_codec=numcodecs.JSON(), - inline_data=json.dumps([value, '|O', [1]], separators=(',', ':')).encode('utf-8') - ) - else: - raise Exception(f'Not yet implemented (1): object scalar dataset with value {value} and dtype {dtype}') - else: - raise Exception(f'Cannot handle scalar dataset {h5_dataset.name} with dtype {dtype}') - else: - # not a scalar dataset - if dtype.kind in ['i', 'u', 'f', 'b']: # integer, unsigned integer, float, boolean - # This is the normal case of a chunked dataset with a numeric (or boolean) dtype - filters = _h5_filters_to_codecs(h5_dataset) - chunks = h5_dataset.chunks - if chunks is None: - # If the dataset is not chunked, we use the entire dataset as a single chunk - # It's important to be explicit about the chunks, because I think None means that zarr could decide otherwise - chunks = shape - return ZarrInfoForH5Dataset( - shape=shape, - chunks=chunks, - dtype=dtype, - filters=filters, - fill_value=h5_dataset.fillvalue, - object_codec=None, - inline_data=None - ) - elif dtype.kind == 'O': - # For type object, we are going to use the JSON codec - # which requires inline data of the form [list, of, items, ..., '|O', [n1, n2, ...]] - object_codec = numcodecs.JSON() - data = h5_dataset[:] - data_vec_view = data.ravel() - for i, val in enumerate(data_vec_view): - if isinstance(val, bytes): - data_vec_view[i] = val.decode() - elif isinstance(val, str): - data_vec_view[i] = val - elif isinstance(val, h5py.Reference): - data_vec_view[i] = _h5_ref_to_zarr_attr(val, label=f'{h5_dataset.name}[{i}]', h5f=h5_dataset.file) - else: - raise Exception(f'Cannot handle dataset {h5_dataset.name} with dtype {dtype} and shape {shape}') - inline_data = json.dumps(data.tolist() + ['|O', list(shape)], separators=(',', ':')).encode('utf-8') - return ZarrInfoForH5Dataset( - shape=shape, - chunks=shape, # be explicit about chunks - dtype=dtype, - filters=None, - fill_value=' ', # not sure what to put here - object_codec=object_codec, - inline_data=inline_data - ) - elif dtype.kind in 'SU': # byte string or unicode string - raise Exception(f'Not yet implemented (2): dataset {h5_dataset.name} with dtype {dtype} and shape {shape}') - elif dtype.kind == 'V': # void (i.e. compound) - if h5_dataset.ndim == 1: - # for now we only handle the case of a 1D compound dataset - data = h5_dataset[:] - # Create an array that would be for example like this - # dtype = np.dtype([('x', np.float64), ('y', np.int32), ('weight', np.float64)]) - # array_list = [[3, 4, 5.3], [2, 1, 7.1], ...] - # where the first entry corresponds to x in the example above, the second to y, and the third to weight - # This is a more compact representation than [{'x': ...}] - # The _COMPOUND_DTYPE attribute will be set on the dataset in the zarr store - # which will be used to interpret the data - array_list = [ - [ - _json_serialize(data[name][i], dtype[name], h5_dataset) - for name in dtype.names - ] - for i in range(h5_dataset.shape[0]) - ] - object_codec = numcodecs.JSON() - inline_data = array_list + ['|O', list(shape)] - return ZarrInfoForH5Dataset( - shape=shape, - chunks=shape, # be explicit about chunks - dtype='object', - filters=None, - fill_value=' ', # not sure what to put here - object_codec=object_codec, - inline_data=json.dumps(inline_data, separators=(',', ':')).encode('utf-8') - ) - else: - raise Exception(f'More than one dimension not supported for compound dataset {h5_dataset.name} with dtype {dtype} and shape {shape}') - else: - print(dtype.kind) - raise Exception(f'Not yet implemented (3): dataset {h5_dataset.name} with dtype {dtype} and shape {shape}') - - -def _json_serialize(val: Any, dtype: np.dtype, h5_dataset: h5py.Dataset) -> Any: - if dtype.kind in ['i', 'u']: # integer, unsigned integer - return int(val) - elif dtype.kind == 'f': # float - return float(val) - elif dtype.kind == 'b': # boolean - return bool(val) - elif dtype.kind == 'S': # byte string - return val.decode() - elif dtype.kind == 'U': # unicode string - return val - elif dtype == h5py.Reference: - return _h5_ref_to_zarr_attr(val, label=f'{h5_dataset.name}', h5f=h5_dataset.file) - else: - raise Exception(f'Cannot serialize item {val} with dtype {dtype} when serializing dataset {h5_dataset.name} with compound dtype.') - - -def _get_numeric_format_str(dtype: Any) -> Union[str, None]: - """Get the format string for a numeric dtype. - - This is used to convert a scalar dataset to inline data using struct.pack. - """ - if dtype == np.int8: - return '" if it represents an HDF5 reference + for i in range(len(compound_dtype_obj)): + if compound_dtype_obj[i][1] == '': + compound_dtype_obj[i][1] = h5py.special_dtype(ref=h5py.Reference) # If we have a compound dtype, then create the numpy dtype self._compound_dtype = np.dtype( - [(compound_dtype_obj[i][0], compound_dtype_obj[i][1]) for i in range(len(compound_dtype_obj))] + [ + ( + compound_dtype_obj[i][0], + compound_dtype_obj[i][1] + ) + for i in range(len(compound_dtype_obj)) + ] ) else: self._compound_dtype = None @@ -49,6 +64,14 @@ def __init__(self, _dataset_object: Union[h5py.Dataset, zarr.Array], _file: "Lin else: self._is_scalar = self._dataset_object.ndim == 0 + # The self._write object handles all the writing operations + from .writers.LindiH5pyDatasetWriter import LindiH5pyDatasetWriter # avoid circular import + + if self._readonly: + self._writer = None + else: + self._writer = LindiH5pyDatasetWriter(self) + @property def id(self): if isinstance(self._dataset_object, h5py.Dataset): @@ -119,7 +142,25 @@ def attrs(self): # type: ignore attrs_type = 'zarr' else: raise Exception(f'Unexpected dataset object type: {type(self._dataset_object)}') - return LindiH5pyAttributes(self._dataset_object.attrs, attrs_type=attrs_type) + return LindiH5pyAttributes(self._dataset_object.attrs, attrs_type=attrs_type, readonly=self._file.mode == 'r') + + @property + def fletcher32(self): + if isinstance(self._dataset_object, h5py.Dataset): + return self._dataset_object.fletcher32 + elif isinstance(self._dataset_object, zarr.Array): + for f in self._dataset_object.filters: + if f.__class__.__name__ == 'Fletcher32': + return True + return False + else: + raise Exception(f'Unexpected dataset object type: {type(self._dataset_object)}') + + def __repr__(self): # type: ignore + return f"<{self.__class__.__name__}: {self.name}>" + + def __str__(self): + return f"<{self.__class__.__name__}: {self.name}>" def __getitem__(self, args, new_dtype=None): if isinstance(self._dataset_object, h5py.Dataset): @@ -128,7 +169,6 @@ def __getitem__(self, args, new_dtype=None): if new_dtype is not None: raise Exception("new_dtype is not supported for zarr.Array") ret = self._get_item_for_zarr(self._dataset_object, args) - ret = _resolve_references(ret) else: raise Exception(f"Unexpected type: {type(self._dataset_object)}") return ret @@ -178,7 +218,7 @@ def _get_item_for_zarr(self, zarr_array: zarr.Array, selection: Any): if selection != (): raise TypeError(f'Cannot slice a scalar dataset with {selection}') return zarr_array[0] - return zarr_array[selection] + return decode_references(zarr_array[selection]) def _get_external_hdf5_client(self, url: str) -> h5py.File: if url not in _external_hdf5_clients: @@ -189,25 +229,20 @@ def _get_external_hdf5_client(self, url: str) -> h5py.File: _external_hdf5_clients[url] = h5py.File(ff, "r") return _external_hdf5_clients[url] + @property + def ref(self): + if self._readonly: + raise ValueError("Cannot get ref on read-only object") + assert self._writer is not None + return self._writer.ref -def _resolve_references(x: Any): - if isinstance(x, dict): - # x should only be a dict when x represents a converted reference - if '_REFERENCE' in x: - return LindiH5pyReference(x['_REFERENCE']) - else: # pragma: no cover - raise Exception(f"Unexpected dict in selection: {x}") - elif isinstance(x, list): - # Replace any references in the list with the resolved ref in-place - for i, v in enumerate(x): - x[i] = _resolve_references(v) - elif isinstance(x, np.ndarray): - if x.dtype == object or x.dtype is None: - # Replace any references in the object array with the resolved ref in-place - view_1d = x.reshape(-1) - for i in range(len(view_1d)): - view_1d[i] = _resolve_references(view_1d[i]) - return x + ############################## + # Write + def __setitem__(self, args, val): + if self._readonly: + raise ValueError("Cannot set items on read-only object") + assert self._writer is not None + self._writer.__setitem__(args, val) class LindiH5pyDatasetCompoundFieldSelection: @@ -271,4 +306,4 @@ def size(self): return self._data.size def __getitem__(self, selection): - return self._data[selection] + return decode_references(self._data[selection]) diff --git a/lindi/LindiH5pyFile/LindiH5pyFile.py b/lindi/LindiH5pyFile/LindiH5pyFile.py index b032686..d25f2fb 100644 --- a/lindi/LindiH5pyFile/LindiH5pyFile.py +++ b/lindi/LindiH5pyFile/LindiH5pyFile.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Literal import json import tempfile import urllib.request @@ -14,19 +14,35 @@ class LindiH5pyFile(h5py.File): - def __init__(self, _file_object: Union[h5py.File, zarr.Group]): + def __init__(self, _file_object: Union[h5py.File, zarr.Group], *, _zarr_store: Union[ZarrStore, None] = None, _mode: Literal["r", "r+"] = "r"): """ Do not use this constructor directly. Instead, use: from_reference_file_system, from_zarr_store, from_zarr_group, or from_h5py_file """ self._file_object = _file_object + self._zarr_store = _zarr_store + self._mode: Literal['r', 'r+'] = _mode self._the_group = LindiH5pyGroup(_file_object, self) @staticmethod - def from_reference_file_system(rfs: Union[dict, str]): + def from_reference_file_system(rfs: Union[dict, str], mode: Literal["r", "r+"] = "r"): """ Create a LindiH5pyFile from a reference file system. + + Parameters + ---------- + rfs : Union[dict, str] + The reference file system. This can be a dictionary or a URL or path + to a .zarr.json file. + mode : Literal["r", "r+"], optional + The mode to open the file object in, by default "r". If the mode is + "r", the file object will be read-only. If the mode is "r+", the + file will be read-write. However, if the rfs is a string (URL or + path), the file itself will not be modified on changes, but the + internal in-memory representation will be modified. Use + to_reference_file_system() to export the updated reference file + system to the same file or a new file. """ if isinstance(rfs, str): if rfs.startswith("http") or rfs.startswith("https"): @@ -36,44 +52,88 @@ def from_reference_file_system(rfs: Union[dict, str]): with open(filename, "r") as f: data = json.load(f) assert isinstance(data, dict) # prevent infinite recursion - return LindiH5pyFile.from_reference_file_system(data) + return LindiH5pyFile.from_reference_file_system(data, mode=mode) else: with open(rfs, "r") as f: data = json.load(f) assert isinstance(data, dict) # prevent infinite recursion - return LindiH5pyFile.from_reference_file_system(data) + return LindiH5pyFile.from_reference_file_system(data, mode=mode) elif isinstance(rfs, dict): # This store does not need to be closed store = LindiReferenceFileSystemStore(rfs) - return LindiH5pyFile.from_zarr_store(store) + return LindiH5pyFile.from_zarr_store(store, mode=mode) else: raise Exception(f"Unhandled type for rfs: {type(rfs)}") @staticmethod - def from_zarr_store(zarr_store: ZarrStore): + def from_zarr_store(zarr_store: ZarrStore, mode: Literal["r", "r+"] = "r"): """ Create a LindiH5pyFile from a zarr store. + + Parameters + ---------- + zarr_store : ZarrStore + The zarr store. + mode : Literal["r", "r+"], optional + The mode to open the file object in, by default "r". If the mode is + "r", the file object will be read-only. For write mode to work, the + zarr store will need to be writeable as well. """ # note that even though the function is called "open", the zarr_group # does not need to be closed - zarr_group = zarr.open(store=zarr_store, mode="r") + zarr_group = zarr.open(store=zarr_store, mode=mode) assert isinstance(zarr_group, zarr.Group) - return LindiH5pyFile.from_zarr_group(zarr_group) + return LindiH5pyFile.from_zarr_group(zarr_group, _zarr_store=zarr_store, mode=mode) @staticmethod - def from_zarr_group(zarr_group: zarr.Group): + def from_zarr_group(zarr_group: zarr.Group, *, mode: Literal["r", "r+"] = "r", _zarr_store: Union[ZarrStore, None] = None): """ Create a LindiH5pyFile from a zarr group. + + Parameters + ---------- + zarr_group : zarr.Group + The zarr group. + mode : Literal["r", "r+"], optional + The mode to open the file object in, by default "r". If the mode is + "r", the file object will be read-only. For write mode to work, the + zarr store will need to be writeable as well. + _zarr_store : Union[ZarrStore, None], optional + The zarr store, internally set for use with + to_reference_file_system(). + + See from_zarr_store(). """ - return LindiH5pyFile(zarr_group) + return LindiH5pyFile(zarr_group, _zarr_store=_zarr_store, _mode=mode) @staticmethod def from_h5py_file(h5py_file: h5py.File): """ Create a LindiH5pyFile from an h5py file. + + This is used mainly for testing and may be removed in the future. + + Parameters + ---------- + h5py_file : h5py.File + The h5py file. """ return LindiH5pyFile(h5py_file) + def to_reference_file_system(self): + """ + Export the internal in-memory representation to a reference file system. + In order to use this, the file object needs to have been created using + from_reference_file_system(). + """ + if self._zarr_store is None: + raise Exception("Cannot convert to reference file system without zarr store") + if not isinstance(self._zarr_store, LindiReferenceFileSystemStore): + raise Exception(f"Unexpected type for zarr store: {type(self._zarr_store)}") + rfs = self._zarr_store.rfs + rfs_copy = json.loads(json.dumps(rfs)) + return rfs_copy + @property def attrs(self): # type: ignore if isinstance(self._file_object, h5py.File): @@ -82,7 +142,7 @@ def attrs(self): # type: ignore attrs_type = 'zarr' else: raise Exception(f'Unexpected file object type: {type(self._file_object)}') - return LindiH5pyAttributes(self._file_object.attrs, attrs_type=attrs_type) + return LindiH5pyAttributes(self._file_object.attrs, attrs_type=attrs_type, readonly=self.mode == "r") @property def filename(self): @@ -98,15 +158,9 @@ def filename(self): def driver(self): raise Exception("Getting driver is not allowed") - # @property - # def mode(self): - # if isinstance(self._file_object, h5py.File): - # return self._file_object.mode - # elif isinstance(self._file_object, zarr.Group): - # # hard-coded to read-only - # return "r" - # else: - # raise Exception(f"Unhandled type: {type(self._file_object)}") + @property + def mode(self): + return self._mode @property def libver(self): @@ -137,9 +191,54 @@ def __enter__(self): # type: ignore def __exit__(self, *args): self.close() + def __str__(self): + return f'' + def __repr__(self): return f'' + def __bool__(self): + # This is called when checking if the file is open + if isinstance(self._file_object, h5py.File): + return self._file_object.__bool__() + elif isinstance(self._file_object, zarr.Group): + return True + else: + raise Exception(f"Unexpected type for file object: {type(self._file_object)}") + + def __hash__(self): + # This is called for example when using a file as a key in a dictionary + if isinstance(self._file_object, h5py.File): + return self._file_object.__hash__() + else: + return id(self) + + def copy(self, source, dest, name=None, + shallow=False, expand_soft=False, expand_external=False, + expand_refs=False, without_attrs=False): + if shallow: + raise Exception("shallow is not implemented for copy") + if expand_soft: + raise Exception("expand_soft is not implemented for copy") + if expand_external: + raise Exception("expand_external is not implemented for copy") + if expand_refs: + raise Exception("expand_refs is not implemented for copy") + if without_attrs: + raise Exception("without_attrs is not implemented for copy") + if name is None: + raise Exception("name must be provided for copy") + src_item = self._get_item(source) + if not isinstance(src_item, (h5py.Group, h5py.Dataset)): + raise Exception(f"Unexpected type for source in copy: {type(src_item)}") + _recursive_copy(src_item, dest, name=name) + + def __delitem__(self, name): + parent_key = '/'.join(name.split('/')[:-1]) + grp = self[parent_key] + assert isinstance(grp, LindiH5pyGroup) + del grp[name.split('/')[-1]] + # Group methods def __getitem__(self, name): # type: ignore return self._get_item(name) @@ -209,6 +308,34 @@ def file(self): def name(self): return self._the_group.name + @property + def ref(self): + return self._the_group.ref + + ############################## + # write + def create_group(self, name, track_order=None): + if self._mode not in ['r+']: + raise Exception("Cannot create group in read-only mode") + if track_order is not None: + raise Exception("track_order is not supported (I don't know what it is)") + return self._the_group.create_group(name) + + def require_group(self, name): + if self._mode not in ['r+']: + raise Exception("Cannot require group in read-only mode") + return self._the_group.require_group(name) + + def create_dataset(self, name, shape=None, dtype=None, data=None, **kwds): + if self._mode not in ['r+']: + raise Exception("Cannot create dataset in read-only mode") + return self._the_group.create_dataset(name, shape=shape, dtype=dtype, data=data, **kwds) + + def require_dataset(self, name, shape, dtype, exact=False, **kwds): + if self._mode not in ['r+']: + raise Exception("Cannot require dataset in read-only mode") + return self._the_group.require_dataset(name, shape, dtype, exact=exact, **kwds) + def _download_file(url: str, filename: str) -> None: headers = { @@ -218,3 +345,60 @@ def _download_file(url: str, filename: str) -> None: with urllib.request.urlopen(req) as response: with open(filename, "wb") as f: f.write(response.read()) + + +def _recursive_copy(src_item: Union[h5py.Group, h5py.Dataset], dest: h5py.File, name: str) -> None: + if isinstance(src_item, h5py.Group): + dst_item = dest.create_group(name) + for k, v in src_item.attrs.items(): + dst_item.attrs[k] = v + for k, v in src_item.items(): + _recursive_copy(v, dest, name=f'{name}/{k}') + elif isinstance(src_item, h5py.Dataset): + # Let's specially handle the case where the source and dest files + # are LindiH5pyFiles with reference file systems as the internal + # representation. In this case, we don't need to copy the actual + # data because we can copy the reference. + if isinstance(src_item.file, LindiH5pyFile) and isinstance(dest, LindiH5pyFile): + if src_item.name is None: + raise Exception("src_item.name is None") + src_item_name = _without_initial_slash(src_item.name) + src_zarr_store = src_item.file._zarr_store + dst_zarr_store = dest._zarr_store + if src_zarr_store is not None and dst_zarr_store is not None: + if isinstance(src_zarr_store, LindiReferenceFileSystemStore) and isinstance(dst_zarr_store, LindiReferenceFileSystemStore): + src_rfs = src_zarr_store.rfs + dst_rfs = dst_zarr_store.rfs + src_ref_keys = list(src_rfs['refs'].keys()) + for src_ref_key in src_ref_keys: + if src_ref_key.startswith(f'{src_item_name}/'): + dst_ref_key = f'{name}/{src_ref_key[len(src_item_name) + 1:]}' + # Even though it's not expected to be a problem, we + # do a deep copy here because a problem resulting + # from one rfs being modified affecting another + # would be very difficult to debug. + dst_rfs['refs'][dst_ref_key] = _deep_copy(src_rfs['refs'][src_ref_key]) + return + + dst_item = dest.create_dataset(name, data=src_item[()]) + for k, v in src_item.attrs.items(): + dst_item.attrs[k] = v + else: + raise Exception(f"Unexpected type for src_item in _recursive_copy: {type(src_item)}") + + +def _without_initial_slash(s: str) -> str: + if s.startswith('/'): + return s[1:] + return s + + +def _deep_copy(obj): + if isinstance(obj, dict): + return {k: _deep_copy(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_deep_copy(v) for v in obj] + elif isinstance(obj, tuple): + return tuple(_deep_copy(v) for v in obj) + else: + return obj diff --git a/lindi/LindiH5pyFile/LindiH5pyGroup.py b/lindi/LindiH5pyFile/LindiH5pyGroup.py index 48b3d90..1fbe8be 100644 --- a/lindi/LindiH5pyFile/LindiH5pyGroup.py +++ b/lindi/LindiH5pyFile/LindiH5pyGroup.py @@ -20,6 +20,14 @@ class LindiH5pyGroup(h5py.Group): def __init__(self, _group_object: Union[h5py.Group, zarr.Group], _file: "LindiH5pyFile"): self._group_object = _group_object self._file = _file + self._readonly = _file.mode not in ['r+'] + + # The self._write object handles all the writing operations + from .writers.LindiH5pyGroupWriter import LindiH5pyGroupWriter # avoid circular import + if self._readonly: + self._writer = None + else: + self._writer = LindiH5pyGroupWriter(self) def __getitem__(self, name): if isinstance(self._group_object, h5py.Group): @@ -49,12 +57,12 @@ def __getitem__(self, name): soft_link = x.attrs.get('_SOFT_LINK', None) if soft_link is not None: link_path = soft_link['path'] - target_grp = self._file.get(link_path) - if not isinstance(target_grp, LindiH5pyGroup): + target_item = self._file.get(link_path) + if not isinstance(target_item, (LindiH5pyGroup, LindiH5pyDataset)): raise Exception( - f"Expected a group at {link_path} but got {type(x)}" + f"Expected a group or dataset at {link_path} but got {type(target_item)}" ) - return target_grp + return target_item return LindiH5pyGroup(x, self._file) elif isinstance(x, zarr.Array): return LindiH5pyDataset(x, self._file) @@ -115,12 +123,22 @@ def __reversed__(self): def __contains__(self, name): return self._group_object.__contains__(name) + def __str__(self): + return f'<{self.__class__.__name__}: {self.name}>' + + def __repr__(self): + return f'<{self.__class__.__name__}: {self.name}>' + @property def id(self): if isinstance(self._group_object, h5py.Group): return LindiH5pyGroupId(self._group_object.id) elif isinstance(self._group_object, zarr.Group): - return LindiH5pyGroupId(None) + # This is commented out for now because pynwb gets the id of a group + # in at least one place. But that could be avoided in the future, at + # which time, we could uncomment this. + # print('WARNING: Accessing low-level id of LindiH5pyGroup. This should be avoided.') + return LindiH5pyGroupId('') else: raise Exception(f'Unexpected group object type: {type(self._group_object)}') @@ -136,4 +154,49 @@ def attrs(self): # type: ignore attrs_type = 'zarr' else: raise Exception(f'Unexpected group object type: {type(self._group_object)}') - return LindiH5pyAttributes(self._group_object.attrs, attrs_type=attrs_type) + return LindiH5pyAttributes(self._group_object.attrs, attrs_type=attrs_type, readonly=self._file.mode == 'r') + + @property + def ref(self): + if self._readonly: + raise ValueError("Cannot get ref on read-only object") + assert self._writer is not None + return self._writer.ref + + ############################## + # write + def create_group(self, name, track_order=None): + if self._readonly: + raise Exception('Cannot create group in read-only mode') + assert self._writer is not None + return self._writer.create_group(name, track_order=track_order) + + def require_group(self, name): + if self._readonly: + raise Exception('Cannot require group in read-only mode') + assert self._writer is not None + return self._writer.require_group(name) + + def create_dataset(self, name, shape=None, dtype=None, data=None, **kwds): + if self._readonly: + raise Exception('Cannot create dataset in read-only mode') + assert self._writer is not None + return self._writer.create_dataset(name, shape=shape, dtype=dtype, data=data, **kwds) + + def require_dataset(self, name, shape, dtype, exact=False, **kwds): + if self._readonly: + raise Exception('Cannot require dataset in read-only mode') + assert self._writer is not None + return self._writer.require_dataset(name, shape, dtype, exact=exact, **kwds) + + def __setitem__(self, name, obj): + if self._readonly: + raise Exception('Cannot set item in read-only mode') + assert self._writer is not None + return self._writer.__setitem__(name, obj) + + def __delitem__(self, name): + if self._readonly: + raise Exception('Cannot delete item in read-only mode') + assert self._writer is not None + return self._writer.__delitem__(name) diff --git a/lindi/LindiH5pyFile/LindiH5pyReference.py b/lindi/LindiH5pyFile/LindiH5pyReference.py index e1bda22..d12d7cd 100644 --- a/lindi/LindiH5pyFile/LindiH5pyReference.py +++ b/lindi/LindiH5pyFile/LindiH5pyReference.py @@ -3,6 +3,7 @@ class LindiH5pyReference(h5py.h5r.Reference): def __init__(self, obj: dict): + self._obj = obj self._object_id = obj["object_id"] self._path = obj["path"] self._source = obj["source"] diff --git a/lindi/LindiH5pyFile/LindiReferenceFileSystemStore.py b/lindi/LindiH5pyFile/LindiReferenceFileSystemStore.py index ccd351d..e98ce22 100644 --- a/lindi/LindiH5pyFile/LindiReferenceFileSystemStore.py +++ b/lindi/LindiH5pyFile/LindiReferenceFileSystemStore.py @@ -39,8 +39,12 @@ class LindiReferenceFileSystemStore(ZarrStore): being returned. Otherwise, the string is utf-8 encoded and returned as is. Note that a file that actually begins with "base64:" should be represented by a base64 encoded string, to avoid ambiguity. + + It is okay for rfs to be modified outside of this class, and the changes + will be reflected immediately in the store. This can be used by experimental + tools such as lindi-cloud. """ - def __init__(self, rfs: dict, mode: Literal["r"] = "r"): + def __init__(self, rfs: dict, mode: Literal["r", "r+"] = "r+"): """ Create a LindiReferenceFileSystemStore. @@ -94,13 +98,19 @@ def __getitem__(self, key: str): else: # should not happen given checks in __init__, but self.rfs is mutable # and contains mutable lists - raise Exception(f"Problem with {key}: value must be a string or a list") + raise Exception(f"Problem with {key}: value {x} must be a string or a list") def __setitem__(self, key: str, value): - raise Exception("Setting items is not allowed") + try: + # try to ascii encode the value + value = value.decode("ascii") + except UnicodeDecodeError: + # if that fails, base64 encode it + value = "base64:" + base64.b64encode(value).decode("ascii") + self.rfs["refs"][key] = value def __delitem__(self, key: str): - raise Exception("Deleting items is not allowed") + del self.rfs["refs"][key] def __iter__(self): return iter(self.rfs["refs"]) @@ -110,10 +120,10 @@ def __len__(self): # These methods are overridden from BaseStore def is_readable(self): - return self.mode in ["r"] + return self.mode in ["r", "r+"] def is_writeable(self): - return False + return self.mode in ["r+"] def is_listable(self): return True diff --git a/lindi/LindiH5pyFile/writers/LindiH5pyAttributesWriter.py b/lindi/LindiH5pyFile/writers/LindiH5pyAttributesWriter.py new file mode 100644 index 0000000..d2908ab --- /dev/null +++ b/lindi/LindiH5pyFile/writers/LindiH5pyAttributesWriter.py @@ -0,0 +1,20 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..LindiH5pyAttributes import LindiH5pyAttributes # pragma: no cover + + +class LindiH5pyAttributesWriter: + def __init__(self, p: 'LindiH5pyAttributes'): + self.p = p + + def __setitem__(self, key, value): + if self.p._readonly: + raise KeyError("Cannot set attributes on read-only object") + if self.p._attrs_type == "h5py": + self.p._attrs[key] = value + elif self.p._attrs_type == "zarr": + from ...conversion.attr_conversion import h5_to_zarr_attr # avoid circular import + self.p._attrs[key] = h5_to_zarr_attr(value, h5f=None) + else: + raise ValueError(f"Unknown attrs_type: {self.p._attrs_type}") diff --git a/lindi/LindiH5pyFile/writers/LindiH5pyDatasetWriter.py b/lindi/LindiH5pyFile/writers/LindiH5pyDatasetWriter.py new file mode 100644 index 0000000..c455e79 --- /dev/null +++ b/lindi/LindiH5pyFile/writers/LindiH5pyDatasetWriter.py @@ -0,0 +1,50 @@ +from typing import Any, TYPE_CHECKING +import h5py +import zarr +import numpy as np + +from ..LindiH5pyReference import LindiH5pyReference +from ...conversion._util import _is_numeric_dtype +from ...conversion.create_zarr_dataset_from_h5_data import h5_object_data_to_zarr_data + +if TYPE_CHECKING: + from ..LindiH5pyDataset import LindiH5pyDataset # pragma: no cover + + +class LindiH5pyDatasetWriter: + def __init__(self, p: 'LindiH5pyDataset'): + self.p = p + + def __setitem__(self, args, val): + if isinstance(self.p._dataset_object, h5py.Dataset): + self.p._dataset_object.__setitem__(args, val) + elif isinstance(self.p._dataset_object, zarr.Array): + self._set_item_for_zarr(self.p._dataset_object, args, val) + else: + raise Exception(f"Unexpected type: {type(self.p._dataset_object)}") + + def _set_item_for_zarr(self, zarr_array: zarr.Array, selection: Any, val: Any): + if self.p._compound_dtype is not None: + raise Exception("Setting compound dataset is not implemented") + if self.p.ndim == 0: + if selection != (): + raise TypeError(f'Cannot slice a scalar dataset with {selection}') + zarr_array[0] = val + else: + dtype = zarr_array.dtype + if _is_numeric_dtype(dtype) or dtype in [bool, np.bool_]: + # this is the usual numeric case + zarr_array[selection] = val + elif dtype.kind == 'O': + zarr_array[selection] = h5_object_data_to_zarr_data(val, h5f=None, label='') + else: + raise Exception(f'Unsupported dtype for slice setting {dtype} in {self.p.name}') + + @property + def ref(self): + return LindiH5pyReference({ + 'object_id': self.p.attrs.get('object_id', None), + 'path': self.p.name, + 'source': '.', + 'source_object_id': self.p.file.attrs.get('object_id', None) + }) diff --git a/lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py b/lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py new file mode 100644 index 0000000..a6ca6bf --- /dev/null +++ b/lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py @@ -0,0 +1,126 @@ +from typing import TYPE_CHECKING +import h5py +import numpy as np +import zarr + +from ..LindiH5pyDataset import LindiH5pyDataset +from ..LindiH5pyReference import LindiH5pyReference + +if TYPE_CHECKING: + from ..LindiH5pyGroup import LindiH5pyGroup # pragma: no cover + +from ...conversion.create_zarr_dataset_from_h5_data import create_zarr_dataset_from_h5_data + + +class LindiH5pyGroupWriter: + def __init__(self, p: 'LindiH5pyGroup'): + self.p = p + + def create_group(self, name, track_order=None): + from ..LindiH5pyGroup import LindiH5pyGroup # avoid circular import + if track_order is not None: + raise Exception("track_order is not supported (I don't know what it is)") + if isinstance(self.p._group_object, h5py.Group): + return LindiH5pyGroup( + self.p._group_object.create_group(name), self.p._file + ) + elif isinstance(self.p._group_object, zarr.Group): + return LindiH5pyGroup( + self.p._group_object.create_group(name), self.p._file + ) + else: + raise Exception(f'Unexpected group object type: {type(self.p._group_object)}') + + def require_group(self, name): + if name in self.p: + ret = self.p[name] + if not isinstance(ret, LindiH5pyGroup): + raise Exception(f'Expected a group at {name} but got {type(ret)}') + return ret + return self.create_group(name) + + def create_dataset(self, name, shape=None, dtype=None, data=None, **kwds): + chunks = None + for k, v in kwds.items(): + if k == 'chunks': + chunks = v + else: + raise Exception(f'Unsupported kwds in create_dataset: {k}') + + if isinstance(self.p._group_object, h5py.Group): + return LindiH5pyDataset( + self._group_object.create_dataset(name, shape=shape, dtype=dtype, data=data, chunks=chunks), # type: ignore + self.p._file + ) + elif isinstance(self.p._group_object, zarr.Group): + if isinstance(data, list): + data = np.array(data) + if shape is None: + if data is None: + raise Exception('shape or data must be provided') + if isinstance(data, np.ndarray): + shape = data.shape + else: + shape = () + if dtype is None: + if data is None: + raise Exception('dtype or data must be provided') + if isinstance(data, np.ndarray): + dtype = data.dtype + else: + dtype = np.dtype(type(data)) + ds = create_zarr_dataset_from_h5_data( + zarr_parent_group=self.p._group_object, + name=name, + label=(self.p.name or '') + '/' + name, + h5_chunks=chunks, + h5_shape=shape, + h5_dtype=dtype, + h5_data=data, + h5f=None + ) + return LindiH5pyDataset(ds, self.p._file) + else: + raise Exception(f'Unexpected group object type: {type(self.p._group_object)}') + + def require_dataset(self, name, shape, dtype, exact=False, **kwds): + if name in self.p: + ret = self.p[name] + if not isinstance(ret, LindiH5pyDataset): + raise Exception(f'Expected a dataset at {name} but got {type(ret)}') + if ret.shape != shape: + raise Exception(f'Expected shape {shape} but got {ret.shape}') + if exact: + if ret.dtype != dtype: + raise Exception(f'Expected dtype {dtype} but got {ret.dtype}') + else: + if not np.can_cast(ret.dtype, dtype): + raise Exception(f'Cannot cast dtype {ret.dtype} to {dtype}') + return ret + return self.create_dataset(name, *(shape, dtype), **kwds) + + def __setitem__(self, name, obj): + if isinstance(obj, h5py.SoftLink): + if isinstance(self.p._group_object, h5py.Group): + self.p._group_object[name] = obj + elif isinstance(self.p._group_object, zarr.Group): + grp = self.p.create_group(name) + grp._group_object.attrs['_SOFT_LINK'] = { + 'path': obj.path + } + else: + raise Exception(f'Unexpected group object type: {type(self.p._group_object)}') + else: + raise Exception(f'Unexpected type for obj in __setitem__: {type(obj)}') + + def __delitem__(self, name): + del self.p._group_object[name] + + @property + def ref(self): + return LindiH5pyReference({ + 'object_id': self.p.attrs.get('object_id', None), + 'path': self.p.name, + 'source': '.', + 'source_object_id': self.p.file.attrs.get('object_id', None) + }) diff --git a/lindi/LindiH5pyFile/writers/__init__.py b/lindi/LindiH5pyFile/writers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lindi/conversion/__init__.py b/lindi/conversion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lindi/conversion/_util.py b/lindi/conversion/_util.py new file mode 100644 index 0000000..721b01d --- /dev/null +++ b/lindi/conversion/_util.py @@ -0,0 +1,6 @@ +import numpy as np + + +def _is_numeric_dtype(dtype: np.dtype) -> bool: + """Return True if the dtype is a numeric dtype.""" + return np.issubdtype(dtype, np.number) diff --git a/lindi/conversion/attr_conversion.py b/lindi/conversion/attr_conversion.py new file mode 100644 index 0000000..a3f0a19 --- /dev/null +++ b/lindi/conversion/attr_conversion.py @@ -0,0 +1,162 @@ +from typing import Any, Union +import numpy as np +import h5py +from .nan_inf_ninf import encode_nan_inf_ninf +from .h5_ref_to_zarr_attr import h5_ref_to_zarr_attr + + +def h5_to_zarr_attr(attr: Any, *, label: str = '', h5f: Union[h5py.File, None]): + """Convert an attribute from h5py to a format that zarr can accept.""" + + from ..LindiH5pyFile.LindiH5pyReference import LindiH5pyReference # Avoid circular import + + # Do not allow these special strings in attributes + special_strings = ['NaN', 'Infinity', '-Infinity'] + + if isinstance(attr, list): + list_dtype = _determine_list_dtype(attr) + attr = np.array(attr, dtype=list_dtype) + + if isinstance(attr, str) and attr in special_strings: + raise ValueError(f"Special string {attr} not allowed in attribute value at {label}") + if attr is None: + raise Exception(f"Unexpected h5 attribute: None at {label}") + elif type(attr) in [int, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64]: + return int(attr) + elif isinstance(attr, (float, np.floating)): + return encode_nan_inf_ninf(float(attr)) + elif isinstance(attr, (complex, np.complexfloating)): + raise Exception(f"Complex number is not supported at {label}") + elif type(attr) in [bool, np.bool_]: + return bool(attr) + elif isinstance(attr, (bool, list, tuple, dict, set)): + raise Exception(f"Unexpected type for h5 attribute: {type(attr)} at {label}") + elif isinstance(attr, str): + return attr + elif isinstance(attr, bytes): + return attr.decode('utf-8') + elif isinstance(attr, np.ndarray): + if attr.dtype.kind in ['i', 'u']: + return attr.tolist() # this will be a nested list of type int + elif attr.dtype.kind in ['f']: + return encode_nan_inf_ninf(attr.tolist()) # this will be a nested list of type float + elif attr.dtype.kind in ['c']: + raise Exception(f"Arrays of complex numbers are not supported at {label}") + elif attr.dtype.kind == 'b': + return attr.tolist() # this will be a nested list of type bool + elif attr.dtype.kind == 'O': + x = attr.tolist() + if not _nested_list_has_all_strings(x): + raise Exception(f"Not allowed for attribute: numpy array with dtype=object that contains non-string elements at {label}") + return x + elif attr.dtype.kind == 'U': + return _decode_bytes_to_str_in_nested_list(attr.tolist()) + elif attr.dtype.kind == 'S': + return _decode_bytes_to_str_in_nested_list(attr.tolist()) + else: + raise Exception(f"Unexpected dtype for attribute numpy array: {attr.dtype} at {label}") + elif isinstance(attr, LindiH5pyReference): + return { + '_REFERENCE': attr._obj + } + elif isinstance(attr, h5py.Reference): + if h5f is None: + raise Exception(f"h5f cannot be None when converting h5py.Reference to zarr attribute at {label}") + return h5_ref_to_zarr_attr(attr, h5f=h5f) + else: + raise Exception(f"Unexpected type for h5 attribute: {type(attr)} at {label}") + + +def _decode_bytes_to_str_in_nested_list(x): + if isinstance(x, bytes): + return x.decode('utf-8') + elif isinstance(x, str): + return x + elif isinstance(x, list): + return [_decode_bytes_to_str_in_nested_list(y) for y in x] + else: + raise Exception("Unexpected type in _decode_bytes_to_str") + + +def zarr_to_h5_attr(attr: Any): + """Convert an attribute from zarr to a format that h5py expects.""" + if isinstance(attr, str): + return attr + elif isinstance(attr, int): + return attr + elif isinstance(attr, float): + return attr + elif isinstance(attr, bool): + return np.bool_(attr) + elif isinstance(attr, list): + if _nested_list_has_all_strings(attr): + return np.array(attr, dtype='O') + elif _nested_list_has_all_ints(attr): + return np.array(attr, dtype='int64') + elif _nested_list_has_all_floats_or_ints(attr): + return np.array(attr, dtype='float64') + elif _nested_list_has_all_bools(attr): + return np.array(attr, dtype='bool') + else: + raise Exception("Nested list contains mixed types") + else: + raise Exception(f"Unexpected type in zarr attribute: {type(attr)}") + + +def _nested_list_has_all_strings(x): + if isinstance(x, str): + return True + elif isinstance(x, list): + return all(_nested_list_has_all_strings(y) for y in x) + else: + return False + + +def _nested_list_has_all_ints(x): + if isinstance(x, int): + return True + elif isinstance(x, list): + return all(_nested_list_has_all_ints(y) for y in x) + else: + return False + + +def _nested_list_has_all_floats_or_ints(x): + if isinstance(x, (int, float)): + return True + elif isinstance(x, list): + return all(_nested_list_has_all_floats_or_ints(y) for y in x) + else: + return False + + +def _nested_list_has_all_bools(x): + if isinstance(x, bool): + return True + elif isinstance(x, list): + return all(_nested_list_has_all_bools(y) for y in x) + else: + return False + + +def _determine_list_dtype(x): + x_flattened = _flatten_list(x) + if len(x_flattened) == 0: + return np.dtype(np.int64) + if all(isinstance(i, int) for i in x_flattened): + return np.int64 + elif all(isinstance(i, float) for i in x_flattened): + return np.float64 + elif all(isinstance(i, bool) for i in x_flattened): + return np.bool_ + elif all(isinstance(i, str) for i in x_flattened): + return np.dtype('O') + else: + raise Exception("Mixed types in list") + + +def _flatten_list(x): + if isinstance(x, list): + return [a for i in x for a in _flatten_list(i)] + else: + return [x] diff --git a/lindi/conversion/create_zarr_dataset_from_h5_data.py b/lindi/conversion/create_zarr_dataset_from_h5_data.py new file mode 100644 index 0000000..0e36127 --- /dev/null +++ b/lindi/conversion/create_zarr_dataset_from_h5_data.py @@ -0,0 +1,237 @@ +from typing import Union, List, Any, Tuple +from dataclasses import dataclass +import numpy as np +import numcodecs +import h5py +import zarr +from .h5_ref_to_zarr_attr import h5_ref_to_zarr_attr +from .attr_conversion import h5_to_zarr_attr +from ._util import _is_numeric_dtype + + +def create_zarr_dataset_from_h5_data( + zarr_parent_group: zarr.Group, + h5_shape: Tuple, + h5_dtype: Any, + h5_data: Union[Any, None], + h5f: Union[h5py.File, None], + name: str, + label: str, + h5_chunks: Union[Tuple, None] +): + """Create a zarr dataset from an h5py dataset. + + Parameters + ---------- + zarr_parent_group : zarr.Group + The parent group in the zarr hierarchy. The new dataset will be created + in this group. + h5_shape : tuple + The shape of the h5py dataset. + h5_dtype : numpy.dtype + The dtype of the h5py dataset. + h5_data : any + The data of the h5py dataset. If None, the dataset will be created + without data. + h5f : h5py.File + The file that the h5py dataset is in. + name : str + The name of the new dataset in the zarr hierarchy. + label : str + The name of the h5py dataset for error messages. + h5_chunks : tuple + The chunk shape of the h5py dataset. + """ + if h5_dtype is None: + raise Exception(f'No dtype in h5_to_zarr_dataset_prep for dataset {label}') + if len(h5_shape) == 0: + # scalar dataset + # zarr doesn't support scalar datasets, so we make an array of shape (1,) + # and the _SCALAR attribute will be set to True elsewhere to indicate that + # it is a scalar dataset + + if h5_data is None: + raise Exception(f'Data must be provided for scalar dataset {label}') + + if _is_numeric_dtype(h5_dtype) or h5_dtype in [bool, np.bool_]: + # Handle the simple numeric types + ds = zarr_parent_group.create_dataset( + name, + shape=(1,), + chunks=(1,), + data=[h5_data[()]] if isinstance(h5_data, h5py.Dataset) or isinstance(h5_data, np.ndarray) else [h5_data], + ) + ds.attrs['_SCALAR'] = True + return ds + elif h5_dtype.kind == 'O': + # For type object, we are going to use the JSON codec + # for encoding [scalar_value] + scalar_value = h5_data[()] if isinstance(h5_data, h5py.Dataset) or isinstance(h5_data, np.ndarray) else h5_data + if isinstance(scalar_value, (bytes, str)): + if isinstance(scalar_value, bytes): + scalar_value = scalar_value.decode() + ds = zarr_parent_group.create_dataset( + name, + shape=(1,), + chunks=(1,), + data=[scalar_value] + ) + ds.attrs['_SCALAR'] = True + return ds + else: + raise Exception(f'Unsupported scalar value type: {type(scalar_value)}') + elif h5_dtype.kind == 'S': + # byte string + if h5_data is None: + raise Exception(f'Data must be provided for scalar dataset {label}') + scalar_value = h5_data[()] if isinstance(h5_data, h5py.Dataset) or isinstance(h5_data, np.ndarray) else h5_data + ds = zarr_parent_group.create_dataset( + name, + shape=(1,), + chunks=(1,), + data=[scalar_value] + ) + ds.attrs['_SCALAR'] = True + return ds + else: + raise Exception(f'Cannot handle scalar dataset {label} with dtype {h5_dtype}') + else: + # not a scalar dataset + + if isinstance(h5_data, list): + # If we have a list, then we need to convert it to an array + h5_data = np.array(h5_data) + + if _is_numeric_dtype(h5_dtype) or h5_dtype in [bool, np.bool_]: # integer, unsigned integer, float, bool + # This is the normal case of a chunked dataset with a numeric (or boolean) dtype + if h5_chunks is None: + # We require that chunks be specified when writing a dataset with more + # than 1 million elements. This is because zarr may default to + # suboptimal chunking. Note that the default for h5py is to use the + # entire dataset as a single chunk. + total_size = np.prod(h5_shape) if len(h5_shape) > 0 else 1 + if total_size > 1000 * 1000: + raise Exception(f'Chunks must be specified explicitly when writing dataset of shape {h5_shape}') + # Note that we are not using the same filters as in the h5py dataset + return zarr_parent_group.create_dataset( + name, + shape=h5_shape, + chunks=h5_chunks, + dtype=h5_dtype, + data=h5_data + ) + elif h5_dtype.kind == 'O': + # For type object, we are going to use the JSON codec + if h5_data is not None: + if isinstance(h5_data, h5py.Dataset): + h5_data = h5_data[:] + zarr_data = h5_object_data_to_zarr_data(h5_data, h5f=h5f, label=label) + else: + zarr_data = None + object_codec = numcodecs.JSON() + return zarr_parent_group.create_dataset( + name, + shape=h5_shape, + chunks=h5_chunks, + dtype=h5_dtype, + data=zarr_data, + object_codec=object_codec + ) + elif h5_dtype.kind == 'S': # byte string + if h5_data is None: + raise Exception(f'Data must be provided when converting dataset {label} with dtype {h5_dtype}') + return zarr_parent_group.create_dataset( + name, + shape=h5_shape, + chunks=h5_chunks, + dtype=h5_dtype, + data=h5_data + ) + elif h5_dtype.kind == 'U': # unicode string + raise Exception(f'Array of unicode strings not supported: dataset {label} with dtype {h5_dtype} and shape {h5_shape}') + elif h5_dtype.kind == 'V' and h5_dtype.fields is not None: # compound dtype + if h5_data is None: + raise Exception(f'Data must be provided when converting compound dataset {label}') + h5_data_1d_view = h5_data.ravel() + zarr_data = np.empty(h5_shape, dtype='object') + zarr_data_1d_view = zarr_data.ravel() + for i in range(len(h5_data_1d_view)): + elmt = tuple([ + _make_json_serializable( + h5_data_1d_view[i][field_name], + h5_dtype[field_name], + label=f'{label}[{i}].{field_name}', + h5f=h5f + ) + for field_name in h5_dtype.names + ]) + zarr_data_1d_view[i] = elmt + ds = zarr_parent_group.create_dataset( + name, + shape=h5_shape, + chunks=h5_chunks, + dtype='object', + data=zarr_data, + object_codec=numcodecs.JSON() + ) + compound_dtype = [] + for name in h5_dtype.names: + tt = h5_dtype[name] + if tt == h5py.special_dtype(ref=h5py.Reference): + tt = "" + compound_dtype.append((name, str(tt))) + ds.attrs['_COMPOUND_DTYPE'] = compound_dtype + return ds + else: + raise Exception(f'Not yet implemented (3): dataset {label} with dtype {h5_dtype} and shape {h5_shape}') + + +@dataclass +class CreateZarrDatasetInfo: + shape: Tuple + dtype: Any + fill_value: Any + scalar: bool + compound_dtype: Union[List[Tuple[str, str]], None] + + +def _make_json_serializable(val: Any, dtype: np.dtype, label: str, h5f: Union[h5py.File, None]) -> Any: + if dtype.kind in ['i', 'u']: # integer, unsigned integer + return int(val) + elif dtype.kind == 'f': # float + return float(val) + elif dtype.kind == 'b': # boolean + return bool(val) + elif dtype.kind == 'S': # byte string + return val.decode() + elif dtype.kind == 'U': # unicode string + return val + elif dtype == h5py.Reference: + return h5_to_zarr_attr(val, label=label, h5f=h5f) + else: + raise Exception(f'Cannot serialize item {val} with dtype {dtype} when serializing dataset {label} with compound dtype.') + + +def h5_object_data_to_zarr_data(h5_data: Union[np.ndarray, list], *, h5f: Union[h5py.File, None], label: str) -> np.ndarray: + from ..LindiH5pyFile.LindiH5pyReference import LindiH5pyReference # Avoid circular import + if isinstance(h5_data, list): + h5_data = np.array(h5_data) + zarr_data = np.empty(h5_data.shape, dtype='object') + h5_data_1d_view = h5_data.ravel() + zarr_data_1d_view = zarr_data.ravel() + for i, val in enumerate(h5_data_1d_view): + if isinstance(val, bytes): + zarr_data_1d_view[i] = val.decode() + elif isinstance(val, str): + zarr_data_1d_view[i] = val + elif isinstance(val, LindiH5pyReference): + zarr_data_1d_view[i] = { + '_REFERENCE': val._obj + } + elif isinstance(val, h5py.Reference): + if h5f is None: + raise Exception(f'h5f cannot be None when converting h5py.Reference to zarr attribute at {label}') + zarr_data_1d_view[i] = h5_ref_to_zarr_attr(val, h5f=h5f) + else: + raise Exception(f'Cannot handle value of type {type(val)} in dataset {label} with dtype {h5_data.dtype} and shape {h5_data.shape}') + return zarr_data diff --git a/lindi/conversion/decode_references.py b/lindi/conversion/decode_references.py new file mode 100644 index 0000000..50e5789 --- /dev/null +++ b/lindi/conversion/decode_references.py @@ -0,0 +1,27 @@ +from typing import Any +import numpy as np + + +def decode_references(x: Any): + """Decode references in a nested structure. + + See h5_ref_to_zarr_attr() for the encoding of references. + """ + from ..LindiH5pyFile.LindiH5pyReference import LindiH5pyReference # Avoid circular import + if isinstance(x, dict): + # x should only be a dict when x represents a converted reference + if '_REFERENCE' in x: + return LindiH5pyReference(x['_REFERENCE']) + else: # pragma: no cover + raise Exception(f"Unexpected dict in selection: {x}") + elif isinstance(x, list): + # Replace any references in the list with the resolved ref in-place + for i, v in enumerate(x): + x[i] = decode_references(v) + elif isinstance(x, np.ndarray): + if x.dtype == object or x.dtype is None: + # Replace any references in the object array with the resolved ref in-place + view_1d = x.reshape(-1) + for i in range(len(view_1d)): + view_1d[i] = decode_references(view_1d[i]) + return x diff --git a/lindi/LindiH5ZarrStore/_h5_filters_to_codecs.py b/lindi/conversion/h5_filters_to_codecs.py similarity index 81% rename from lindi/LindiH5ZarrStore/_h5_filters_to_codecs.py rename to lindi/conversion/h5_filters_to_codecs.py index 8052b89..7fa9cdb 100644 --- a/lindi/LindiH5ZarrStore/_h5_filters_to_codecs.py +++ b/lindi/conversion/h5_filters_to_codecs.py @@ -3,12 +3,17 @@ import numcodecs from numcodecs.abc import Codec +# The purpose of _h5_filters_to_codecs it to translate the filters that are +# defined on an HDF5 dataset into numcodecs filters for use with Zarr so that +# the raw data chunks can stay within the the HDF5 file and be read by Zarr +# without having to copy/convert the data. + # This is adapted from _decode_filters from kerchunk source # https://github.com/fsspec/kerchunk # Copyright (c) 2020 Intake # MIT License -def _h5_filters_to_codecs(h5obj: h5py.Dataset) -> Union[List[Codec], None]: +def h5_filters_to_codecs(h5obj: h5py.Dataset) -> Union[List[Codec], None]: """Decode HDF5 filters to numcodecs filters.""" if h5obj.scaleoffset: raise RuntimeError( @@ -63,6 +68,9 @@ def _h5_filters_to_codecs(h5obj: h5py.Dataset) -> Union[List[Codec], None]: elif str(filter_id) == "shuffle": # already handled before this loop pass + elif str(filter_id) == "fletcher32": + # added by lindi (not in kerchunk) -- required by dandiset 000117 + filters.append(numcodecs.Fletcher32()) else: raise RuntimeError( f"{h5obj.name} uses filter id {filter_id} with properties {properties}," diff --git a/lindi/conversion/h5_ref_to_zarr_attr.py b/lindi/conversion/h5_ref_to_zarr_attr.py new file mode 100644 index 0000000..f6cf2ea --- /dev/null +++ b/lindi/conversion/h5_ref_to_zarr_attr.py @@ -0,0 +1,61 @@ +import h5py + + +def h5_ref_to_zarr_attr(ref: h5py.Reference, *, h5f: h5py.File): + """Convert/encode an h5py reference to a format that zarr can accept. + + Parameters + ---------- + ref : h5py.Reference + The reference to convert. + h5f : h5py.File + The file that the reference is in. + + Returns + ------- + dict + The reference in a format that zarr can accept. + + The format is a dictionary with a single key, '_REFERENCE', whose value is + another dictionary with the following keys: + + 'object_id', 'path', 'source', 'source_object_id' + + * object_id is the object ID of the target object. + * path is the path of the target object. + * source is always '.', meaning that path is relative to the root of the + file (I think) + * source_object_id is the object ID of the source object. + + See + https://hdmf-zarr.readthedocs.io/en/latest/storage.html#storing-object-references-in-attributes + + Note that we will also need to handle "region" references. I would propose + another field in the value containing the region info. See + https://hdmf-zarr.readthedocs.io/en/latest/storage.html#sec-zarr-storage-references-region + """ + dref_obj = h5f[ref] + deref_objname = dref_obj.name + + object_id = dref_obj.attrs.get("object_id", None) + + # Here we assume that the file has a top-level attribute called "object_id". + # This will be the case for files created by the LindiH5ZarrStore class. + file_object_id = h5f.attrs.get("object_id", None) + + # See https://hdmf-zarr.readthedocs.io/en/latest/storage.html#storing-object-references-in-attributes + value = { + "object_id": object_id, + "path": deref_objname, + "source": ".", # Are we always going to use the top-level object as the source? + "source_object_id": file_object_id, + } + + # We need this to be json serializable + for k, v in value.items(): + if isinstance(v, bytes): + value[k] = v.decode('utf-8') + + return { + "_REFERENCE": value + } diff --git a/lindi/conversion/nan_inf_ninf.py b/lindi/conversion/nan_inf_ninf.py new file mode 100644 index 0000000..cc7595d --- /dev/null +++ b/lindi/conversion/nan_inf_ninf.py @@ -0,0 +1,34 @@ +import numpy as np + + +def decode_nan_inf_ninf(val): + if isinstance(val, list): + return [decode_nan_inf_ninf(v) for v in val] + elif isinstance(val, dict): + return {k: decode_nan_inf_ninf(v) for k, v in val.items()} + elif val == 'NaN': + return float('nan') + elif val == 'Infinity': + return float('inf') + elif val == '-Infinity': + return float('-inf') + else: + return val + + +def encode_nan_inf_ninf(val): + if isinstance(val, list): + return [encode_nan_inf_ninf(v) for v in val] + elif isinstance(val, dict): + return {k: encode_nan_inf_ninf(v) for k, v in val.items()} + elif isinstance(val, (float, np.floating)): + if np.isnan(val): + return 'NaN' + elif val == float('inf'): + return 'Infinity' + elif val == float('-inf'): + return '-Infinity' + else: + return val + else: + return val diff --git a/lindi/conversion/reformat_json.py b/lindi/conversion/reformat_json.py new file mode 100644 index 0000000..b10b1f0 --- /dev/null +++ b/lindi/conversion/reformat_json.py @@ -0,0 +1,14 @@ +from typing import Union +import json + + +def reformat_json(x: Union[bytes, None]) -> Union[bytes, None]: + """Reformat to not include whitespace and to not allow nan, inf, and ninf. + + It is assumed that float attributes nan, inf, and ninf float values have + been encoded as strings. See encode_nan_inf_ninf() and h5_to_zarr_attr(). + """ + if x is None: + return None + a = json.loads(x.decode("utf-8")) + return json.dumps(a, separators=(",", ":"), allow_nan=False).encode("utf-8") diff --git a/tests/test_copy.py b/tests/test_copy.py new file mode 100644 index 0000000..95af5a9 --- /dev/null +++ b/tests/test_copy.py @@ -0,0 +1,72 @@ +import h5py +import tempfile +import pytest +import lindi +from lindi import LindiH5ZarrStore +from utils import arrays_are_equal, assert_groups_equal + + +def test_copy_dataset(): + with tempfile.TemporaryDirectory() as tmpdir: + filename = f"{tmpdir}/test.h5" + with h5py.File(filename, "w") as f: + f.create_dataset("X", data=[1, 2, 3]) + f.create_dataset("Y", data=[4, 5, 6]) + f['X'].attrs['attr1'] = 'value1' + h5f = h5py.File(filename, "r") + with LindiH5ZarrStore.from_file(filename, url=filename) as store: + rfs = store.to_reference_file_system() + h5f_2 = lindi.LindiH5pyFile.from_reference_file_system(rfs, mode="r+") + assert "X" in h5f_2 + assert "Y" in h5f_2 + with pytest.raises(Exception): + # This one is not expected to work. Would be difficult to + # implement since this involves low-level operations on + # LindiH5pyFile. + h5f.copy("X", h5f_2, "Z") + h5f_2.copy("X", h5f_2, "Z") + assert "Z" in h5f_2 + assert h5f_2["Z"].attrs['attr1'] == 'value1' # type: ignore + assert arrays_are_equal(h5f["X"][()], h5f_2["Z"][()]) # type: ignore + rfs_copy = store.to_reference_file_system() + h5f_3 = lindi.LindiH5pyFile.from_reference_file_system(rfs_copy, mode="r+") + assert "Z" not in h5f_3 + h5f_2.copy("X", h5f_3, "Z") + assert "Z" in h5f_3 + assert h5f_3["Z"].attrs['attr1'] == 'value1' # type: ignore + assert arrays_are_equal(h5f["X"][()], h5f_3["Z"][()]) # type: ignore + + +def test_copy_group(): + with tempfile.TemporaryDirectory() as tmpdir: + filename = f"{tmpdir}/test.h5" + with h5py.File(filename, "w") as f: + f.create_group("X") + f.create_group("Y") + f.create_dataset("X/A", data=[1, 2, 3]) + f.create_dataset("Y/B", data=[4, 5, 6]) + f['X'].attrs['attr1'] = 'value1' + h5f = h5py.File(filename, "r") + with LindiH5ZarrStore.from_file(filename, url=filename) as store: + rfs = store.to_reference_file_system() + h5f_2 = lindi.LindiH5pyFile.from_reference_file_system(rfs, mode="r+") + assert "X" in h5f_2 + assert "Y" in h5f_2 + with pytest.raises(Exception): + # This one is not expected to work. Would be difficult to + # implement since this involves low-level operations on + # LindiH5pyFile. + h5f.copy("X", h5f_2, "Z") + h5f_2.copy("X", h5f_2, "Z") + assert "Z" in h5f_2 + assert_groups_equal(h5f["X"], h5f_2["Z"]) # type: ignore + rfs_copy = store.to_reference_file_system() + h5f_3 = lindi.LindiH5pyFile.from_reference_file_system(rfs_copy, mode="r+") + assert "Z" not in h5f_3 + h5f_2.copy("X", h5f_3, "Z") + assert "Z" in h5f_3 + assert_groups_equal(h5f["X"], h5f_3["Z"]) # type: ignore + + +if __name__ == '__main__': + test_copy_dataset() diff --git a/tests/test_core.py b/tests/test_core.py index a3aacec..16be391 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -4,6 +4,7 @@ import tempfile import lindi from lindi import LindiH5ZarrStore +from utils import arrays_are_equal, lists_are_equal def test_variety(): @@ -25,7 +26,6 @@ def test_variety(): f["group1"].attrs["test_attr2"] = "attribute-of-group1" h5f = h5py.File(filename, "r") h5f_wrapped = lindi.LindiH5pyFile.from_h5py_file(h5f) - assert h5f_wrapped.id # for coverage with LindiH5ZarrStore.from_file(filename, url=filename) as store: rfs = store.to_reference_file_system() h5f_rfs = lindi.LindiH5pyFile.from_reference_file_system(rfs) @@ -34,12 +34,12 @@ def test_variety(): assert h5f_2.attrs["float1"] == h5f.attrs["float1"] assert h5f_2.attrs["str1"] == h5f.attrs["str1"] assert h5f_2.attrs["bytes1"] == h5f.attrs["bytes1"] - assert _lists_are_equal(h5f_2.attrs["list1"], h5f.attrs["list1"]) - assert _lists_are_equal(h5f_2.attrs["tuple1"], h5f.attrs["tuple1"]) - assert _arrays_are_equal(np.array(h5f_2.attrs["array1"]), h5f.attrs["array1"]) + assert lists_are_equal(h5f_2.attrs["list1"], h5f.attrs["list1"]) + assert lists_are_equal(h5f_2.attrs["tuple1"], h5f.attrs["tuple1"]) + assert arrays_are_equal(np.array(h5f_2.attrs["array1"]), h5f.attrs["array1"]) assert h5f_2["dataset1"].attrs["test_attr1"] == h5f["dataset1"].attrs["test_attr1"] # type: ignore assert h5f_2["dataset1"].id - assert _arrays_are_equal(h5f_2["dataset1"][()], h5f["dataset1"][()]) # type: ignore + assert arrays_are_equal(h5f_2["dataset1"][()], h5f["dataset1"][()]) # type: ignore assert h5f_2["group1"].attrs["test_attr2"] == h5f["group1"].attrs["test_attr2"] # type: ignore target_1 = h5f[h5f.attrs["dataset1_ref"]] target_2 = h5f_2[h5f_2.attrs["dataset1_ref"]] @@ -86,17 +86,17 @@ def test_soft_links(): assert isinstance(ds1, h5py.Dataset) ds2 = h5f_2['soft_link']['dataset1'] # type: ignore assert isinstance(ds2, h5py.Dataset) - assert _arrays_are_equal(ds1[()], ds2[()]) + assert arrays_are_equal(ds1[()], ds2[()]) ds1 = h5f['soft_link/dataset1'] assert isinstance(ds1, h5py.Dataset) ds2 = h5f_2['soft_link/dataset1'] assert isinstance(ds2, h5py.Dataset) - assert _arrays_are_equal(ds1[()], ds2[()]) + assert arrays_are_equal(ds1[()], ds2[()]) ds1 = h5f['group_target/dataset1'] assert isinstance(ds1, h5py.Dataset) ds2 = h5f_2['group_target/dataset1'] assert isinstance(ds2, h5py.Dataset) - assert _arrays_are_equal(ds1[()], ds2[()]) + assert arrays_are_equal(ds1[()], ds2[()]) def test_arrays_of_compound_dtype(): @@ -104,9 +104,12 @@ def test_arrays_of_compound_dtype(): filename = f"{tmpdir}/test.h5" with h5py.File(filename, "w") as f: dt = np.dtype([("x", "i4"), ("y", "f8")]) - f.create_dataset("dataset1", data=[(1, 3.14), (2, 6.28)], dtype=dt) + dataset1 = f.create_dataset("dataset1", data=[(1, 3.14), (2, 6.28)], dtype=dt) dt = np.dtype([("a", "i4"), ("b", "f8"), ("c", "S10")]) - f.create_dataset("dataset2", data=[(1, 3.14, "abc"), (2, 6.28, "def")], dtype=dt) + dataset2 = f.create_dataset("dataset2", data=[(1, 3.14, "abc"), (2, 6.28, "def")], dtype=dt) + # how about references! + dt = np.dtype([("a", "i4"), ("b", "f8"), ("c", h5py.special_dtype(ref=h5py.Reference))]) + f.create_dataset("dataset3", data=[(1, 3.14, dataset1.ref), (2, 6.28, dataset2.ref)], dtype=dt) h5f = h5py.File(filename, "r") with LindiH5ZarrStore.from_file(filename, url=filename) as store: rfs = store.to_reference_file_system() @@ -116,16 +119,27 @@ def test_arrays_of_compound_dtype(): ds1_2 = h5f_2['dataset1'] assert isinstance(ds1_2, h5py.Dataset) assert ds1_1.dtype == ds1_2.dtype - assert _arrays_are_equal(ds1_1['x'][()], ds1_2['x'][()]) # type: ignore - assert _arrays_are_equal(ds1_1['y'][()], ds1_2['y'][()]) # type: ignore + assert arrays_are_equal(ds1_1['x'][()], ds1_2['x'][()]) # type: ignore + assert arrays_are_equal(ds1_1['y'][()], ds1_2['y'][()]) # type: ignore ds2_1 = h5f['dataset2'] assert isinstance(ds2_1, h5py.Dataset) ds2_2 = h5f_2['dataset2'] assert isinstance(ds2_2, h5py.Dataset) assert ds2_1.dtype == ds2_2.dtype - assert _arrays_are_equal(ds2_1['a'][()], ds2_2['a'][()]) # type: ignore - assert _arrays_are_equal(ds2_1['b'][()], ds2_2['b'][()]) # type: ignore - assert _arrays_are_equal(ds2_1['c'][()], ds2_2['c'][()]) # type: ignore + assert arrays_are_equal(ds2_1['a'][()], ds2_2['a'][()]) # type: ignore + assert arrays_are_equal(ds2_1['b'][()], ds2_2['b'][()]) # type: ignore + assert arrays_are_equal(ds2_1['c'][()], ds2_2['c'][()]) # type: ignore + ds3_1 = h5f['dataset3'] + assert isinstance(ds3_1, h5py.Dataset) + ds3_2 = h5f_2['dataset3'] + assert isinstance(ds3_2, h5py.Dataset) + assert ds3_1.dtype == ds3_2.dtype + assert ds3_1.dtype['c'] == ds3_2.dtype['c'] + assert ds3_2.dtype['c'] == h5py.special_dtype(ref=h5py.Reference) + target1 = h5f[ds3_1['c'][0]] + assert isinstance(target1, h5py.Dataset) + target2 = h5f_2[ds3_2['c'][0]] + assert isinstance(target2, h5py.Dataset) def test_arrays_of_compound_dtype_with_references(): @@ -144,7 +158,7 @@ def test_arrays_of_compound_dtype_with_references(): ds1_2 = h5f_2['dataset1'] assert isinstance(ds1_2, h5py.Dataset) assert ds1_1.dtype == ds1_2.dtype - assert _arrays_are_equal(ds1_1['x'][()], ds1_2['x'][()]) # type: ignore + assert arrays_are_equal(ds1_1['x'][()], ds1_2['x'][()]) # type: ignore ref1 = ds1_1['y'][0] ref2 = ds1_2['y'][0] assert isinstance(ref1, h5py.Reference) @@ -153,7 +167,7 @@ def test_arrays_of_compound_dtype_with_references(): assert isinstance(target1, h5py.Dataset) target2 = h5f_2[ref2] assert isinstance(target2, h5py.Dataset) - assert _arrays_are_equal(target1[()], target2[()]) + assert arrays_are_equal(target1[()], target2[()]) def test_scalar_arrays(): @@ -206,7 +220,7 @@ def test_arrays_of_strings(): assert isinstance(X1, h5py.Dataset) X2 = h5f_2['X'] assert isinstance(X2, h5py.Dataset) - assert _lists_are_equal(X1[:].tolist(), [x.encode() for x in X2[:]]) # type: ignore + assert lists_are_equal(X1[:].tolist(), [x.encode() for x in X2[:]]) # type: ignore def test_numpy_arrays(): @@ -260,13 +274,13 @@ def test_nan_inf_attributes(): assert isinstance(nanval, float) and np.isnan(nanval) assert X1.attrs["inf"] == np.inf assert X1.attrs["ninf"] == -np.inf - assert _lists_are_equal(X1.attrs['float_list'], [np.nan, np.inf, -np.inf, 23]) + assert lists_are_equal(X1.attrs['float_list'], [np.nan, np.inf, -np.inf, 23]) nanval = X2.attrs["nan"] assert isinstance(nanval, float) and np.isnan(nanval) assert X2.attrs["inf"] == np.inf assert X2.attrs["ninf"] == -np.inf - assert _lists_are_equal(X2.attrs['float_list'], [np.nan, np.inf, -np.inf, 23]) + assert lists_are_equal(X2.attrs['float_list'], [np.nan, np.inf, -np.inf, 23]) for test_string in ["NaN", "Infinity", "-Infinity", "Not-illegal"]: filename = f"{tmpdir}/illegal_string.h5" @@ -294,24 +308,12 @@ def test_reference_file_system_to_file(): client = lindi.LindiH5pyFile.from_reference_file_system(rfs_fname) X = client["X"] assert isinstance(X, lindi.LindiH5pyDataset) - assert _lists_are_equal(X[()], [1, 2, 3]) + assert lists_are_equal(X[()], [1, 2, 3]) def test_lindi_reference_file_system_store(): from lindi.LindiH5pyFile.LindiReferenceFileSystemStore import LindiReferenceFileSystemStore - # test that setting items is not allowed - rfs = {"refs": {"a": "a"}} - store = LindiReferenceFileSystemStore(rfs) - with pytest.raises(Exception): - store["b"] = "b" - - # test that deleting items is not allowed - rfs = {"refs": {"a": "a"}} - store = LindiReferenceFileSystemStore(rfs) - with pytest.raises(Exception): - del store["a"] - # test for invalid rfs rfs = {"rfs_misspelled": {"a": "a"}} # misspelled with pytest.raises(Exception): @@ -347,7 +349,7 @@ def test_lindi_reference_file_system_store(): rfs = {"refs": {"a": "abc"}} store = LindiReferenceFileSystemStore(rfs) assert store.is_readable() - assert not store.is_writeable() + assert store.is_writeable() assert store.is_listable() assert not store.is_erasable() assert len(store) == 1 @@ -469,29 +471,21 @@ def test_lindi_h5_zarr_store(): assert 'scalar_dataset/1' not in store -def _lists_are_equal(a, b): - if len(a) != len(b): - return False - for aa, bb in zip(a, b): - if aa != bb: - if np.isnan(aa) and np.isnan(bb): - # nan != nan, but we want to consider them equal - continue - return False - return True - - -def _arrays_are_equal(a, b): - if a.shape != b.shape: - return False - if a.dtype != b.dtype: - return False - # if this is numeric data we need to use allclose so that we can handle NaNs - if np.issubdtype(a.dtype, np.number): - return np.allclose(a, b, equal_nan=True) - else: - return np.array_equal(a, b) +def test_numpy_array_of_byte_strings(): + with tempfile.TemporaryDirectory() as tmpdir: + filename = f"{tmpdir}/test.h5" + with h5py.File(filename, "w") as f: + f.create_dataset("X", data=np.array([b"abc", b"def", b"ghi"])) + h5f = h5py.File(filename, "r") + with LindiH5ZarrStore.from_file(filename, url=filename) as store: + rfs = store.to_reference_file_system() + h5f_2 = lindi.LindiH5pyFile.from_reference_file_system(rfs) + X1 = h5f['X'] + assert isinstance(X1, h5py.Dataset) + X2 = h5f_2['X'] + assert isinstance(X2, h5py.Dataset) + assert lists_are_equal(X1[:].tolist(), X2[:].tolist()) # type: ignore if __name__ == '__main__': - test_scalar_arrays() + pass diff --git a/tests/test_fletcher32.py b/tests/test_fletcher32.py new file mode 100644 index 0000000..a485e08 --- /dev/null +++ b/tests/test_fletcher32.py @@ -0,0 +1,27 @@ +import tempfile +import h5py +import lindi +import numpy as np + + +def test_fletcher32(): + with tempfile.TemporaryDirectory() as tmpdir: + filename = f'{tmpdir}/test.h5' + with h5py.File(filename, 'w') as f: + dset = f.create_dataset('dset', shape=(100,), dtype='i4', fletcher32=True) + dset[...] = range(100) + assert dset.fletcher32 + store = lindi.LindiH5ZarrStore.from_file(filename, url=filename) + rfs = store.to_reference_file_system() + client = lindi.LindiH5pyFile.from_reference_file_system(rfs) + ds0 = client['dset'] + assert isinstance(ds0, h5py.Dataset) + assert ds0.fletcher32 + data = ds0[...] + assert isinstance(data, np.ndarray) + assert data.dtype == np.dtype('int32') + assert np.all(data == np.arange(100)) + + +if __name__ == '__main__': + test_fletcher32() diff --git a/tests/test_remote_data.py b/tests/test_remote_data.py index 27af554..8ecdbea 100644 --- a/tests/test_remote_data.py +++ b/tests/test_remote_data.py @@ -1,6 +1,7 @@ import json import pytest import lindi +from utils import arrays_are_equal @pytest.mark.network @@ -34,7 +35,7 @@ def test_remote_data_2(): import pynwb # Define the URL for a remote .zarr.json file - url = 'https://kerchunk.neurosift.org/dandi/dandisets/000939/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/zarr.json' + url = 'https://lindi.neurosift.org/dandi/dandisets/000939/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/zarr.json' # Load the h5py-like client from the reference file system client = lindi.LindiH5pyFile.from_reference_file_system(url) @@ -43,3 +44,52 @@ def test_remote_data_2(): with pynwb.NWBHDF5IO(file=client, mode="r") as io: nwbfile = io.read() print(nwbfile) + + +@pytest.mark.network +def test_remote_data_rfs_copy(): + # Test that we can copy datasets and groups from one reference file system to another + # and the data itself is not copied, only the references. + url = 'https://lindi.neurosift.org/dandi/dandisets/000939/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/zarr.json' + + client = lindi.LindiH5pyFile.from_reference_file_system(url) + + rfs2 = {'refs': { + '.zgroup': '{"zarr_format": 2}', + }} + client2 = lindi.LindiH5pyFile.from_reference_file_system(rfs2) + + # This first dataset is a 2D array with chunks + ds = client['processing/behavior/Position/position/data'] + assert isinstance(ds, lindi.LindiH5pyDataset) + assert ds.shape == (494315, 2) + + client.copy('processing/behavior/Position/position/data', client2, 'copied_data1') + aa = rfs2['refs']['copied_data1/.zarray'] + assert isinstance(aa, str) + assert 'copied_data1/0.0' in rfs2['refs'] + bb = rfs2['refs']['copied_data1/0.0'] + assert isinstance(bb, list) # make sure it is a reference, not the actual data + + ds2 = client2['copied_data1'] + assert isinstance(ds2, lindi.LindiH5pyDataset) + assert arrays_are_equal(ds[()], ds2[()]) # make sure the data is the same + + # This next dataset has an _EXTERNAL_ARRAY_LINK which means it has a pointer + # to a dataset in a remote h5py + ds = client['processing/ecephys/LFP/LFP/data'] + assert isinstance(ds, lindi.LindiH5pyDataset) + assert ds.shape == (17647830, 64) + + client.copy('processing/ecephys/LFP/LFP/data', client2, 'copied_data2') + aa = rfs2['refs']['copied_data2/.zarray'] + assert isinstance(aa, str) + assert 'copied_data2/0.0' not in rfs2['refs'] # make sure the chunks were not copied + + ds2 = client2['copied_data2'] + assert isinstance(ds2, lindi.LindiH5pyDataset) + assert arrays_are_equal(ds[100000:100010], ds2[100000:100010]) + + +if __name__ == "__main__": + test_remote_data_rfs_copy() diff --git a/tests/test_store.py b/tests/test_store.py index e7688f5..aff4667 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -1,6 +1,7 @@ import h5py import tempfile import lindi +from utils import lists_are_equal def test_store(): @@ -14,11 +15,11 @@ def test_store(): with lindi.LindiH5ZarrStore.from_file(filename, url=filename) as store: store.to_file(f"{tmpdir}/test.zarr.json") # for coverage a = store.listdir('') - assert _lists_are_equal(a, ['dataset1', 'group1'], ordered=False) + assert _lists_are_equal_as_sets(a, ['dataset1', 'group1']) b = store.listdir('group1') - assert _lists_are_equal(b, ['group2', 'dataset2'], ordered=False) + assert _lists_are_equal_as_sets(b, ['group2', 'dataset2']) c = store.listdir('group1/group2') - assert _lists_are_equal(c, [], ordered=False) + assert _lists_are_equal_as_sets(c, []) assert '.zattrs' in store assert '.zgroup' in store assert 'dataset1' not in store @@ -41,18 +42,10 @@ def test_store(): assert 'group1/dataset2/0' in store client = lindi.LindiH5pyFile.from_zarr_store(store) X = client["dataset1"][:] # type: ignore - assert _lists_are_equal(X, [1, 2, 3], ordered=True) + assert lists_are_equal(X, [1, 2, 3]) Y = client["group1/dataset2"][:] # type: ignore - assert _lists_are_equal(Y, [4, 5, 6], ordered=True) + assert lists_are_equal(Y, [4, 5, 6]) -def _lists_are_equal(a, b, ordered: bool): - if ordered: - if len(a) != len(b): - return False - for i in range(len(a)): - if a[i] != b[i]: - return False - return True - else: - return set(a) == set(b) +def _lists_are_equal_as_sets(a, b): + return set(a) == set(b) diff --git a/tests/test_zarr_write.py b/tests/test_zarr_write.py new file mode 100644 index 0000000..a205502 --- /dev/null +++ b/tests/test_zarr_write.py @@ -0,0 +1,105 @@ +import tempfile +import numpy as np +import zarr +import h5py +import lindi +import pytest +from utils import assert_groups_equal + + +def test_zarr_write(): + with tempfile.TemporaryDirectory() as tmpdir: + dirname = f'{tmpdir}/test.zarr' + store = zarr.DirectoryStore(dirname) + zarr.group(store=store) + with lindi.LindiH5pyFile.from_zarr_store(store, mode='r+') as h5f_backed_by_zarr: + write_example_h5_data(h5f_backed_by_zarr) + + store2 = zarr.DirectoryStore(dirname) + with lindi.LindiH5pyFile.from_zarr_store(store2) as h5f_backed_by_zarr: + compare_example_h5_data(h5f_backed_by_zarr, tmpdir=tmpdir) + + +def test_require_dataset(): + with tempfile.TemporaryDirectory() as tmpdir: + dirname = f'{tmpdir}/test.zarr' + store = zarr.DirectoryStore(dirname) + zarr.group(store=store) + with lindi.LindiH5pyFile.from_zarr_store(store, mode='r+') as h5f_backed_by_zarr: + h5f_backed_by_zarr.create_dataset('dset_int8', data=np.array([1, 2, 3], dtype=np.int8)) + h5f_backed_by_zarr.create_dataset('dset_int16', data=np.array([1, 2, 3], dtype=np.int16)) + h5f_backed_by_zarr.require_dataset('dset_int8', shape=(3,), dtype=np.int8) + with pytest.raises(Exception): + h5f_backed_by_zarr.require_dataset('dset_int8', shape=(4,), dtype=np.int8) + with pytest.raises(Exception): + h5f_backed_by_zarr.require_dataset('dset_int8', shape=(3,), dtype=np.int16, exact=True) + h5f_backed_by_zarr.require_dataset('dset_int8', shape=(3,), dtype=np.int16, exact=False) + with pytest.raises(Exception): + h5f_backed_by_zarr.require_dataset('dset_int16', shape=(3,), dtype=np.int8, exact=False) + ds = h5f_backed_by_zarr.require_dataset('dset_float32', shape=(3,), dtype=np.float32) + ds[:] = np.array([1.1, 2.2, 3.3]) + with pytest.raises(Exception): + h5f_backed_by_zarr.require_dataset('dset_float32', shape=(3,), dtype=np.float64, exact=True) + + +def write_example_h5_data(h5f: h5py.File): + h5f.attrs['attr_str'] = 'hello' + h5f.attrs['attr_int'] = 42 + h5f.attrs['attr_float'] = 3.14 + h5f.attrs['attr_bool'] = True + h5f.attrs['attr_list_str'] = ['a', 'b', 'c'] + h5f.attrs['attr_list_int'] = [1, 2, 3] + h5f.attrs['attr_list_float'] = [1.1, 2.2, 3.3] + h5f.attrs['attr_list_bool'] = [True, False, True] + with pytest.raises(Exception): + h5f.attrs['attr_list_mixed'] = [1, 2.2, 'c', True] + h5f.attrs['2d_array'] = np.array([[1, 2], [3, 4]]) + h5f.create_dataset('dset_int8', data=np.array([1, 2, 3], dtype=np.int8)) + h5f.create_dataset('dset_int16', data=np.array([1, 2, 3], dtype=np.int16)) + h5f.create_dataset('dset_int32', data=np.array([1, 2, 3], dtype=np.int32)) + h5f.create_dataset('dset_int64', data=np.array([1, 2, 3], dtype=np.int64)) + h5f.create_dataset('dset_uint8', data=np.array([1, 2, 3], dtype=np.uint8)) + h5f.create_dataset('dset_uint16', data=np.array([1, 2, 3], dtype=np.uint16)) + h5f.create_dataset('dset_uint32', data=np.array([1, 2, 3], dtype=np.uint32)) + h5f.create_dataset('dset_uint64', data=np.array([1, 2, 3], dtype=np.uint64)) + h5f.create_dataset('dset_float32', data=np.array([1, 2, 3], dtype=np.float32)) + h5f.create_dataset('dset_float64', data=np.array([1, 2, 3], dtype=np.float64)) + h5f.create_dataset('dset_bool', data=np.array([True, False, True], dtype=np.bool_)) + + group1 = h5f.create_group('group1') + group1.attrs['attr_str'] = 'hello' + group1.attrs['attr_int'] = 42 + group1.create_dataset('dset_with_nan', data=np.array([1, np.nan, 3], dtype=np.float64)) + group1.create_dataset('dset_with_inf', data=np.array([np.inf, 6, -np.inf], dtype=np.float64)) + + compound_dtype = np.dtype([('x', np.int32), ('y', np.float64)]) + group1.create_dataset('dset_compound', data=np.array([(1, 2.2), (3, 4.4)], dtype=compound_dtype)) + + group_to_delete = h5f.create_group('group_to_delete') + group_to_delete.attrs['attr_str'] = 'hello' + group_to_delete.attrs['attr_int'] = 42 + group_to_delete.create_dataset('dset_to_delete', data=np.array([1, 2, 3], dtype=np.int8)) + del h5f['group_to_delete'] + + another_group_to_delete = group1.create_group('another_group_to_delete') + another_group_to_delete.attrs['attr_str'] = 'hello' + another_group_to_delete.attrs['attr_int'] = 42 + another_group_to_delete.create_dataset('dset_to_delete', data=np.array([1, 2, 3], dtype=np.int8)) + del group1['another_group_to_delete'] + + yet_another_group_to_delete = group1.create_group('yet_another_group_to_delete') + yet_another_group_to_delete.attrs['attr_str'] = 'hello' + yet_another_group_to_delete.attrs['attr_int'] = 42 + yet_another_group_to_delete.create_dataset('dset_to_delete', data=np.array([1, 2, 3], dtype=np.int8)) + del h5f['group1/yet_another_group_to_delete'] + + +def compare_example_h5_data(h5f: h5py.File, tmpdir: str): + with h5py.File(f'{tmpdir}/for_comparison.h5', 'w') as h5f2: + write_example_h5_data(h5f2) + with h5py.File(f'{tmpdir}/for_comparison.h5', 'r') as h5f2: + assert_groups_equal(h5f, h5f2) + + +if __name__ == '__main__': + test_require_dataset() diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..eb07b38 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,82 @@ +from typing import Union +import numpy as np +import h5py +from lindi.conversion.attr_conversion import h5_to_zarr_attr + + +def assert_groups_equal(h5f: h5py.Group, h5f2: h5py.Group): + print(f'Comparing groups: {h5f.name}') + assert_attrs_equal(h5f, h5f2) + for k in h5f.keys(): + X1 = h5f[k] + X2 = h5f2[k] + if isinstance(X1, h5py.Group): + assert isinstance(X2, h5py.Group) + assert_groups_equal(X1, X2) + elif isinstance(X1, h5py.Dataset): + assert isinstance(X2, h5py.Dataset) + assert_datasets_equal(X1, X2) + else: + raise Exception(f'Unexpected type: {type(X1)}') + + for k in h5f2.keys(): + if k not in h5f: + raise Exception(f'Key {k} not found in h5f') + + +def assert_datasets_equal(h5d1: h5py.Dataset, h5d2: h5py.Dataset): + print(f'Comparing datasets: {h5d1.name}') + assert h5d1.shape == h5d2.shape, f'h5d1.shape: {h5d1.shape}, h5d2.shape: {h5d2.shape}' + assert h5d1.dtype == h5d2.dtype, f'h5d1.dtype: {h5d1.dtype}, h5d2.dtype: {h5d2.dtype}' + if h5d1.dtype.kind == 'V': + for name in h5d1.dtype.names: + data1 = h5d1[name][()] + data2 = h5d2[name][()] + assert arrays_are_equal(data1, data2), f'data1: {data1}, data2: {data2}' + else: + data1 = h5d1[()] + data2 = h5d2[()] + assert arrays_are_equal(data1, data2), f'data1: {data1}, data2: {data2}' + + +def arrays_are_equal(a, b): + if a.shape != b.shape: + return False + if a.dtype != b.dtype: + return False + # if this is numeric data we need to use allclose so that we can handle NaNs + if np.issubdtype(a.dtype, np.number): + return np.allclose(a, b, equal_nan=True) + else: + return np.array_equal(a, b) + + +def assert_attrs_equal( + h5f1: Union[h5py.Group, h5py.Dataset], + h5f2: Union[h5py.Group, h5py.Dataset] +): + attrs1 = h5f1.attrs + attrs2 = h5f2.attrs + keys1 = set(attrs1.keys()) + keys2 = set(attrs2.keys()) + assert keys1 == keys2, f'keys1: {keys1}, keys2: {keys2}' + for k1, v1 in attrs1.items(): + assert_attr_equal(v1, attrs2[k1]) + + +def assert_attr_equal(v1, v2): + v1_normalized = h5_to_zarr_attr(v1, h5f=None) + v2_normalized = h5_to_zarr_attr(v2, h5f=None) + assert v1_normalized == v2_normalized, f'v1_normalized: {v1_normalized}, v2_normalized: {v2_normalized}' + + +def lists_are_equal(a, b): + if len(a) != len(b): + return False + for aa, bb in zip(a, b): + if aa != bb: + if np.isnan(aa) and np.isnan(bb): + # nan != nan, but we want to consider them equal + continue + return False + return True