Skip to content

Commit

Permalink
Fix slice save for intensity types #947. Add test.
Browse files Browse the repository at this point in the history
  • Loading branch information
mducle committed Oct 9, 2023
1 parent 6100698 commit 198379e
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/mslice/models/workspacemanager/workspace_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def export_workspace_to_ads(workspace):
def _save_single_ws(workspace, save_name, save_method, path, extension, slice_nonpsd):
save_as = save_name if save_name is not None else str(workspace) + extension
full_path = os.path.join(str(path), save_as)
if isinstance(workspace, str):
if isinstance(workspace, string_types):
workspace = get_workspace_handle(workspace)
save_method(workspace, full_path)

Expand Down
2 changes: 1 addition & 1 deletion src/mslice/plotting/plot_window/plot_figure_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def save_plot(self):
workspaces = self.plot_handler.ws_list
else:
if isinstance(self.plot_handler, SlicePlot):
workspaces = [self.plot_handler.get_slice_cache().scattering_function]
workspaces = [self.plot_handler.get_cached_workspace()]
else:
workspaces = [self.plot_handler.ws_name]
try:
Expand Down
4 changes: 4 additions & 0 deletions src/mslice/plotting/plot_window/slice_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,10 @@ def flip_icut(self):
def get_slice_cache(self):
return self._slice_plotter_presenter.get_slice_cache(self.ws_name)

def get_cached_workspace(self):
cached_slice = self.get_slice_cache()
return getattr(cached_slice, IntensityCache.get_slice_type(self.intensity_type))

def update_workspaces(self):
self._slice_plotter_presenter.update_displayed_workspaces()

Expand Down
8 changes: 8 additions & 0 deletions src/mslice/presenters/slice_plotter_presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SlicePlotterPresenter(PresenterUtility):
def __init__(self):
self._main_presenter = None
self._slice_cache = {}
self._cache_intensity_correction_methods()

def plot_slice(self, selected_ws, x_axis, y_axis, intensity_start, intensity_end, norm_to_one, colourmap):
workspace = get_workspace_handle(selected_ws)
Expand Down Expand Up @@ -123,3 +124,10 @@ def _cache_intensity_correction_methods(self):
self.show_dynamical_susceptibility_magnetic)
IntensityCache.cache_method(cat, IntensityType.D2SIGMA, self.show_d2sigma)
IntensityCache.cache_method(cat, IntensityType.SYMMETRISED, self.show_symmetrised)
IntensityCache.cache_method(cat, IntensityType.GDOS, self.show_gdos)
IntensityCache.cache_slice_type(IntensityType.SCATTERING_FUNCTION, "scattering_function")
IntensityCache.cache_slice_type(IntensityType.CHI, "chi")
IntensityCache.cache_slice_type(IntensityType.CHI_MAGNETIC, "chi_magnetic")
IntensityCache.cache_slice_type(IntensityType.D2SIGMA, "d2sigma")
IntensityCache.cache_slice_type(IntensityType.SYMMETRISED, "symmetrised")
IntensityCache.cache_slice_type(IntensityType.GDOS, "gdos")
12 changes: 12 additions & 0 deletions src/mslice/util/intensity_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class IntensityCache:
__action_dict = {}
__method_dict_cut = {}
__method_dict_slice = {}
__slice_cache_type = {}
__description_dict = {IntensityType.SCATTERING_FUNCTION: "scattering_function",
IntensityType.CHI: "dynamical_susceptibility",
IntensityType.CHI_MAGNETIC: "dynamical_susceptibility_magnetic",
Expand Down Expand Up @@ -62,6 +63,17 @@ def get_method(cls, category, intensity_correction_type):
else:
raise KeyError("method related to the intensity correction type not found")

@classmethod
def cache_slice_type(cls, intensity_correction_type, name):
if intensity_correction_type not in cls.__slice_cache_type:
cls.__slice_cache_type[intensity_correction_type] = name

@classmethod
def get_slice_type(cls, intensity_correction_type):
if intensity_correction_type not in cls.__slice_cache_type:
raise KeyError("intensity correction cached type not found")
return cls.__slice_cache_type[intensity_correction_type]

@classmethod
def get_intensity_type_from_desc(cls, description):
if description in cls._IntensityCache__description_dict.values():
Expand Down
67 changes: 67 additions & 0 deletions tests/plot_figure_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import (absolute_import, division, print_function)
import unittest
from unittest import mock
from unittest.mock import patch

from mslice.presenters.slice_plotter_presenter import SlicePlotterPresenter
from mslice.presenters.cut_plotter_presenter import CutPlotterPresenter
from mslice.util.intensity_correction import IntensityType

FORCE_METHOD_CALLS_TO_QAPP_THREAD = 'mslice.plotting.plot_window.plot_figure_manager.force_method_calls_to_qapp_thread'


class PlotFigureTest(unittest.TestCase):

def setUp(self):
self.mock_force_qapp = mock.patch(
FORCE_METHOD_CALLS_TO_QAPP_THREAD).start()
# make it a noop
self.mock_force_qapp.side_effect = lambda arg: arg
from mslice.plotting.plot_window.plot_figure_manager import new_plot_figure_manager
self.new_plot_figure_manager = new_plot_figure_manager
self.slice_plotter_presenter = SlicePlotterPresenter()
self.cut_plotter_presenter = CutPlotterPresenter()

def tearDown(self):
self.mock_force_qapp.stop()

def test_save_slice_nexus_sofqe(self):
gman = mock.Mock()
workspace = 'testworkspace'
file_name = ('', 'test.nxs', '.nxs')
fg = self.new_plot_figure_manager(num=1, global_manager=gman)
fg.add_slice_plot(self.slice_plotter_presenter, workspace=workspace)

with patch('mslice.plotting.plot_window.plot_figure_manager.get_save_directory') as get_save_dir, \
patch('mslice.models.workspacemanager.workspace_algorithms.save_nexus') as save_nexus, \
patch('mslice.models.workspacemanager.workspace_algorithms.get_workspace_handle') as get_handle, \
patch.object(SlicePlotterPresenter, 'get_slice_cache') as get_slice_cache:
get_save_dir.return_value = file_name
slice_cache = mock.Mock()
slice_cache.scattering_function = workspace
get_slice_cache.return_value = slice_cache
get_handle.return_value = workspace
fg.save_plot()
save_nexus.assert_called_once_with(workspace, file_name[1])
get_slice_cache.assert_called_once()

def test_save_slice_matlab_gdos(self):
gman = mock.Mock()
workspace = 'testworkspace'
file_name = ('', 'test.mat', '.mat')
fg = self.new_plot_figure_manager(num=1, global_manager=gman)
fg.add_slice_plot(self.slice_plotter_presenter, workspace=workspace)

with patch('mslice.plotting.plot_window.plot_figure_manager.get_save_directory') as get_save_dir, \
patch('mslice.models.workspacemanager.workspace_algorithms.save_matlab') as save_matlab, \
patch('mslice.models.workspacemanager.workspace_algorithms.get_workspace_handle') as get_handle, \
patch.object(SlicePlotterPresenter, 'get_slice_cache') as get_slice_cache:
get_save_dir.return_value = file_name
slice_cache = mock.Mock()
slice_cache.gdos = workspace
get_slice_cache.return_value = slice_cache
get_handle.return_value = workspace
fg.plot_handler.intensity_type = IntensityType.GDOS
fg.save_plot()
save_matlab.assert_called_once_with(workspace, file_name[1])
get_slice_cache.assert_called_once()

0 comments on commit 198379e

Please sign in to comment.