Skip to content

Commit

Permalink
fix: Fix get_ranges for empty channels in Image (#1136)
Browse files Browse the repository at this point in the history
fix PARTSEG-TZ

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced class initialization with list comprehension for better range
adjustments.
  - Introduced new methods for internal adjustments and mask fitting.

- **Improvements**
- Updated method signatures to use modern type hinting for improved code
clarity and maintainability.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
Czaki authored Jul 10, 2024
1 parent cc9d57b commit d5f7ef8
Showing 1 changed file with 48 additions and 43 deletions.
91 changes: 48 additions & 43 deletions package/PartSegImage/image.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import re
import typing
import warnings
from collections.abc import Iterable
from contextlib import suppress
from typing import Union

import numpy as np

Expand Down Expand Up @@ -33,8 +34,8 @@ def minimal_dtype(val: int):

def reduce_array(
array: np.ndarray,
components: typing.Optional[typing.Collection[int]] = None,
max_val: typing.Optional[int] = None,
components: typing.Collection[int] | None = None,
max_val: int | None = None,
dtype=None,
) -> np.ndarray:
"""
Expand Down Expand Up @@ -102,12 +103,12 @@ def __init__(
data: _IMAGE_DATA,
image_spacing: Spacing,
file_path=None,
mask: typing.Union[None, np.ndarray] = None,
mask: None | np.ndarray = None,
default_coloring=None,
ranges=None,
channel_names=None,
axes_order: typing.Optional[str] = None,
shift: typing.Optional[Spacing] = None,
axes_order: str | None = None,
shift: Spacing | None = None,
name: str = "",
):
# TODO add time distance to image spacing
Expand Down Expand Up @@ -144,18 +145,26 @@ def __init__(
self.default_coloring = [np.array(x) for x in default_coloring]

self._channel_names = self._prepare_channel_names(channel_names, self.channels)

self.ranges = self._adjust_ranges(ranges, self._channel_arrays)
self._mask_array = self._fit_mask(mask, data, axes_order)

@staticmethod
def _adjust_ranges(
ranges: list[tuple[float, float]] | None, channel_arrays: list[np.ndarray]
) -> list[tuple[float, float]]:
if ranges is None:
self.ranges = list(
zip((np.min(c) for c in self._channel_arrays), (np.max(c) for c in self._channel_arrays))
)
else:
self.ranges = ranges
self._mask_array = self._prepare_mask(mask, data, axes_order)
if self._mask_array is not None:
self._mask_array = self.fit_mask_to_image(self._mask_array)
ranges = list(zip((np.min(c) for c in channel_arrays), (np.max(c) for c in channel_arrays)))
return [(min_val, max_val) if (min_val != max_val) else (min_val, min_val + 1) for (min_val, max_val) in ranges]

def _fit_mask(self, mask, data, axes_order):
mask_array = self._prepare_mask(mask, data, axes_order)
if mask_array is not None:
mask_array = self.fit_mask_to_image(mask_array)
return mask_array

@classmethod
def _prepare_mask(cls, mask, data, axes_order) -> typing.Optional[np.ndarray]:
def _prepare_mask(cls, mask, data, axes_order) -> np.ndarray | None:
if mask is None:
return None

Expand All @@ -170,7 +179,7 @@ def _prepare_mask(cls, mask, data, axes_order) -> typing.Optional[np.ndarray]:
return cls.reorder_axes(mask, axes_order.replace("C", ""))

@staticmethod
def _prepare_channel_names(channel_names, channels_num) -> typing.List[str]:
def _prepare_channel_names(channel_names, channels_num) -> list[str]:
default_channel_names = [f"channel {i + 1}" for i in range(channels_num)]
if isinstance(channel_names, str):
channel_names = [channel_names]
Expand All @@ -182,9 +191,7 @@ def _prepare_channel_names(channel_names, channels_num) -> typing.List[str]:
return channel_names_list[:channels_num]

@classmethod
def _split_data_on_channels(
cls, data: typing.Union[np.ndarray, typing.List[np.ndarray]], axes_order: str
) -> typing.List[np.ndarray]:
def _split_data_on_channels(cls, data: np.ndarray | list[np.ndarray], axes_order: str) -> list[np.ndarray]:
if isinstance(data, list) and not axes_order.startswith("C"): # pragma: no cover
raise ValueError("When passing data as list of numpy arrays then Channel must be first axis.")
if "C" not in axes_order:
Expand All @@ -199,7 +206,7 @@ def _split_data_on_channels(

if not isinstance(data, np.ndarray):
raise TypeError("If `data` is list of arrays then `axes_order` must start with `C`") # pragma: no cover
pos: typing.List[typing.Union[slice, int]] = [slice(None) for _ in range(data.ndim)]
pos: list[slice | int] = [slice(None) for _ in range(data.ndim)]
c_pos = axes_order.index("C")
res = []
for i in range(data.shape[c_pos]):
Expand All @@ -208,9 +215,7 @@ def _split_data_on_channels(
return res

@staticmethod
def _merge_channel_names(
base_channel_names: typing.List[str], new_channel_names: typing.List[str]
) -> typing.List[str]:
def _merge_channel_names(base_channel_names: list[str], new_channel_names: list[str]) -> list[str]:
base_channel_names = base_channel_names[:]
reg = re.compile(r"channel \d+")
for name in new_channel_names:
Expand All @@ -228,7 +233,7 @@ def _merge_channel_names(
base_channel_names.append(new_name)
return base_channel_names

def merge(self, image: "Image", axis: str) -> "Image":
def merge(self, image: Image, axis: str) -> Image:
"""
Produce new image merging image data along given axis. All metadata
are obtained from self.
Expand Down Expand Up @@ -256,7 +261,7 @@ def merge(self, image: "Image", axis: str) -> "Image":
return self.substitute(data=data, ranges=self.ranges + image.ranges, channel_names=channel_names)

@property
def channel_names(self) -> typing.List[str]:
def channel_names(self) -> list[str]:
return self._channel_names[:]

@property
Expand Down Expand Up @@ -333,7 +338,7 @@ def substitute(
default_coloring=None,
ranges=None,
channel_names=None,
) -> "Image":
) -> Image:
"""Create copy of image with substitution of not None elements"""
data = self._channel_arrays if data is None else data
image_spacing = self._image_spacing if image_spacing is None else image_spacing
Expand All @@ -353,7 +358,7 @@ def substitute(
axes_order=self.axis_order,
)

