Skip to content

Commit

Permalink
fix: Workaround for lack of zsd support in czifile (#1142)
Browse files Browse the repository at this point in the history
Workaround for bug reported cgohlke/czifile#10

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
  - Added support for new compression formats in image reading.
- Introduced `max_workers` parameter for improved image processing
performance.

- **Tests**
  - Added unit tests for reading compressed CZI files.
  - Implemented fixture to set maximum workers for CZI processing.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

fix PARTSEG-V0
  • Loading branch information
Czaki authored Jul 15, 2024
1 parent e13f990 commit 052bf57
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 4 deletions.
80 changes: 76 additions & 4 deletions package/PartSegImage/image_reader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import inspect
import os.path
import typing
from abc import abstractmethod
from contextlib import suppress
from importlib.metadata import version
from io import BytesIO
from pathlib import Path
from threading import Lock

import imagecodecs
import numpy as np
import tifffile
from czifile.czifile import CziFile
from czifile.czifile import DECOMPRESS, CziFile
from defusedxml import ElementTree
from oiffile import OifFile
from packaging.version import parse as parse_version

from PartSegImage.image import Image

Expand All @@ -21,6 +25,75 @@
from xml.etree.ElementTree import Element # nosec


CZI_MAX_WORKERS = None


class ZSTD1Header(typing.NamedTuple):
"""
ZSTD1 header structure
based on:
https://github.com/ZEISS/libczi/blob/4a60e22200cbf0c8ff2a59f69a81ef1b2b89bf4f/Src/libCZI/decoder_zstd.cpp#L19
"""

header_size: int
hiLoByteUnpackPreprocessing: bool


def parse_zstd1_header(data: bytes, size: int) -> ZSTD1Header: # pragma: no cover
"""
Parse ZSTD header
https://github.com/ZEISS/libczi/blob/4a60e22200cbf0c8ff2a59f69a81ef1b2b89bf4f/Src/libCZI/decoder_zstd.cpp#L84
"""
if size < 1:
return ZSTD1Header(0, False)

if data[0] == 1:
return ZSTD1Header(1, False)

if data[0] == 3 and size < 3:
return ZSTD1Header(0, False)

if data[1] == 1:
return ZSTD1Header(3, bool(data[2] & 1))

return ZSTD1Header(0, False)


def _get_dtype():
return inspect.currentframe().f_back.f_back.f_locals["de"].dtype


def decode_zstd1(data: bytes) -> np.ndarray:
"""
Decode ZSTD1 data
"""
header = parse_zstd1_header(data, len(data))
dtype = _get_dtype()
if header.hiLoByteUnpackPreprocessing:
array_ = np.fromstring(imagecodecs.zstd_decode(data[header.header_size :]), np.uint8)
half_size = array_.size // 2
array = np.empty(half_size, np.uint16)
array[:] = array_[:half_size] + (array_[half_size:].astype(np.uint16) << 8)
array = array.view(dtype)
else:
array = np.fromstring(imagecodecs.zstd_decode(data[header.header_size :]), dtype)
return array


def decode_zstd0(data: bytes) -> np.ndarray:
"""
Decode ZSTD0 data
"""
dtype = _get_dtype()
return np.fromstring(imagecodecs.zstd_decode(data), dtype)


if parse_version(version("czifile")) == parse_version("2019.7.2"):
DECOMPRESS[5] = decode_zstd0
DECOMPRESS[6] = decode_zstd1


def _empty(_, __):
"""Empty function for callback"""

Expand Down Expand Up @@ -146,8 +219,7 @@ def update_array_shape(cls, array: np.ndarray, axes: str):
axes_li[1] = "Z"
i = 0
while i < len(axes_li):
name = axes_li[i]
if name not in final_mapping_dict and array.shape[i] == 1:
if array.shape[i] == 1:
array = array.take(0, i)
axes_li.pop(i)
else:
Expand Down Expand Up @@ -270,7 +342,7 @@ class CziImageReader(BaseImageReaderBuffer):

def read(self, image_path: typing.Union[str, BytesIO, Path], mask_path=None, ext=None) -> Image:
image_file = CziFile(image_path)
image_data = image_file.asarray()
image_data = image_file.asarray(max_workers=CZI_MAX_WORKERS)
image_data = self.update_array_shape(image_data, image_file.axes)
metadata = image_file.metadata(False)
with suppress(KeyError):
Expand Down
25 changes: 25 additions & 0 deletions package/tests/test_PartSegImage/test_image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,24 @@
import os.path
import shutil
from glob import glob
from importlib.metadata import version
from io import BytesIO

import numpy as np
import pytest
import tifffile
from packaging.version import parse as parse_version

import PartSegData
from PartSegImage import CziImageReader, GenericImageReader, Image, ObsepImageReader, OifImagReader, TiffImageReader


@pytest.fixture(autouse=True)
def _set_max_workers_czi(monkeypatch):
# set max workers to 1 to get exception in case of problems
monkeypatch.setattr("PartSegImage.image_reader.CZI_MAX_WORKERS", 1)


class TestImageClass:
def test_tiff_image_read(self):
image = TiffImageReader.read_image(PartSegData.segmentation_mask_default_image)
Expand All @@ -28,18 +36,35 @@ def test_tiff_image_read_buffer(self):

def test_czi_file_read(self, data_test_dir):
image = CziImageReader.read_image(os.path.join(data_test_dir, "test_czi.czi"))
assert np.count_nonzero(image.get_channel(0))
assert image.channels == 4
assert image.layers == 1

assert image.file_path == os.path.join(data_test_dir, "test_czi.czi")

assert np.all(np.isclose(image.spacing, (7.752248561753867e-08,) * 2))

@pytest.mark.skipif(
parse_version(version("czifile")) < parse_version("2019.7.2"),
reason="There is no patch for czifile before 2019.7.2",
)
@pytest.mark.parametrize("file_name", ["test_czi_zstd0.czi", "test_czi_zstd1.czi", "test_czi_zstd1_hilo.czi"])
def test_czi_file_read_compressed(self, data_test_dir, file_name):
image = CziImageReader.read_image(os.path.join(data_test_dir, file_name))
assert np.count_nonzero(image.get_channel(0))
assert image.channels == 4
assert image.layers == 1

assert image.file_path == os.path.join(data_test_dir, file_name)

assert np.all(np.isclose(image.spacing, (7.752248561753867e-08,) * 2))

def test_czi_file_read_buffer(self, data_test_dir):
with open(os.path.join(data_test_dir, "test_czi.czi"), "rb") as f_p:
buffer = BytesIO(f_p.read())

image = CziImageReader.read_image(buffer)
assert np.count_nonzero(image.get_channel(0))
assert image.channels == 4
assert image.layers == 1
assert image.file_path == ""
Expand Down

0 comments on commit 052bf57

Please sign in to comment.