diff --git a/src/mslice/plotting/plot_window/cut_plot.py b/src/mslice/plotting/plot_window/cut_plot.py index bc0dd5d0..87099e6f 100644 --- a/src/mslice/plotting/plot_window/cut_plot.py +++ b/src/mslice/plotting/plot_window/cut_plot.py @@ -178,21 +178,20 @@ def update_legend(self, line_data=None): handles_to_show = [] handles, labels = axes.get_legend_handles_labels() if line_data is None: - i = 0 - for handle, label in zip(handles, labels): + for i, (handle, label) in enumerate(zip(handles, labels)): if self.legend_visible(i): labels_to_show.append(label) handles_to_show.append(handle) - i += 1 else: - containers = axes.containers - for i in range(len(containers)): - if line_data[i]['legend']: - handles_to_show.append(handles[i]) - labels_to_show.append(line_data[i]['label']) - self._legends_visible[i] = line_data[i]['legend'] - legend = axes.legend(handles_to_show, labels_to_show, fontsize='medium') # add new legends - legend_set_draggable(legend, True) + for line, handle in zip(line_data, handles): + if line['legend']: + handles_to_show.append(handle) + labels_to_show.append(line['label']) + self._legends_visible = [line['legend'] for line in line_data] + + if self._legends_shown: + legend = axes.legend(handles_to_show, labels_to_show, fontsize='medium') # add new legends + legend_set_draggable(legend, True) def change_axis_scale(self, xy_config): current_axis = self._canvas.figure.gca() diff --git a/tests/cut_plot_test.py b/tests/cut_plot_test.py index 55ac74dd..7a092561 100644 --- a/tests/cut_plot_test.py +++ b/tests/cut_plot_test.py @@ -85,13 +85,39 @@ def test_object_clicked(self, quick_options_mock): self.cut_plot.update_legend.assert_called_once() self.cut_plot._canvas.draw.assert_called_once() - def test_update_legend(self): + def test_update_legend_legends_not_shown(self): line = Line2D([], []) self.axes.get_legend_handles_labels = MagicMock(return_value=([line], ['some_label'])) + self.cut_plot._legends_shown = False + self.cut_plot.update_legend() + self.assertTrue(self.cut_plot._legends_visible[0]) + self.axes.legend.assert_not_called() + + def test_update_legend_legends_shown(self): + line = Line2D([], []) + self.axes.get_legend_handles_labels = MagicMock(return_value=([line], ['some_label'])) + self.cut_plot._legends_shown = True self.cut_plot.update_legend() self.assertTrue(self.cut_plot._legends_visible[0]) self.axes.legend.assert_called_with([line], ['some_label'], fontsize=ANY) + def test_update_legend_with_line_data(self): + line_data = [ + {'shown': True, 'legend': 2, 'label': 'visible_line_data_label'}, + {'shown': True, 'legend': 0, 'label': 'non_visible_line_data_label'} + ] + mock_line = Line2D([], []) + another_mock_line = Line2D([], []) + + self.axes.get_legend_handles_labels = MagicMock(return_value=( + [mock_line, another_mock_line], ['mock_label', 'another_mock_label'] + )) + self.cut_plot._legends_shown = True + + self.cut_plot.update_legend(line_data) + self.assertEqual(self.cut_plot._legends_visible, [2, 0]) + self.axes.legend.assert_called_with([mock_line], ['visible_line_data_label'], fontsize=ANY) + def test_waterfall(self): self.cut_plot._apply_offset = MagicMock() self.cut_plot.update_bragg_peaks = MagicMock()