Skip to content

Commit

Permalink
fix: Fix checking if channel requested by MeasurementProfile exists (#…
Browse files Browse the repository at this point in the history
…1165)

fix PARTSEG-VA

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Summary by CodeRabbit

- **New Features**
- Introduced a method to check the existence of channels, enhancing
validation during measurements.
- Updated handling of channel identifiers to accept both `Channel`
objects and strings for improved flexibility.
- Improved dialog functionality by allowing it to have a parent widget
for better user experience.
- Enhanced the measurement profile structure for clarity and
functionality, focusing on ROI measurements.

- **Bug Fixes**
- Enhanced error messaging for channel validity to align with user
input, improving usability.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
Czaki authored Jul 23, 2024
1 parent d87bd21 commit 501574e
Show file tree
Hide file tree
Showing 17 changed files with 467 additions and 192 deletions.
2 changes: 1 addition & 1 deletion package/PartSeg/_roi_analysis/advanced_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ def import_measurement_profiles(self):
if err:
QMessageBox.warning(self, "Import error", "error during importing, part of data were filtered.")
measurement_dict = self.settings.measurement_profiles
imp = ImportDialog(stat, measurement_dict, StringViewer, MeasurementProfile)
imp = ImportDialog(stat, measurement_dict, StringViewer, MeasurementProfile, parent=self)
if not imp.exec_():
return
for original_name, final_name in imp.get_import_list():
Expand Down
4 changes: 2 additions & 2 deletions package/PartSeg/_roi_analysis/measurement_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,11 @@ def append_measurement_result(self):
units = self.units_choose.currentEnum()

for num in compute_class.get_channels_num():
if num >= self.settings.image.channels:
if not self.settings.image.has_channel(num):
QMessageBox.warning(
self,
"Measurement error",
f"Cannot calculate this measurement because image do not have channel {num+1}",
f"Cannot calculate this measurement because image do not have channel {num}",
)
return

Expand Down
2 changes: 1 addition & 1 deletion package/PartSeg/common_gui/napari_image_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def mask_opacity(self) -> float:
def mask_color(self) -> ColorInfo:
"""Get mask marking color"""
color = Color(np.divide(self.settings.get_from_profile("mask_presentation_color", [255, 255, 255]), 255))
return {0: (0, 0, 0, 0), 1: color.rgba}
return {0: (0, 0, 0, 0), 1: color.rgba, None: (0, 0, 0, 0)}

def get_image(self, image: Optional[Image]) -> Image:
if image is not None:
Expand Down
7 changes: 5 additions & 2 deletions package/PartSeg/plugins/napari_widgets/measurement_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from PartSeg.plugins.napari_widgets._settings import get_settings
from PartSeg.plugins.napari_widgets.utils import NapariFormDialog, generate_image
from PartSegCore.roi_info import ROIInfo
from PartSegImage import Channel

