Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add option to combine channels using sum and max #1159

Merged
merged 3 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion package/PartSeg/_roi_analysis/partseg_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ def set_project_info(self, data: typing.Union[ProjectTuple, MaskInfo, PointsInfo
return
if not isinstance(data, ProjectTuple):
return
if self.image.file_path == data.image.file_path and self.image.shape == data.image.shape:
if (
self.image.file_path == data.image.file_path
and self.image.shape == data.image.shape
and self.image.channels == data.image.channels
):
if data.roi_info.roi is not None:
try:
self.image.fit_array_to_image(data.roi_info.roi)
Expand Down
2 changes: 1 addition & 1 deletion package/PartSeg/common_gui/image_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, image: Image, transform_dict: Optional[Dict[str, TransformBas
for key, val in transform_dict.items():
self.choose.addItem(key)
initial_values = val.calculate_initial(image)
form_widget = FormWidget(val.get_fields_per_dimension(list(image.get_dimension_letters())), initial_values)
form_widget = FormWidget(val.get_fields_per_dimension(image), initial_values)
self.stacked.addWidget(form_widget)

self.choose.currentIndexChanged.connect(self.stacked.setCurrentIndex)
Expand Down
5 changes: 3 additions & 2 deletions package/PartSegCore/image_transforming/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from PartSegCore.algorithm_describe_base import Register
from PartSegCore.image_transforming.combine_channels import CombineChannels
from PartSegCore.image_transforming.image_projection import ImageProjection
from PartSegCore.image_transforming.interpolate_image import InterpolateImage
from PartSegCore.image_transforming.swap_time_stack import SwapTimeStack
from PartSegCore.image_transforming.transform_base import TransformBase

image_transform_dict = Register(InterpolateImage, SwapTimeStack, ImageProjection)
image_transform_dict = Register(CombineChannels, ImageProjection, InterpolateImage, SwapTimeStack)

__all__ = ("image_transform_dict", "InterpolateImage", "TransformBase", "ImageProjection")
__all__ = ("image_transform_dict", "ImageProjection", "InterpolateImage", "SwapTimeStack", "TransformBase")
60 changes: 60 additions & 0 deletions package/PartSegCore/image_transforming/combine_channels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from enum import Enum, auto
from typing import Callable, List, Optional, Tuple, Union

import numpy as np

from PartSegCore.algorithm_describe_base import AlgorithmProperty
from PartSegCore.image_transforming.transform_base import TransformBase
from PartSegCore.roi_info import ROIInfo
from PartSegImage import Image


class CombineMode(Enum):
Max = auto()
Sum = auto()


class CombineChannels(TransformBase):
@classmethod
def get_fields(cls):
return [AlgorithmProperty("combine_mode", "Combine Mode", CombineMode.Sum)]

@classmethod
def get_fields_per_dimension(cls, image: Image) -> List[Union[str, AlgorithmProperty]]:
return [
AlgorithmProperty("combine_mode", "Combine Mode", CombineMode.Sum),
*[AlgorithmProperty(f"channel_{i}", f"Channel {i}", False) for i in range(image.channels)],
]

@classmethod
def get_name(cls):
return "Combine channels"

@classmethod
def transform(
cls,
image: Image,
roi_info: Optional[ROIInfo],
arguments: dict,
callback_function: Optional[Callable[[str, int], None]] = None,
) -> Tuple[Image, Optional[ROIInfo]]:
channels = [i for i, x in enumerate(x for x in arguments.items() if x[0].startswith("channel")) if x[1]]
if not channels:
return image, roi_info
channel_array = [image.get_channel(i) for i in channels]
if arguments["combine_mode"] == CombineMode.Max:
new_channel = np.max(channel_array, axis=0)

Check warning on line 46 in package/PartSegCore/image_transforming/combine_channels.py

View check run for this annotation

Codecov / codecov/patch

package/PartSegCore/image_transforming/combine_channels.py#L41-L46

Added lines #L41 - L46 were not covered by tests
else:
new_channel = np.sum(channel_array, axis=0)
all_channels = [image.get_channel(i) for i in range(image.channels)]
all_channels.append(new_channel)
channel_names = [*image.channel_names, "combined"]
contrast_limits = [*image.get_ranges(), (np.min(new_channel), np.max(new_channel))]
return image.substitute(data=all_channels, channel_names=channel_names, ranges=contrast_limits), roi_info

Check warning on line 53 in package/PartSegCore/image_transforming/combine_channels.py

View check run for this annotation

Codecov / codecov/patch

package/PartSegCore/image_transforming/combine_channels.py#L48-L53

Added lines #L48 - L53 were not covered by tests

@classmethod
def calculate_initial(cls, image: Image):
min_val = min(image.spacing)
return {
f"scale_{letter}": x / min_val for x, letter in zip(image.spacing, image.get_dimension_letters().lower())
}
4 changes: 2 additions & 2 deletions package/PartSegCore/image_transforming/image_projection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Callable, List, Optional, Tuple
from typing import Callable, Optional, Tuple

import numpy as np
from pydantic import Field
Expand Down Expand Up @@ -67,7 +67,7 @@ def transform(
)

@classmethod
def get_fields_per_dimension(cls, component_list: List[str]):
def get_fields_per_dimension(cls, image: Image):
return cls.__argument_class__

@classmethod
Expand Down
3 changes: 2 additions & 1 deletion package/PartSegCore/image_transforming/interpolate_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def get_fields(cls):
return ["It can be very slow.", AlgorithmProperty("scale", "Scale", 1.0)]

@classmethod
def get_fields_per_dimension(cls, component_list: List[str]) -> List[Union[str, AlgorithmProperty]]:
def get_fields_per_dimension(cls, image: Image) -> List[Union[str, AlgorithmProperty]]:
component_list = list(image.get_dimension_letters())
return [
"it can be very slow",
*[AlgorithmProperty(f"scale_{i.lower()}", f"Scale {i}", 1.0) for i in reversed(component_list)],
Expand Down
2 changes: 1 addition & 1 deletion package/PartSegCore/image_transforming/swap_time_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def transform(
return image.swap_time_and_stack(), None

@classmethod
def get_fields_per_dimension(cls, component_list: typing.List[str]):
def get_fields_per_dimension(cls, image: Image):
return cls.get_fields()

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion package/PartSegCore/image_transforming/transform_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def transform(
raise NotImplementedError

@classmethod
def get_fields_per_dimension(cls, component_list: List[str]) -> List[Union[str, AlgorithmProperty]]:
def get_fields_per_dimension(cls, image: Image) -> List[Union[str, AlgorithmProperty]]:
raise NotImplementedError

@classmethod
Expand Down
Loading