def set_mask(self, mask: typing.Optional[np.ndarray], axes: typing.Optional[str] = None):
def set_mask(self, mask: np.ndarray | None, axes: str | None = None):
"""
Set mask for image, check if it has proper shape.
Expand All @@ -374,7 +379,7 @@ def get_data(self) -> np.ndarray:
return self._channel_arrays[0]

@property
def mask(self) -> typing.Optional[np.ndarray]:
def mask(self) -> np.ndarray | None:
return self._mask_array[:] if self._mask_array is not None else None

@staticmethod
Expand Down Expand Up @@ -430,7 +435,7 @@ def get_image_for_save(self) -> np.ndarray:
)
return self._reorder_axes(self._channel_arrays[0], self.axis_order, "TZCYX")

def get_mask_for_save(self) -> typing.Optional[np.ndarray]:
def get_mask_for_save(self) -> np.ndarray | None:
"""
:return: if image has mask then return mask with axes in proper order
"""
Expand Down Expand Up @@ -469,7 +474,7 @@ def times(self) -> int:
return self._channel_arrays[0].shape[self.time_pos]

@property
def plane_shape(self) -> typing.Tuple[int, int]:
def plane_shape(self) -> tuple[int, int]:
"""y,x size of image"""
return self._channel_arrays[0].shape[self.y_pos], self._channel_arrays[0].shape[self.x_pos]

Expand All @@ -487,15 +492,15 @@ def swap_time_and_stack(self):
return self.substitute(data=self._image_data_normalize(image_array_list))

@classmethod
def get_axis_positions(cls) -> typing.Dict[str, int]:
def get_axis_positions(cls) -> dict[str, int]:
"""
:return: dict with mapping axis to its position
:rtype: dict
"""
return {letter: i for i, letter in enumerate(cls.axis_order)}

@classmethod
def get_array_axis_positions(cls) -> typing.Dict[str, int]:
def get_array_axis_positions(cls) -> dict[str, int]:
"""
:return: dict with mapping axis to its position for array fitted to image
:rtype: dict
Expand All @@ -510,7 +515,7 @@ def get_data_by_axis(self, **kwargs) -> np.ndarray:
:return:
:rtype:
"""
slices: typing.List[typing.Union[int, slice]] = [slice(None) for _ in range(len(self.array_axis_order))]
slices: list[int | slice] = [slice(None) for _ in range(len(self.array_axis_order))]
axis_pos = self.get_array_axis_positions()
if "c" in kwargs:
kwargs["C"] = kwargs.pop("c")
Expand All @@ -533,7 +538,7 @@ def get_data_by_axis(self, **kwargs) -> np.ndarray:
return self._channel_arrays[channel][slices_t]
return np.stack([x[slices_t] for x in self._channel_arrays[channel]], axis=axis_order.index("C"))

def clip_array(self, array: np.ndarray, **kwargs: typing.Union[int, slice]) -> np.ndarray:
def clip_array(self, array: np.ndarray, **kwargs: int | slice) -> np.ndarray:
"""
Clip array by axis. Axis is selected by single letter from :py:attr:`axis_order`
Expand All @@ -542,14 +547,14 @@ def clip_array(self, array: np.ndarray, **kwargs: typing.Union[int, slice]) -> n
:return: clipped array
"""
array = self.fit_array_to_image(array)
slices: typing.List[typing.Union[int, slice]] = [slice(None) for _ in range(len(self.array_axis_order))]
slices: list[int | slice] = [slice(None) for _ in range(len(self.array_axis_order))]
axis_pos = self.get_array_axis_positions()
for name in kwargs:
if (n := name.upper()) in axis_pos:
slices[axis_pos[n]] = kwargs[name]
return array[tuple(slices)]

def get_channel(self, num: Union[int, str, Channel]) -> np.ndarray:
def get_channel(self, num: int | str | Channel) -> np.ndarray:
"""
Alias for :py:func:`get_sub_data` with argument ``c=num``
Expand Down Expand Up @@ -611,7 +616,7 @@ def set_spacing(self, value: Spacing):
self._image_spacing = tuple(value)