if TYPE_CHECKING:
from PartSegCore.analysis.measurement_calculation import MeasurementProfile, MeasurementResult
Expand Down Expand Up @@ -58,11 +59,13 @@ def append_measurement_result(self):
if self.channels_chose.value is None:
return
for name in compute_class.get_channels_num():
if name not in self.napari_viewer.layers:
if name.value not in self.napari_viewer.layers:
show_info(f"Cannot calculate this measurement because image do not have layer {name}")
return
units = self.units_choose.currentEnum()
image = generate_image(self.napari_viewer, self.channels_chose.value.name, *compute_class.get_channels_num())
image = generate_image(
self.napari_viewer, Channel(self.channels_chose.value.name), *compute_class.get_channels_num()
)
if self.mask_chose.value is not None:
image.set_mask(self.mask_chose.value.data)
roi_info = ROIInfo(self.roi_chose.value.data).fit_to_image(image)
Expand Down
4 changes: 3 additions & 1 deletion package/PartSeg/plugins/napari_widgets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ def __init__(self, *args, **kwargs):
def generate_image(viewer: Viewer, *layer_names):
axis_order = Image.axis_order.replace("C", "")
image_list = []
if isinstance(layer_names[0], str):
layer_names = [Channel(el) for el in layer_names]
for name in dict.fromkeys(layer_names):
image_layer = viewer.layers[name]
image_layer = viewer.layers[name.value]
data_scale = image_layer.scale[-3:] / UNIT_SCALE[Units.nm.value]
image_list.append(
Image(
Expand Down
8 changes: 7 additions & 1 deletion package/PartSegCore/analysis/measurement_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,13 @@ def calculate_property(

@classmethod
def get_starting_leaf(cls) -> Leaf:
"""This leaf is put on default list"""
"""This leaf is put on a default list"""
if (
hasattr(cls, "__argument_class__")
and cls.__argument_class__ is not None
and cls.__argument_class__ is not BaseModel
):
return Leaf(name=cls._display_name(), parameters=cls.__argument_class__())
return Leaf(name=cls._display_name())

@classmethod
Expand Down
15 changes: 13 additions & 2 deletions package/PartSegImage/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,11 @@ def get_data_by_axis(self, **kwargs) -> np.ndarray:
axis_pos = self.get_array_axis_positions()
if "c" in kwargs:
kwargs["C"] = kwargs.pop("c")
if "C" in kwargs and isinstance(kwargs["C"], str):
kwargs["C"] = self.channel_names.index(kwargs["C"])
if "C" in kwargs:
if isinstance(kwargs["C"], Channel):
kwargs["C"] = kwargs["C"].value
if isinstance(kwargs["C"], str):
kwargs["C"] = self.channel_names.index(kwargs["C"])

channel = kwargs.pop("C", slice(None) if "C" in self.axis_order else 0)
if isinstance(channel, Channel):
Expand Down Expand Up @@ -571,6 +574,14 @@ def get_channel(self, num: int | str | Channel) -> np.ndarray:
"""
return self.get_data_by_axis(c=num)

def has_channel(self, num: int | str | Channel) -> bool:
if isinstance(num, Channel):
num = num.value

if isinstance(num, str):
return num in self.channel_names
return 0 <= num < self.channels

def get_layer(self, time: int, stack: int) -> np.ndarray:
"""
return single layer contains data for all channel
Expand Down
25 changes: 20 additions & 5 deletions package/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from PartSegCore.algorithm_describe_base import ROIExtractionProfile
from PartSegCore.analysis import ProjectTuple, SegmentationPipeline, SegmentationPipelineElement
from PartSegCore.analysis.measurement_base import AreaType, MeasurementEntry, PerComponent
from PartSegCore.analysis.measurement_calculation import ComponentsNumber, MeasurementProfile, Volume
from PartSegCore.analysis.measurement_calculation import (
ColocalizationMeasurement,
ComponentsNumber,
MeasurementProfile,
Volume,
)
from PartSegCore.image_operations import RadiusType
from PartSegCore.mask.io_functions import MaskProjectTuple
from PartSegCore.mask_create import MaskProperty
Expand Down Expand Up @@ -69,15 +74,15 @@ def image2d(tmp_path):

@pytest.fixture()
def stack_image():
data = np.zeros([20, 40, 40], dtype=np.uint8)
data = np.zeros([20, 40, 40, 2], dtype=np.uint8)
for x, y in itertools.product([0, 20], repeat=2):
data[1:-1, x + 2 : x + 18, y + 2 : y + 18] = 100
for x, y in itertools.product([0, 20], repeat=2):
data[3:-3, x + 4 : x + 16, y + 4 : y + 16] = 120
for x, y in itertools.product([0, 20], repeat=2):
data[5:-5, x + 6 : x + 14, y + 6 : y + 14] = 140

return MaskProjectTuple("test_path", Image(data, (2, 1, 1), axes_order="ZYX", file_path="test_path"))
return MaskProjectTuple("test_path", Image(data, (2, 1, 1), axes_order="ZYXC", file_path="test_path"))


@pytest.fixture()
Expand Down Expand Up @@ -201,8 +206,18 @@ def measurement_profiles():
calculation_tree=Volume.get_starting_leaf().replace_(area=AreaType.Mask, per_component=PerComponent.No),
),
]
return MeasurementProfile(name="statistic1", chosen_fields=statistics), MeasurementProfile(
name="statistic2", chosen_fields=statistics + statistics2
statistics3 = [
MeasurementEntry(
name="Colocalisation",
calculation_tree=ColocalizationMeasurement.get_starting_leaf().replace_(
per_component=PerComponent.No, area=AreaType.ROI
),
),
]
return (
MeasurementProfile(name="statistic1", chosen_fields=statistics),
MeasurementProfile(name="statistic2", chosen_fields=statistics + statistics2),
MeasurementProfile(name="statistic3", chosen_fields=statistics + statistics2 + statistics3),
)


Expand Down
8 changes: 4 additions & 4 deletions package/tests/test_PartSeg/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@


@pytest.fixture()
def base_settings(image, tmp_path, measurement_profiles, qapp):
def base_settings(image, tmp_path, measurement_profiles):
settings = BaseSettings(tmp_path)
settings.image = image
return settings


@pytest.fixture()
def part_settings(image, tmp_path, measurement_profiles, qapp):
def part_settings(image, tmp_path, measurement_profiles):
settings = PartSettings(tmp_path)
settings.image = image
for el in measurement_profiles:
Expand All @@ -28,7 +28,7 @@ def part_settings(image, tmp_path, measurement_profiles, qapp):


@pytest.fixture()
def stack_settings(tmp_path, image, qapp):
def stack_settings(tmp_path, image):
settings = StackSettings(tmp_path)
settings.image = image
chose = ChosenComponents()
Expand All @@ -38,7 +38,7 @@ def stack_settings(tmp_path, image, qapp):


@pytest.fixture()
def part_settings_with_project(image, analysis_segmentation2, tmp_path, qapp):
def part_settings_with_project(image, analysis_segmentation2, tmp_path):
settings = PartSettings(tmp_path)
settings.image = image
settings.set_project_info(analysis_segmentation2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ def test_base_steep(self, qtbot, part_settings):
widget.profile_name.setText("test")
assert widget.save_butt.isEnabled()

assert len(part_settings.measurement_profiles) == 2
assert len(part_settings.measurement_profiles) == 3
with qtbot.waitSignal(widget.save_butt.clicked):
widget.save_butt.click()
assert len(part_settings.measurement_profiles) == 3
assert len(part_settings.measurement_profiles) == 4

with qtbot.waitSignal(widget.profile_name.textChanged):
widget.profile_name.setText("")
Expand Down
23 changes: 20 additions & 3 deletions package/tests/test_PartSeg/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_missed_mask(self, qmessagebox_path, qtbot, analysis_segmentation, part_
widget = MeasurementWidget(part_settings)
qtbot.addWidget(widget)

assert widget.measurement_type.count() == 3
assert widget.measurement_type.count() == 4
part_settings.set_project_info(analysis_segmentation)

with qtbot.waitSignal(widget.measurement_type.currentIndexChanged):
Expand All @@ -34,7 +34,7 @@ def test_base(self, qtbot, analysis_segmentation, part_settings):
widget = MeasurementWidget(part_settings)
qtbot.addWidget(widget)

assert widget.measurement_type.count() == 3
assert widget.measurement_type.count() == 4
part_settings.set_project_info(analysis_segmentation)
widget.measurement_type.setCurrentIndex(1)
assert widget.recalculate_button.isEnabled()
Expand All @@ -52,7 +52,7 @@ def test_base2(self, qtbot, analysis_segmentation2, part_settings):
widget = MeasurementWidget(part_settings)
qtbot.addWidget(widget)

assert widget.measurement_type.count() == 3
assert widget.measurement_type.count() == 4
part_settings.set_project_info(analysis_segmentation2)
widget.measurement_type.setCurrentIndex(2)
assert widget.recalculate_button.isEnabled()
Expand All @@ -64,6 +64,23 @@ def test_base2(self, qtbot, analysis_segmentation2, part_settings):
assert widget.info_field.columnCount() == 3
assert widget.info_field.rowCount() == 2

@pytest.mark.enablethread()
@pytest.mark.enabledialog()
def test_base_channels(self, qtbot, analysis_segmentation2, part_settings):
widget = MeasurementWidget(part_settings)
qtbot.addWidget(widget)

part_settings.set_project_info(analysis_segmentation2)
widget.measurement_type.setCurrentIndex(3)
assert widget.recalculate_button.isEnabled()
widget.recalculate_button.click()
assert widget.info_field.columnCount() == 2
assert widget.info_field.rowCount() == 4
assert widget.info_field.item(1, 1).text() == "4"
widget.horizontal_measurement_present.setChecked(True)
assert widget.info_field.columnCount() == 4
assert widget.info_field.rowCount() == 2


class TestSimpleMeasurementsWidget:
@pytest.mark.enablethread()
Expand Down
15 changes: 13 additions & 2 deletions package/tests/test_PartSeg/test_napari_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import gc
import json
from importlib.metadata import version
from unittest.mock import patch
from unittest.mock import Mock, patch

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -223,16 +223,23 @@ def test_simple_measurement_create(make_napari_viewer, qtbot):

@pytest.mark.enablethread()
@pytest.mark.enabledialog()
def test_measurement_create(make_napari_viewer, qtbot, bundle_test_dir):
@pytest.mark.usefixtures("qtbot")
def test_measurement_create(make_napari_viewer, bundle_test_dir, monkeypatch):
from PartSeg.plugins.napari_widgets.measurement_widget import Measurement

monkeypatch.setattr(
"PartSeg.plugins.napari_widgets.measurement_widget.show_info",
Mock(side_effect=RuntimeError("should not be called")),
)

data = np.zeros((10, 10), dtype=np.uint8)
data[2:5, 2:-2] = 1
data[5:-2, 2:-2] = 2

viewer = make_napari_viewer()
viewer.add_labels(data, name="label")
viewer.add_image(data, name="image")
viewer.add_image(data, name="image2")
measurement = Measurement(viewer)
viewer.window.add_dock_widget(measurement)
measurement.reset_choices()
Expand All @@ -243,7 +250,11 @@ def test_measurement_create(make_napari_viewer, qtbot, bundle_test_dir):
assert measurement.measurement_widget.measurement_type.currentText() == "test"
assert measurement.measurement_widget.recalculate_button.isEnabled()
assert measurement.measurement_widget.check_if_measurement_can_be_calculated("test") == "test"
assert measurement.measurement_widget.info_field.rowCount() == 0
assert measurement.measurement_widget.info_field.columnCount() == 3
measurement.measurement_widget.append_measurement_result()
assert measurement.measurement_widget.info_field.rowCount() == 8
assert measurement.measurement_widget.info_field.columnCount() == 3


def test_update_properties():
Expand Down
2 changes: 1 addition & 1 deletion package/tests/test_PartSeg/test_roi_analysis_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def check_measurement(measurement: prepare_plan_widget.MeasurementCalculate):

widget = prepare_plan_widget.SelectMeasurementOp(part_settings)
qtbot.addWidget(widget)
assert widget.measurements_list.count() == 2
assert widget.measurements_list.count() == 3
with qtbot.assert_not_emitted(widget.set_of_measurement_add):
widget._measurement_add()
widget.measurements_list.setCurrentRow(0)
Expand Down
6 changes: 3 additions & 3 deletions package/tests/test_PartSeg/test_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ def test_base(self, image, analysis_segmentation2, tmp_path):
assert len(viewer.layers) == 2
settings.image = analysis_segmentation2.image
viewer.create_initial_layers(True, True, True, True)
assert len(viewer.layers) == 1
assert len(viewer.layers) == 2
settings.roi = analysis_segmentation2.roi_info.roi
viewer.create_initial_layers(True, True, True, True)
assert len(viewer.layers) == 2
assert len(viewer.layers) == 3
settings.mask = analysis_segmentation2.mask
viewer.create_initial_layers(True, True, True, True)
assert len(viewer.layers) == 3
assert len(viewer.layers) == 4
viewer.close()

def test_points(self, image, tmp_path, qtbot):
Expand Down
8 changes: 4 additions & 4 deletions package/tests/test_PartSegCore/test_napari_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@
def test_project_to_layers_analysis(analysis_segmentation):
analysis_segmentation.roi_info.alternative["test"] = np.zeros(analysis_segmentation.image.shape, dtype=np.uint8)
res = project_to_layers(analysis_segmentation)
assert len(res) == 3
assert len(res) == 4
l1 = Layer.create(*res[0])
assert isinstance(l1, Image)
assert l1.name == "channel 1"
assert np.allclose(l1.scale[1:] / 1e9, analysis_segmentation.image.spacing)
l2 = Layer.create(*res[1])
l2 = Layer.create(*res[2])
assert isinstance(l2, Labels)
assert l2.name == "ROI"
assert np.allclose(l2.scale[1:] / 1e9, analysis_segmentation.image.spacing)
l3 = Layer.create(*res[2])
l3 = Layer.create(*res[3])
assert isinstance(l3, Labels)
assert l3.name == "test"
assert np.allclose(l3.scale[1:] / 1e9, analysis_segmentation.image.spacing)
Expand All @@ -55,7 +55,7 @@ def test_project_to_layers_roi():

def test_project_to_layers_mask(stack_segmentation1):
res = project_to_layers(stack_segmentation1)
assert len(res) == 2
assert len(res) == 3
assert res[0][2] == "image"


Expand Down
10 changes: 9 additions & 1 deletion package/tests/test_PartSegImage/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,17 @@ def test_get_channel(self):
image = self.image_class(
np.zeros((1, 10, 20, 30, 3), np.uint8), (1, 1, 1), "", axes_order="TZYXC", channel_names=["a", "b", "c"]
)
assert image.has_channel(1)
assert image.has_channel(Channel(1))
assert not image.has_channel(5)
assert not image.has_channel(Channel(5))
channel = image.get_channel(1)
assert channel.shape == self.mask_shape((1, 10, 20, 30), "TZYX")
channel = image.get_channel("b")
assert image.has_channel("b")
assert image.has_channel(Channel("b"))
assert not image.has_channel("d")
assert not image.has_channel(Channel("d"))
channel = image.get_channel(Channel("b"))
assert channel.shape == self.mask_shape((1, 10, 20, 30), "TZYX")
with pytest.raises(IndexError):
image.get_channel(4)
Expand Down
Loading

0 comments on commit 501574e

Please sign in to comment.