diff --git a/src/napari_matplotlib/histogram.py b/src/napari_matplotlib/histogram.py index c853988..f78a850 100644 --- a/src/napari_matplotlib/histogram.py +++ b/src/napari_matplotlib/histogram.py @@ -1,8 +1,9 @@ -from typing import Any, Optional +from typing import Any, Optional, cast import napari import numpy as np import numpy.typing as npt +from matplotlib.container import BarContainer from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget from .base import SingleAxesWidget @@ -162,12 +163,39 @@ def on_update_layers(self) -> None: def draw(self) -> None: """Clear the axes and histogram the currently selected layer/slice.""" + # get the colormap from the layer depending on its type + if isinstance(self.layers[0], napari.layers.Points): + colormap = self.layers[0].face_colormap + self.layers[0].face_color = self.x_axis_key + elif isinstance(self.layers[0], napari.layers.Vectors): + colormap = self.layers[0].edge_colormap + self.layers[0].edge_color = self.x_axis_key + else: + colormap = None + + # apply new colors to the layer + self.viewer.layers[self.layers[0].name].refresh_colors(True) + self.viewer.layers[self.layers[0].name].refresh() + + # Draw the histogram data, x_axis_name = self._get_data() if data is None: return - self.axes.hist(data, bins=50, edgecolor="white", linewidth=0.3) + _, bins, patches = self.axes.hist( + data, bins=50, edgecolor="white", linewidth=0.3 + ) + patches = cast(BarContainer, patches) + + # recolor the histogram plot + if colormap is not None: + self.bins_norm = (bins - bins.min()) / (bins.max() - bins.min()) + colors = colormap.map(self.bins_norm) + + # Set histogram style: + for idx, patch in enumerate(patches): + patch.set_facecolor(colors[idx]) # set ax labels self.axes.set_xlabel(x_axis_name) diff --git a/src/napari_matplotlib/tests/baseline/test_feature_histogram2.png b/src/napari_matplotlib/tests/baseline/test_feature_histogram2.png deleted file mode 100644 index c53a890..0000000 Binary files a/src/napari_matplotlib/tests/baseline/test_feature_histogram2.png and /dev/null differ diff --git a/src/napari_matplotlib/tests/baseline/test_feature_histogram_points.png b/src/napari_matplotlib/tests/baseline/test_feature_histogram_points.png new file mode 100644 index 0000000..b98a017 Binary files /dev/null and b/src/napari_matplotlib/tests/baseline/test_feature_histogram_points.png differ diff --git a/src/napari_matplotlib/tests/baseline/test_feature_histogram_vectors.png b/src/napari_matplotlib/tests/baseline/test_feature_histogram_vectors.png new file mode 100644 index 0000000..3b90586 Binary files /dev/null and b/src/napari_matplotlib/tests/baseline/test_feature_histogram_vectors.png differ diff --git a/src/napari_matplotlib/tests/test_histogram.py b/src/napari_matplotlib/tests/test_histogram.py index 006c042..1ceca51 100644 --- a/src/napari_matplotlib/tests/test_histogram.py +++ b/src/napari_matplotlib/tests/test_histogram.py @@ -38,6 +38,8 @@ def test_histogram_3D(make_napari_viewer, brain_data): def test_feature_histogram(make_napari_viewer): n_points = 1000 random_points = np.random.random((n_points, 3)) * 10 + random_directions = np.random.random((n_points, 3)) * 10 + random_vectors = np.stack([random_points, random_directions], axis=1) feature1 = np.random.random(n_points) feature2 = np.random.normal(size=n_points) @@ -47,10 +49,10 @@ def test_feature_histogram(make_napari_viewer): properties={"feature1": feature1, "feature2": feature2}, name="points1", ) - viewer.add_points( - random_points, + viewer.add_vectors( + random_vectors, properties={"feature1": feature1, "feature2": feature2}, - name="points2", + name="vectors1", ) widget = FeaturesHistogramWidget(viewer) @@ -70,26 +72,42 @@ def test_feature_histogram(make_napari_viewer): @pytest.mark.mpl_image_compare -def test_feature_histogram2(make_napari_viewer): - import numpy as np +def test_feature_histogram_vectors(make_napari_viewer): + n_points = 1000 + np.random.seed(42) + random_points = np.random.random((n_points, 3)) * 10 + random_directions = np.random.random((n_points, 3)) * 10 + random_vectors = np.stack([random_points, random_directions], axis=1) + feature1 = np.random.random(n_points) + + viewer = make_napari_viewer() + viewer.add_vectors( + random_vectors, + properties={"feature1": feature1}, + name="vectors1", + ) + + widget = FeaturesHistogramWidget(viewer) + viewer.window.add_dock_widget(widget) + widget._set_axis_keys("feature1") + fig = FeaturesHistogramWidget(viewer).figure + return deepcopy(fig) + + +@pytest.mark.mpl_image_compare +def test_feature_histogram_points(make_napari_viewer): np.random.seed(0) n_points = 1000 random_points = np.random.random((n_points, 3)) * 10 feature1 = np.random.random(n_points) - feature2 = np.random.normal(size=n_points) viewer = make_napari_viewer() viewer.add_points( random_points, - properties={"feature1": feature1, "feature2": feature2}, + properties={"feature1": feature1}, name="points1", ) - viewer.add_points( - random_points, - properties={"feature1": feature1, "feature2": feature2}, - name="points2", - ) widget = FeaturesHistogramWidget(viewer) viewer.window.add_dock_widget(widget)