@staticmethod
def _frame_array(array: typing.Optional[np.ndarray], index_to_add: typing.List[int], frame=FRAME_THICKNESS):
def _frame_array(array: np.ndarray | None, index_to_add: list[int], frame=FRAME_THICKNESS):
if array is None: # pragma: no cover
return array
result_shape = list(array.shape)
Expand All @@ -626,7 +631,7 @@ def _frame_array(array: typing.Optional[np.ndarray], index_to_add: typing.List[i
return data

@staticmethod
def calc_index_to_frame(array_axis: str, important_axis: str) -> typing.List[int]:
def calc_index_to_frame(array_axis: str, important_axis: str) -> list[int]:
"""
calculate in which axis frame should be added
Expand All @@ -650,15 +655,15 @@ def _frame_cut_area(self, cut_area: typing.Iterable[slice], frame: int):

def _cut_image_slices(
self, cut_area: typing.Iterable[slice], frame: int
) -> typing.Tuple[typing.List[np.ndarray], typing.Optional[np.ndarray]]:
) -> tuple[list[np.ndarray], np.ndarray | None]:
new_mask = None
cut_area = self._frame_cut_area(cut_area, frame)
new_image = [x[tuple(cut_area)] for x in self._channel_arrays]
if self._mask_array is not None:
new_mask = self._mask_array[tuple(cut_area)]
return new_image, new_mask

def _roi_to_slices(self, roi: np.ndarray) -> typing.List[slice]:
def _roi_to_slices(self, roi: np.ndarray) -> list[slice]:
cut_area = self.fit_array_to_image(roi)
points = np.nonzero(cut_area)
lower_bound = np.min(points, axis=1)
Expand Down Expand Up @@ -689,11 +694,11 @@ def _cut_with_roi(self, cut_area: np.ndarray, replace_mask: bool, frame: int):

def cut_image(
self,
cut_area: typing.Union[np.ndarray, typing.Iterable[slice]],
cut_area: np.ndarray | typing.Iterable[slice],
replace_mask=False,
frame: int = FRAME_THICKNESS,
zero_out_cut_area: bool = True,
) -> "Image":
) -> Image:
"""
Create new image base on mask or list of slices
:param bool replace_mask: if cut area is represented by mask array,
Expand Down Expand Up @@ -771,7 +776,7 @@ def get_um_shift(self) -> Spacing:
"""image spacing in micrometers"""
return tuple(float(x * 10**6) for x in self.shift)

def get_ranges(self) -> typing.List[typing.Tuple[float, float]]:
def get_ranges(self) -> list[tuple[float, float]]:
"""image brightness ranges for each channel"""
return self.ranges[:]

Expand Down

0 comments on commit d5f7ef8

Please sign in to comment.