diff --git a/package/PartSegImage/image.py b/package/PartSegImage/image.py index 25080fb4e..49aed35d4 100644 --- a/package/PartSegImage/image.py +++ b/package/PartSegImage/image.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import re import typing import warnings from collections.abc import Iterable from contextlib import suppress -from typing import Union import numpy as np @@ -33,8 +34,8 @@ def minimal_dtype(val: int): def reduce_array( array: np.ndarray, - components: typing.Optional[typing.Collection[int]] = None, - max_val: typing.Optional[int] = None, + components: typing.Collection[int] | None = None, + max_val: int | None = None, dtype=None, ) -> np.ndarray: """ @@ -102,12 +103,12 @@ def __init__( data: _IMAGE_DATA, image_spacing: Spacing, file_path=None, - mask: typing.Union[None, np.ndarray] = None, + mask: None | np.ndarray = None, default_coloring=None, ranges=None, channel_names=None, - axes_order: typing.Optional[str] = None, - shift: typing.Optional[Spacing] = None, + axes_order: str | None = None, + shift: Spacing | None = None, name: str = "", ): # TODO add time distance to image spacing @@ -144,18 +145,26 @@ def __init__( self.default_coloring = [np.array(x) for x in default_coloring] self._channel_names = self._prepare_channel_names(channel_names, self.channels) + + self.ranges = self._adjust_ranges(ranges, self._channel_arrays) + self._mask_array = self._fit_mask(mask, data, axes_order) + + @staticmethod + def _adjust_ranges( + ranges: list[tuple[float, float]] | None, channel_arrays: list[np.ndarray] + ) -> list[tuple[float, float]]: if ranges is None: - self.ranges = list( - zip((np.min(c) for c in self._channel_arrays), (np.max(c) for c in self._channel_arrays)) - ) - else: - self.ranges = ranges - self._mask_array = self._prepare_mask(mask, data, axes_order) - if self._mask_array is not None: - self._mask_array = self.fit_mask_to_image(self._mask_array) + ranges = list(zip((np.min(c) for c in channel_arrays), (np.max(c) for c in channel_arrays))) + return [(min_val, max_val) if (min_val != max_val) else (min_val, min_val + 1) for (min_val, max_val) in ranges] + + def _fit_mask(self, mask, data, axes_order): + mask_array = self._prepare_mask(mask, data, axes_order) + if mask_array is not None: + mask_array = self.fit_mask_to_image(mask_array) + return mask_array @classmethod - def _prepare_mask(cls, mask, data, axes_order) -> typing.Optional[np.ndarray]: + def _prepare_mask(cls, mask, data, axes_order) -> np.ndarray | None: if mask is None: return None @@ -170,7 +179,7 @@ def _prepare_mask(cls, mask, data, axes_order) -> typing.Optional[np.ndarray]: return cls.reorder_axes(mask, axes_order.replace("C", "")) @staticmethod - def _prepare_channel_names(channel_names, channels_num) -> typing.List[str]: + def _prepare_channel_names(channel_names, channels_num) -> list[str]: default_channel_names = [f"channel {i + 1}" for i in range(channels_num)] if isinstance(channel_names, str): channel_names = [channel_names] @@ -182,9 +191,7 @@ def _prepare_channel_names(channel_names, channels_num) -> typing.List[str]: return channel_names_list[:channels_num] @classmethod - def _split_data_on_channels( - cls, data: typing.Union[np.ndarray, typing.List[np.ndarray]], axes_order: str - ) -> typing.List[np.ndarray]: + def _split_data_on_channels(cls, data: np.ndarray | list[np.ndarray], axes_order: str) -> list[np.ndarray]: if isinstance(data, list) and not axes_order.startswith("C"): # pragma: no cover raise ValueError("When passing data as list of numpy arrays then Channel must be first axis.") if "C" not in axes_order: @@ -199,7 +206,7 @@ def _split_data_on_channels( if not isinstance(data, np.ndarray): raise TypeError("If `data` is list of arrays then `axes_order` must start with `C`") # pragma: no cover - pos: typing.List[typing.Union[slice, int]] = [slice(None) for _ in range(data.ndim)] + pos: list[slice | int] = [slice(None) for _ in range(data.ndim)] c_pos = axes_order.index("C") res = [] for i in range(data.shape[c_pos]): @@ -208,9 +215,7 @@ def _split_data_on_channels( return res @staticmethod - def _merge_channel_names( - base_channel_names: typing.List[str], new_channel_names: typing.List[str] - ) -> typing.List[str]: + def _merge_channel_names(base_channel_names: list[str], new_channel_names: list[str]) -> list[str]: base_channel_names = base_channel_names[:] reg = re.compile(r"channel \d+") for name in new_channel_names: @@ -228,7 +233,7 @@ def _merge_channel_names( base_channel_names.append(new_name) return base_channel_names - def merge(self, image: "Image", axis: str) -> "Image": + def merge(self, image: Image, axis: str) -> Image: """ Produce new image merging image data along given axis. All metadata are obtained from self. @@ -256,7 +261,7 @@ def merge(self, image: "Image", axis: str) -> "Image": return self.substitute(data=data, ranges=self.ranges + image.ranges, channel_names=channel_names) @property - def channel_names(self) -> typing.List[str]: + def channel_names(self) -> list[str]: return self._channel_names[:] @property @@ -333,7 +338,7 @@ def substitute( default_coloring=None, ranges=None, channel_names=None, - ) -> "Image": + ) -> Image: """Create copy of image with substitution of not None elements""" data = self._channel_arrays if data is None else data image_spacing = self._image_spacing if image_spacing is None else image_spacing @@ -353,7 +358,7 @@ def substitute( axes_order=self.axis_order, ) - def set_mask(self, mask: typing.Optional[np.ndarray], axes: typing.Optional[str] = None): + def set_mask(self, mask: np.ndarray | None, axes: str | None = None): """ Set mask for image, check if it has proper shape. @@ -374,7 +379,7 @@ def get_data(self) -> np.ndarray: return self._channel_arrays[0] @property - def mask(self) -> typing.Optional[np.ndarray]: + def mask(self) -> np.ndarray | None: return self._mask_array[:] if self._mask_array is not None else None @staticmethod @@ -430,7 +435,7 @@ def get_image_for_save(self) -> np.ndarray: ) return self._reorder_axes(self._channel_arrays[0], self.axis_order, "TZCYX") - def get_mask_for_save(self) -> typing.Optional[np.ndarray]: + def get_mask_for_save(self) -> np.ndarray | None: """ :return: if image has mask then return mask with axes in proper order """ @@ -469,7 +474,7 @@ def times(self) -> int: return self._channel_arrays[0].shape[self.time_pos] @property - def plane_shape(self) -> typing.Tuple[int, int]: + def plane_shape(self) -> tuple[int, int]: """y,x size of image""" return self._channel_arrays[0].shape[self.y_pos], self._channel_arrays[0].shape[self.x_pos] @@ -487,7 +492,7 @@ def swap_time_and_stack(self): return self.substitute(data=self._image_data_normalize(image_array_list)) @classmethod - def get_axis_positions(cls) -> typing.Dict[str, int]: + def get_axis_positions(cls) -> dict[str, int]: """ :return: dict with mapping axis to its position :rtype: dict @@ -495,7 +500,7 @@ def get_axis_positions(cls) -> typing.Dict[str, int]: return {letter: i for i, letter in enumerate(cls.axis_order)} @classmethod - def get_array_axis_positions(cls) -> typing.Dict[str, int]: + def get_array_axis_positions(cls) -> dict[str, int]: """ :return: dict with mapping axis to its position for array fitted to image :rtype: dict @@ -510,7 +515,7 @@ def get_data_by_axis(self, **kwargs) -> np.ndarray: :return: :rtype: """ - slices: typing.List[typing.Union[int, slice]] = [slice(None) for _ in range(len(self.array_axis_order))] + slices: list[int | slice] = [slice(None) for _ in range(len(self.array_axis_order))] axis_pos = self.get_array_axis_positions() if "c" in kwargs: kwargs["C"] = kwargs.pop("c") @@ -533,7 +538,7 @@ def get_data_by_axis(self, **kwargs) -> np.ndarray: return self._channel_arrays[channel][slices_t] return np.stack([x[slices_t] for x in self._channel_arrays[channel]], axis=axis_order.index("C")) - def clip_array(self, array: np.ndarray, **kwargs: typing.Union[int, slice]) -> np.ndarray: + def clip_array(self, array: np.ndarray, **kwargs: int | slice) -> np.ndarray: """ Clip array by axis. Axis is selected by single letter from :py:attr:`axis_order` @@ -542,14 +547,14 @@ def clip_array(self, array: np.ndarray, **kwargs: typing.Union[int, slice]) -> n :return: clipped array """ array = self.fit_array_to_image(array) - slices: typing.List[typing.Union[int, slice]] = [slice(None) for _ in range(len(self.array_axis_order))] + slices: list[int | slice] = [slice(None) for _ in range(len(self.array_axis_order))] axis_pos = self.get_array_axis_positions() for name in kwargs: if (n := name.upper()) in axis_pos: slices[axis_pos[n]] = kwargs[name] return array[tuple(slices)] - def get_channel(self, num: Union[int, str, Channel]) -> np.ndarray: + def get_channel(self, num: int | str | Channel) -> np.ndarray: """ Alias for :py:func:`get_sub_data` with argument ``c=num`` @@ -611,7 +616,7 @@ def set_spacing(self, value: Spacing): self._image_spacing = tuple(value) @staticmethod - def _frame_array(array: typing.Optional[np.ndarray], index_to_add: typing.List[int], frame=FRAME_THICKNESS): + def _frame_array(array: np.ndarray | None, index_to_add: list[int], frame=FRAME_THICKNESS): if array is None: # pragma: no cover return array result_shape = list(array.shape) @@ -626,7 +631,7 @@ def _frame_array(array: typing.Optional[np.ndarray], index_to_add: typing.List[i return data @staticmethod - def calc_index_to_frame(array_axis: str, important_axis: str) -> typing.List[int]: + def calc_index_to_frame(array_axis: str, important_axis: str) -> list[int]: """ calculate in which axis frame should be added @@ -650,7 +655,7 @@ def _frame_cut_area(self, cut_area: typing.Iterable[slice], frame: int): def _cut_image_slices( self, cut_area: typing.Iterable[slice], frame: int - ) -> typing.Tuple[typing.List[np.ndarray], typing.Optional[np.ndarray]]: + ) -> tuple[list[np.ndarray], np.ndarray | None]: new_mask = None cut_area = self._frame_cut_area(cut_area, frame) new_image = [x[tuple(cut_area)] for x in self._channel_arrays] @@ -658,7 +663,7 @@ def _cut_image_slices( new_mask = self._mask_array[tuple(cut_area)] return new_image, new_mask - def _roi_to_slices(self, roi: np.ndarray) -> typing.List[slice]: + def _roi_to_slices(self, roi: np.ndarray) -> list[slice]: cut_area = self.fit_array_to_image(roi) points = np.nonzero(cut_area) lower_bound = np.min(points, axis=1) @@ -689,11 +694,11 @@ def _cut_with_roi(self, cut_area: np.ndarray, replace_mask: bool, frame: int): def cut_image( self, - cut_area: typing.Union[np.ndarray, typing.Iterable[slice]], + cut_area: np.ndarray | typing.Iterable[slice], replace_mask=False, frame: int = FRAME_THICKNESS, zero_out_cut_area: bool = True, - ) -> "Image": + ) -> Image: """ Create new image base on mask or list of slices :param bool replace_mask: if cut area is represented by mask array, @@ -771,7 +776,7 @@ def get_um_shift(self) -> Spacing: """image spacing in micrometers""" return tuple(float(x * 10**6) for x in self.shift) - def get_ranges(self) -> typing.List[typing.Tuple[float, float]]: + def get_ranges(self) -> list[tuple[float, float]]: """image brightness ranges for each channel""" return self.ranges[:]