Skip to content

Commit

Permalink
fix: Fix napari 0.5.0 compatybility (#1116)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced compatibility with Napari version 5.0, including conditional
logic for colormap and layer viewing.

- **Bug Fixes**
- Improved color handling logic for label highlighting in various
widgets.
- Corrected rendering image logic in tests to ensure accurate
assertions.

- **Dependencies**
  - Updated `numpy` dependency to version >= 1.18.5 but < 2.

- **Tests**
- Introduced version-based conditional logic in tests for better Napari
compatibility.
- Adjusted test assertions for image rendering and colormap type
validations.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
Czaki authored Jul 2, 2024
1 parent ecf008e commit 1eb2297
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 19 deletions.
10 changes: 9 additions & 1 deletion package/PartSeg/_roi_mask/image_view.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from importlib.metadata import version

from packaging.version import parse as parse_version
from vispy.app import MouseEvent

from PartSeg._roi_mask.stack_settings import StackSettings
from PartSeg.common_gui.channel_control import ChannelProperty
from PartSeg.common_gui.napari_image_view import ImageInfo, ImageView, LabelEnum

_napari_ge_0_5_0 = parse_version(version("napari")) >= parse_version("0.5.0a1")


class StackImageView(ImageView):
"""
Expand All @@ -14,7 +19,10 @@ class StackImageView(ImageView):

def __init__(self, settings: StackSettings, channel_property: ChannelProperty, name: str):
super().__init__(settings, channel_property, name)
self.viewer_widget.canvas.events.mouse_press.connect(self.component_click)
if _napari_ge_0_5_0:
self.viewer_widget.canvas._scene_canvas.events.mouse_press.connect(self.component_click)
else:
self.viewer_widget.canvas.events.mouse_press.connect(self.component_click)

def refresh_selected(self):
if (
Expand Down
5 changes: 5 additions & 0 deletions package/PartSeg/common_backend/base_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from napari.utils import Colormap
from napari.utils.theme import get_theme
from napari.utils.theme import template as napari_template
from packaging.version import parse as parse_version
from qtpy.QtCore import QObject, Signal
from qtpy.QtWidgets import QMessageBox, QWidget

Expand All @@ -37,6 +38,8 @@
from napari.settings import NapariSettings
logger = logging.getLogger(__name__)

_napari_ge_5 = parse_version(napari.__version__) >= parse_version("0.5.0a1")

DIR_HISTORY = "io.dir_location_history"
FILE_HISTORY = "io.files_open_history"
MULTIPLE_FILES_OPEN_HISTORY = "io.multiple_files_open_history"
Expand Down Expand Up @@ -268,6 +271,8 @@ def theme_name(self) -> str:
@property
def theme(self):
"""Theme as structure."""
if _napari_ge_5:
return get_theme(self.theme_name)
try:
return get_theme(self.theme_name, as_dict=False)
except TypeError: # pragma: no cover
Expand Down
8 changes: 7 additions & 1 deletion package/PartSeg/common_gui/error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import traceback
import typing
from contextlib import suppress
from importlib.metadata import version

import numpy as np
import requests
import sentry_sdk
from napari.settings import get_settings
from napari.utils.theme import get_theme
from packaging.version import parse as parse_version
from qtpy.QtGui import QIcon
from qtpy.QtWidgets import (
QApplication,
Expand Down Expand Up @@ -53,6 +55,7 @@
_FEEDBACK_URL = "https://sentry.io/api/0/projects/{organization_slug}/{project_slug}/user-feedback/".format(
organization_slug="cent", project_slug="partseg"
)
_napari_ge_5 = parse_version(version("napari")) >= parse_version("0.5.0a1")


def _print_traceback(exception, file_):
Expand Down Expand Up @@ -82,7 +85,10 @@ def __init__(self, exception: Exception, description: str, additional_notes: str
self.create_issue_btn = QPushButton("Create issue")
self.cancel_btn = QPushButton("Cancel")
self.error_description = QTextEdit()
theme = get_theme(get_settings().appearance.theme, as_dict=False)
if _napari_ge_5:
theme = get_theme(get_settings().appearance.theme)
else:
theme = get_theme(get_settings().appearance.theme, as_dict=False)
self._highlight = Pylighter(self.error_description.document(), "python", theme.syntax_style)
self.traceback_summary = additional_info
if additional_info is None:
Expand Down
24 changes: 23 additions & 1 deletion package/PartSeg/common_gui/napari_image_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@
from napari._qt.widgets.qt_viewer_buttons import QtViewerPushButton as QtViewerPushButton_
_napari_ge_4_13 = parse_version(napari.__version__) >= parse_version("0.4.13a1")
_napari_ge_4_17 = parse_version(napari.__version__) >= parse_version("0.4.17a1")
_napari_ge_5 = parse_version(napari.__version__) >= parse_version("0.5.0a1")


def get_highlight_colormap():
cmap_dict = {0: (0, 0, 0, 0), 1: "white", None: (0, 0, 0, 0)}
if _napari_ge_5:
from napari.utils.colormaps import DirectLabelColormap

return {"colormap": DirectLabelColormap(color_dict=cmap_dict)}

return {"color": cmap_dict}


class QtViewerPushButton(QtViewerPushButton_):
Expand Down Expand Up @@ -838,8 +849,8 @@ def _mark_layer(self, num: int, flash: bool, image_info: ImageInfo):
component_mark,
scale=image_info.roi.scale,
blending="translucent",
color={0: (0, 0, 0, 0), 1: "white"},
opacity=0.7,
**get_highlight_colormap(),
)
self.viewer.layers.selection.active = active_layer
else:
Expand Down Expand Up @@ -965,6 +976,17 @@ def closeEvent(self, event):
self.close()
super().closeEvent(event)

def _render(self):
if _napari_ge_5:
return self.canvas._scene_canvas.render()
return self.canvas.render()

if _napari_ge_5:

@property
def view(self):
return self.canvas.view


class SearchComponentModal(QtPopup):
def __init__(self, image_view: ImageView, search_type: SearchType, component_num: int, max_components):
Expand Down
10 changes: 9 additions & 1 deletion package/PartSeg/plugins/napari_widgets/lables_control.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from importlib.metadata import version
from typing import List, Sequence

from napari import Viewer
from napari.layers import Labels
from packaging.version import parse as parse_version
from qtpy.QtWidgets import QHBoxLayout, QPushButton, QTabWidget

from PartSeg.common_backend.base_settings import BaseSettings
from PartSeg.common_gui.label_create import LabelChoose, LabelEditor, LabelShow
from PartSeg.plugins.napari_widgets._settings import get_settings

NAPARI_GE_5_0 = parse_version(version("napari")) >= parse_version("0.5.0a1")


class NapariLabelShow(LabelShow):
def __init__(self, viewer: Viewer, name: str, label: List[Sequence[float]], removable, parent=None):
Expand Down Expand Up @@ -36,7 +40,11 @@ def apply_label(self):
):
max_val = layer.data.max()
labels = {i + 1: [x / 255 for x in self.label[i % len(self.label)]] for i in range(max_val + 5)}
layer.color = labels
labels[None] = [0, 0, 0, 0]
if NAPARI_GE_5_0:
layer.colormap = labels
else:
layer.color = labels


class NaparliLabelChoose(LabelChoose):
Expand Down
4 changes: 2 additions & 2 deletions package/PartSeg/plugins/napari_widgets/search_label_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from qtpy.QtCore import QTimer
from vispy.geometry import Rect

from PartSeg.common_gui.napari_image_view import SearchType
from PartSeg.common_gui.napari_image_view import SearchType, get_highlight_colormap
from PartSegCore.roi_info import ROIInfo

HIGHLIGHT_LABEL_NAME = ".Highlight"
Expand Down Expand Up @@ -77,8 +77,8 @@ def _highlight(self):
name=HIGHLIGHT_LABEL_NAME,
scale=labels.scale,
blending="translucent",
color={0: (0, 0, 0, 0), 1: "white"},
opacity=0.7,
**get_highlight_colormap(),
)

def flash_fun(layer_=layer):
Expand Down
10 changes: 5 additions & 5 deletions package/tests/test_PartSeg/test_channel_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,12 @@ def check_parameters(name, index):
@pytest.mark.windows_ci_skip()
def test_image_view_integration(self, qtbot, tmp_path, ch_property, image_view):
image_view.viewer_widget.screenshot(flash=False)
image1 = image_view.viewer_widget.canvas.render()
image1 = image_view.viewer_widget._render()
assert np.any(image1 != 255)
ch_property.minimum_value.setValue(100)
ch_property.maximum_value.setValue(10000)
ch_property.filter_radius.setValue(0.5)
image2 = image_view.viewer_widget.canvas.render()
image2 = image_view.viewer_widget._render()
assert np.any(image2 != 255)

assert np.all(image1 == image2)
Expand All @@ -328,13 +328,13 @@ def check_parameters(name, index):
):
ch_property.fixed.setChecked(True)

image1 = image_view.viewer_widget.canvas.render()
image1 = image_view.viewer_widget._render()
assert np.any(image1 != 255)
with qtbot.waitSignal(image_view.channel_control.coloring_update), qtbot.waitSignal(
image_view.channel_control.change_channel, check_params_cb=check_parameters
):
ch_property.minimum_value.setValue(20)
image2 = image_view.viewer_widget.canvas.render()
image2 = image_view.viewer_widget._render()
assert np.any(image2 != 255)
assert np.any(image1 != image2)

Expand Down Expand Up @@ -392,7 +392,7 @@ def check_parameters(name, index):
):
ch_property.fixed.setChecked(True)
image_view.viewer_widget.screenshot(flash=False)
image1 = image_view.viewer_widget.canvas.render()
image1 = image_view.viewer_widget._render()
with qtbot.waitSignal(image_view.channel_control.coloring_update), qtbot.waitSignal(
image_view.channel_control.change_channel, check_params_cb=check_parameters
):
Expand Down
6 changes: 2 additions & 4 deletions package/tests/test_PartSeg/test_common_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,8 @@ def test_safe_repr(self):
assert PartSegCore.utils.safe_repr(np.arange(3)) == "array([0, 1, 2])"

def test_safe_repr_napari_image(self):
assert (
PartSegCore.utils.safe_repr(napari.layers.Image(np.zeros((10, 10, 5))))
== "<Image of shape: (10, 10, 5), dtype: float64, slice"
" (0, slice(None, None, None), slice(None, None, None))>"
assert PartSegCore.utils.safe_repr(napari.layers.Image(np.zeros((10, 10, 5)))).startswith(
"<Image of shape: (10, 10, 5), dtype: float64, slice"
)


Expand Down
12 changes: 11 additions & 1 deletion package/tests/test_PartSeg/test_napari_image_view.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# pylint: disable=no-self-use
import gc
from functools import partial
from importlib.metadata import version
from unittest.mock import MagicMock

import numpy as np
import pytest
from napari.layers import Image as NapariImage
from napari.qt import QtViewer
from packaging.version import parse as parse_version
from qtpy.QtCore import QPoint
from vispy.geometry import Rect

Expand All @@ -25,6 +27,14 @@
from PartSegCore.roi_info import ROIInfo
from PartSegImage import Image

NAPARI_GE_5_0 = parse_version(version("napari")) >= parse_version("0.5.0a1")


if NAPARI_GE_5_0:
EXPECTED_RANGE = (0, 0, 1)
else:
EXPECTED_RANGE = (0, 1, 1)


def test_image_info():
image_info = ImageInfo(Image(np.zeros((10, 10)), image_spacing=(1, 1), axes_order="XY"), [])
Expand Down Expand Up @@ -246,7 +256,7 @@ def test_marking_component_flash(self, base_settings, image_view, tmp_path, qtbo
assert "timer" in image_view.image_info[str(tmp_path / "test2.tiff")].highlight.metadata
timer = image_view.image_info[str(tmp_path / "test2.tiff")].highlight.metadata["timer"]
assert timer.isActive()
assert image_view.viewer.dims.range[0] == (0, 1, 1)
assert image_view.viewer.dims.range[0] == EXPECTED_RANGE
qtbot.wait(800)
image_view.component_unmark(0)
assert not image_view.image_info[str(tmp_path / "test2.tiff")].highlight.visible
Expand Down
28 changes: 26 additions & 2 deletions package/tests/test_PartSeg/test_napari_widgets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import gc
import json
from importlib.metadata import version
from unittest.mock import patch

import numpy as np
Expand All @@ -10,6 +11,7 @@
from napari.layers import Image as NapariImage
from napari.layers import Labels
from napari.utils import Colormap
from packaging.version import parse as parse_version
from qtpy.QtCore import QObject, QTimer, Signal

from PartSeg._roi_analysis.partseg_settings import PartSettings
Expand Down Expand Up @@ -55,6 +57,28 @@
from PartSegCore.segmentation.threshold import DoubleThresholdSelection, ThresholdSelection
from PartSegCore.segmentation.watershed import WatershedSelection

NAPARI_GE_5_0 = parse_version(version("napari")) >= parse_version("0.5.0a1")

if NAPARI_GE_5_0:

def check_auto_mode(layer):
from napari.utils.colormaps import CyclicLabelColormap

assert isinstance(layer.colormap, CyclicLabelColormap)

def check_direct_mode(layer):
from napari.utils.colormaps import DirectLabelColormap

assert isinstance(layer.colormap, DirectLabelColormap)

else:

def check_auto_mode(layer):
assert layer.color_mode == "auto"

def check_direct_mode(layer):
assert layer.color_mode == "direct"


@pytest.fixture(autouse=True)
def _clean_settings(tmp_path):
Expand Down Expand Up @@ -351,9 +375,9 @@ def test_napari_label_show(viewer_with_data, qtbot):
assert not widget.apply_label_btn.isEnabled()
viewer_with_data.layers.selection.remove(viewer_with_data.layers["image"])
assert widget.apply_label_btn.isEnabled()
assert viewer_with_data.layers["label"].color_mode == "auto"
check_auto_mode(viewer_with_data.layers["label"])
widget.apply_label_btn.click()
assert viewer_with_data.layers["label"].color_mode == "direct"
check_direct_mode(viewer_with_data.layers["label"])


def test_napari_colormap_control(viewer_with_data, qtbot):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ dependencies = [
"mahotas>=1.4.10",
"napari>=0.4.14",
"nme>=0.1.7",
"numpy>=1.18.5",
"numpy>=1.18.5,<2", # mahotas requires rebuild for numpy 2.
"oiffile>=2020.1.18",
"openpyxl>=2.5.7",
"packaging>=20.0",
Expand Down

0 comments on commit 1eb2297

Please sign in to comment.