Skip to content

Commit

Permalink
Merge pull request #954 from mantidproject/release-next
Browse files Browse the repository at this point in the history
Merge release next into main
  • Loading branch information
SilkeSchomann authored Oct 31, 2023
2 parents c5046bc + cd889aa commit ba4b8cd
Show file tree
Hide file tree
Showing 14 changed files with 136 additions and 79 deletions.
10 changes: 5 additions & 5 deletions src/mslice/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def pcolormesh(self, *args, **kwargs):
else:
return Axes.pcolormesh(self, *args, **kwargs)

def recoil(self, workspace, element=None, rmm=None):
def recoil(self, workspace, element=None, rmm=None, **kwargs):
from mslice.app.presenters import get_slice_plotter_presenter
_check_workspace_name(workspace)
workspace = get_workspace_handle(workspace)
Expand All @@ -48,12 +48,12 @@ def recoil(self, workspace, element=None, rmm=None):
plot_handler = GlobalFigureManager.get_active_figure().plot_handler
plot_handler._arb_nuclei_rmm = rmm

get_slice_plotter_presenter().add_overplot_line(workspace.name, key, recoil=True, cif=None)
get_slice_plotter_presenter().add_overplot_line(workspace.name, key, recoil=True, cif=None, **kwargs)

_update_overplot_checklist(key)
_update_legend()

def bragg(self, workspace, element=None, cif=None):
def bragg(self, workspace, element=None, cif=None, **kwargs):
from mslice.app.presenters import get_cut_plotter_presenter, get_slice_plotter_presenter
_check_workspace_name(workspace)
workspace = get_workspace_handle(workspace)
Expand All @@ -62,9 +62,9 @@ def bragg(self, workspace, element=None, cif=None):

ws_type = _get_workspace_type(workspace)
if ws_type == 'HistogramWorkspace':
get_cut_plotter_presenter().add_overplot_line(workspace.name, key, recoil=True, cif=None)
get_cut_plotter_presenter().add_overplot_line(workspace.name, key, recoil=True, cif=None, **kwargs)
elif ws_type == 'MatrixWorkspace':
get_slice_plotter_presenter().add_overplot_line(workspace.name, key, recoil=False, cif=cif)
get_slice_plotter_presenter().add_overplot_line(workspace.name, key, recoil=False, cif=cif, **kwargs)

_update_overplot_checklist(key)
_update_legend()
Expand Down
2 changes: 1 addition & 1 deletion src/mslice/cli/_mslice_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def PlotCut(InputWorkspace, IntensityStart=0, IntensityEnd=0, PlotOver=False):
raise RuntimeError("Incorrect workspace type.")

if IntensityStart == 0 and IntensityEnd == 0:
intensity_range = None
intensity_range = (None, None)
else:
intensity_range = (IntensityStart, IntensityEnd)
from mslice.app.presenters import cli_cut_plotter_presenter
Expand Down
7 changes: 3 additions & 4 deletions src/mslice/models/workspacemanager/workspace_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,12 @@ def scale_workspaces(workspaces, scale_factor=None, from_temp=None, to_temp=None
propagate_properties(ws, result)


def save_workspaces(workspaces, path, save_name, extension, slice_nonpsd=False):
def save_workspaces(workspaces, path, save_name, extension):
"""
:param workspaces: list of workspaces to save
:param path: directory to save to
:param save_name: name to save the file as (plus file extension). Pass none to use workspace name
:param extension: file extension (such as .txt)
:param slice_nonpsd: whether the selection is in non_psd mode
"""
if extension == '.nxs':
save_method = save_nexus
Expand All @@ -265,7 +264,7 @@ def save_workspaces(workspaces, path, save_name, extension, slice_nonpsd=False):
if len(save_names) != len(workspaces):
save_names = [None] * len(workspaces)
for workspace, save_name_single in zip(workspaces, save_names):
_save_single_ws(workspace, save_name_single, save_method, path, extension, slice_nonpsd)
_save_single_ws(workspace, save_name_single, save_method, path, extension)


def export_workspace_to_ads(workspace):
Expand All @@ -279,7 +278,7 @@ def export_workspace_to_ads(workspace):
add_to_ads(workspace)


def _save_single_ws(workspace, save_name, save_method, path, extension, slice_nonpsd):
def _save_single_ws(workspace, save_name, save_method, path, extension):
save_as = save_name if save_name is not None else str(workspace) + extension
full_path = os.path.join(str(path), save_as)
if isinstance(workspace, string_types):
Expand Down
1 change: 1 addition & 0 deletions src/mslice/plotting/plot_window/cut_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ def on_newplot(self, plot_over, ws_name):
self.plot_window.toggle_waterfall_edit()
if not plot_over:
self._reset_plot_window_options()
self.ws_name = ws_name

all_lines = [line for container in line_containers for line in container.get_children()]
for cached_lines in list(self._waterfall_cache.keys()):
Expand Down
4 changes: 2 additions & 2 deletions src/mslice/plotting/plot_window/overplot_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def remove_line(line):
plt.gca().lines.remove(line)


def plot_overplot_line(x, y, key, recoil, cache):
color = OVERPLOT_COLORS[key] if key in OVERPLOT_COLORS else 'c'
def plot_overplot_line(x, y, key, recoil, cache, **kwargs):
color = kwargs.get('color', OVERPLOT_COLORS.get(key, 'c'))
if recoil:
return overplot_line(x, y, color, get_recoil_label(key), cache.rotated)
else:
Expand Down
3 changes: 1 addition & 2 deletions src/mslice/plotting/plot_window/plot_figure_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,7 @@ def save_plot(self):
save_workspaces(workspaces,
file_path,
save_name,
ext,
slice_nonpsd=True)
ext)
except RuntimeError as e:
if str(e) == "unrecognised file extension":
supported_image_types = list(
Expand Down
4 changes: 2 additions & 2 deletions src/mslice/presenters/cut_plotter_presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _get_log_bragg_y_coords(size, portion_of_axes, datum):
return np.resize(np.array([10 ** adj_factor, 10 ** (-adj_factor), np.nan]), size) * datum

def add_overplot_line(self, workspace_name, key, recoil, cif=None, e_is_logarithmic=None, datum=0,
intensity_correction=IntensityType.SCATTERING_FUNCTION):
intensity_correction=IntensityType.SCATTERING_FUNCTION, **kwargs):
cache = self._cut_cache_dict[plt.gca()][0]
if cache.rotated:
warnings.warn("No Bragg peak found as cut has no |Q| dimension.")
Expand All @@ -209,7 +209,7 @@ def add_overplot_line(self, workspace_name, key, recoil, cif=None, e_is_logarith
else:
y = self._get_log_bragg_y_coords(len(y), BRAGG_SIZE_ON_AXES, datum)

self._overplot_cache[key] = plot_overplot_line(x, y, key, recoil, cache)
self._overplot_cache[key] = plot_overplot_line(x, y, key, recoil, cache, **kwargs)
except (ValueError, IndexError):
warnings.warn("No Bragg peak found.")

Expand Down
4 changes: 2 additions & 2 deletions src/mslice/presenters/slice_plotter_presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ def hide_overplot_line(self, workspace, key):
remove_line(line)

def add_overplot_line(self, workspace_name, key, recoil, cif=None, y_has_logarithmic=None, datum=None,
intensity_correction=None):
intensity_correction=None, **kwargs):
cache = self._slice_cache[workspace_name]
if recoil:
x, y = compute_recoil_line(workspace_name, cache.momentum_axis, key)
else:
x, y = compute_powder_line(workspace_name, cache.momentum_axis, key, cif_file=cif)
y = convert_energy_to_meV(y, cache.energy_axis.e_unit)
cache.overplot_lines[key] = plot_overplot_line(x, y, key, recoil, cache)
cache.overplot_lines[key] = plot_overplot_line(x, y, key, recoil, cache, **kwargs)

def validate_intensity(self, intensity_start, intensity_end):
intensity_start = self._to_float(intensity_start)
Expand Down
2 changes: 1 addition & 1 deletion src/mslice/scripting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_algorithm_kwargs(algorithm, existing_ws_refs):
continue
if algorithm.name() == "Load":
if prop.name() == "Filename":
arguments += [f"{prop.name()}='{pval}'"]
arguments += [f"{prop.name()}=r'{pval}'"]
continue
elif prop.name() == "LoaderName" or prop.name() == "LoaderVersion":
continue
Expand Down
79 changes: 41 additions & 38 deletions src/mslice/scripting/helperfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def header(plot_handler):
"""Creates a list of import statements to be used in the generated script header"""
from mslice.plotting.plot_window.cut_plot import CutPlot
from mslice.plotting.plot_window.slice_plot import SlicePlot
statements = ["# Python Script Generated by Mslice on {}\n".format(datetime.now().replace(microsecond=0))]
statements = [f"# Python Script Generated by Mslice on {datetime.now().replace(microsecond=0)}\n"]

statements.append("\n".join(COMMON_PACKAGES))
if isinstance(plot_handler, SlicePlot) and plot_handler.colorbar_log is True:
Expand Down Expand Up @@ -47,8 +47,8 @@ def add_plot_statements(script_lines, plot_handler, ax):
add_slice_plot_statements(script_lines, plot_handler)
add_overplot_statements(script_lines, plot_handler)
elif isinstance(plot_handler, CutPlot):
add_cut_plot_statements(script_lines, plot_handler, ax)
add_overplot_statements(script_lines, plot_handler)
return_ws_vars = add_cut_plot_statements(script_lines, plot_handler, ax)
add_overplot_statements(script_lines, plot_handler, return_ws_vars)

script_lines.append("mc.Show()\n")

Expand All @@ -63,25 +63,25 @@ def add_slice_plot_statements(script_lines, plot_handler):
energy_axis = str(slice.energy_axis)
norm = slice.norm_to_one

script_lines.append('slice_ws = mc.Slice(ws_{}, Axis1="{}", Axis2="{}", NormToOne={})\n\n'.format(
plot_handler.ws_name.replace(".", "_"), momentum_axis, energy_axis, norm))
script_lines.append(f'slice_ws = mc.Slice(ws_{plot_handler.ws_name.replace(".", "_")}, Axis1="{momentum_axis}", Axis2="{energy_axis}", '
f'NormToOne={norm})\n\n')

if plot_handler.intensity is True:
intensity = IntensityCache.get_desc_from_type(plot_handler.intensity_type)
if plot_handler.temp_dependent:
script_lines.append('mesh = ax.pcolormesh(slice_ws, cmap="{}", intensity="{}", temperature={})\n'.format(
cache[plot_handler.ws_name].colourmap, intensity, plot_handler.temp))
script_lines.append(f'mesh = ax.pcolormesh(slice_ws, cmap="{cache[plot_handler.ws_name].colourmap}", intensity="{intensity}", '
f'temperature={plot_handler.temp})\n')
else:
script_lines.append('mesh = ax.pcolormesh(slice_ws, cmap="{}", intensity="{}")\n'.format(
cache[plot_handler.ws_name].colourmap, intensity))
script_lines.append(f'mesh = ax.pcolormesh(slice_ws, cmap="{cache[plot_handler.ws_name].colourmap}", '
f'intensity="{intensity}")\n')
else:
script_lines.append('mesh = ax.pcolormesh(slice_ws, cmap="{}")\n'.format(cache[plot_handler.ws_name].colourmap))
script_lines.append(f'mesh = ax.pcolormesh(slice_ws, cmap="{cache[plot_handler.ws_name].colourmap}")\n')

script_lines.append("mesh.set_clim({}, {})\n".format(*plot_handler.colorbar_range))
script_lines.append(f"mesh.set_clim({plot_handler.colorbar_range[0]}, {plot_handler.colorbar_range[1]})\n")
if plot_handler.colorbar_log:
min, maximum = plot_handler.colorbar_range[0], plot_handler.colorbar_range[1]
min = max(min, LOG_SCALE_MIN)
script_lines.append("mesh.set_norm(colors.LogNorm({}, {}))\n".format(min, maximum))
script_lines.append(f"mesh.set_norm(colors.LogNorm({min}, {maximum}))\n")

script_lines.append("cb = plt.colorbar(mesh, ax=ax)\n")
script_lines.append(f"cb.set_label('{plot_handler.colorbar_label}', labelpad=20, rotation=270, picker=5, "
Expand All @@ -91,12 +91,13 @@ def add_slice_plot_statements(script_lines, plot_handler):
add_plot_options(script_lines, plot_handler)


def add_overplot_statements(script_lines, plot_handler):
def add_overplot_statements(script_lines, plot_handler, ws_vars=None):
"""Adds overplot line statements to the script if they were plotted"""
ax = plot_handler._canvas.figure.gca()
line_artists = ax.lines

for line in line_artists:
color = line._color
label = line._label
if "nolegend" in label:
continue
Expand All @@ -106,23 +107,23 @@ def add_overplot_statements(script_lines, plot_handler):
recoil = True if rmm is not None or key in [1, 2, 4] else False
cif = None # Does not yet account for CIF files

ws_var = ws_vars.pop(0) if ws_vars else f"'{plot_handler.ws_name}'"
if recoil:
if element is None:
script_lines.append("ax.recoil(workspace='{}', rmm={})\n".format(plot_handler.ws_name, rmm))
script_lines.append(f"ax.recoil(workspace={ws_var}, rmm={rmm}, color='{color}')\n")
else:
script_lines.append("ax.recoil(workspace='{}', element='{}')\n".format(plot_handler.ws_name, element))
script_lines.append(f"ax.recoil(workspace={ws_var}, element='{element}', color='{color}')\n")
else:
if cif is None:
script_lines.append("ax.bragg(workspace='{}', element='{}')\n".format(plot_handler.ws_name, element))

script_lines.append(f"ax.bragg(workspace={ws_var}, element='{element}', color='{color}')\n")
else:
script_lines.append("ax.bragg(workspace='{}', cif='{}')\n".format(plot_handler.ws_name, cif))
script_lines.append(f"ax.bragg(workspace={ws_var}, cif='{cif}', color='{color}')\n")


def add_cut_plot_statements(script_lines, plot_handler, ax):
"""Adds cut specific statements to the script"""

add_cut_lines(script_lines, plot_handler, ax)
return_ws_vars = add_cut_lines(script_lines, plot_handler, ax)
add_plot_options(script_lines, plot_handler)

if plot_handler.is_changed("x_log"):
Expand All @@ -132,14 +133,16 @@ def add_cut_plot_statements(script_lines, plot_handler, ax):
if plot_handler.is_changed("y_log"):
script_lines.append(f"ax.set_yscale('symlog', "
f"linthresh=pow(10, np.floor(np.log10({plot_handler.y_axis_min}))))\n")
return return_ws_vars


def add_cut_lines(script_lines, plot_handler, ax):
cuts = plot_handler._cut_plotter_presenter._cut_cache_dict[ax]
errorbars = plot_handler._canvas.figure.gca().containers
intensity_correction = plot_handler.intensity_type
add_cut_lines_with_width(errorbars, script_lines, cuts, intensity_correction)
return_ws_vars = add_cut_lines_with_width(errorbars, script_lines, cuts, intensity_correction)
hide_lines(script_lines, plot_handler, ax)
return return_ws_vars


def hide_lines(script_lines, plot_handler, ax):
Expand Down Expand Up @@ -172,13 +175,14 @@ def hide_lines(script_lines, plot_handler, ax):
def add_cut_lines_with_width(errorbars, script_lines, cuts, intensity_correction):
"""Adds the cut statements for each interval of the cuts that were plotted"""
index = 0 # Required as we run through the loop multiple times for each cut
return_ws_vars = []
for cut in cuts:
integration_start = cut.integration_axis.start
integration_end = cut.integration_axis.end
cut_start, cut_end = integration_start, min(integration_start + cut.width, integration_end)
intensity_range = (cut.intensity_start, cut.intensity_end)
norm_to_one = cut.norm_to_one
algo_str = '' if 'Rebin' in cut.algorithm else ', Algorithm="{}"'.format(cut.algorithm)
algo_str = '' if 'Rebin' in cut.algorithm else f', Algorithm="{cut.algorithm}"'

while cut_start != cut_end and index < len(errorbars):
cut.integration_axis.start = cut_start
Expand All @@ -195,24 +199,25 @@ def add_cut_lines_with_width(errorbars, script_lines, cuts, intensity_correction

intensity_correction_arg = f"'{IntensityCache.get_desc_from_type(intensity_correction)}'" \
if not intensity_correction == IntensityType.SCATTERING_FUNCTION else False
script_lines.append('cut_ws_{} = mc.Cut(ws_{}, CutAxis="{}", IntegrationAxis="{}", '
'NormToOne={}{}, IntensityCorrection={}, SampleTemperature={})'
'\n'.format(index, replace_ws_special_chars(cut.parent_ws_name), cut_axis, integration_axis,
norm_to_one, algo_str, intensity_correction_arg, cut.raw_sample_temp))

cut_ws = f'cut_ws_{index}'
script_lines.append(f'{cut_ws} = mc.Cut(ws_{replace_ws_special_chars(cut.parent_ws_name)}, CutAxis="{cut_axis}", '
f'IntegrationAxis="{integration_axis}", NormToOne={norm_to_one}{algo_str}, '
f'IntensityCorrection={intensity_correction_arg}, SampleTemperature={cut.raw_sample_temp})\n')
return_ws_vars.append(cut_ws)
plot_over = False if index == 0 else True
if intensity_range != (None, None):
script_lines.append(
'ax.errorbar(cut_ws_{}, label="{}", color="{}", marker="{}", ls="{}", '
'lw={}, intensity_range={})\n\n'.format(index, label, colour, marker, style, width,
intensity_range))
f'ax.errorbar(cut_ws_{index}, label="{label}", color="{colour}", marker="{marker}", ls="{style}", '
f'lw={width}, intensity_range={intensity_range}, plot_over={plot_over})\n\n')
else:
script_lines.append(
'ax.errorbar(cut_ws_{}, label="{}", color="{}", marker="{}", ls="{}", '
'lw={})\n\n'.format(index, label, colour, marker, style, width))
f'ax.errorbar(cut_ws_{index}, label="{label}", color="{colour}", marker="{marker}", ls="{style}", '
f'lw={width}, plot_over={plot_over})\n\n')

cut_start, cut_end = cut_end, min(cut_end + cut.width, integration_end)
index += 1
cut.reset_integration_axis(cut.start, cut.end)
return return_ws_vars


def add_plot_options(script_lines, plot_handler):
Expand All @@ -226,16 +231,16 @@ def add_plot_options(script_lines, plot_handler):
script_lines.append(f"ax.set_xlabel(r'{plot_handler.x_label}', fontsize={plot_handler.x_label_size})\n")

if plot_handler.is_changed("y_grid"):
script_lines.append("ax.grid({}, axis='y')\n".format(plot_handler.y_grid))
script_lines.append(f"ax.grid({plot_handler.y_grid}, axis='y')\n")

if plot_handler.is_changed("x_grid"):
script_lines.append("ax.grid({}, axis='x')\n".format(plot_handler.x_grid))
script_lines.append(f"ax.grid({plot_handler.x_grid}, axis='x')\n")

if plot_handler.is_changed("y_range"):
script_lines.append("ax.set_ylim(bottom={}, top={})\n".format(*plot_handler.y_range))
script_lines.append(f"ax.set_ylim(bottom={plot_handler.y_range[0]}, top={plot_handler.y_range[1]})\n")

if plot_handler.is_changed("x_range"):
script_lines.append("ax.set_xlim(left={}, right={})\n".format(*plot_handler.x_range))
script_lines.append(f"ax.set_xlim(left={plot_handler.x_range[0]}, right={plot_handler.x_range[1]})\n")

if plot_handler.is_changed("y_range_font_size"):
script_lines.append(f"ax.yaxis.set_tick_params(labelsize={plot_handler.y_range_font_size})\n")
Expand All @@ -245,9 +250,7 @@ def add_plot_options(script_lines, plot_handler):

from mslice.plotting.plot_window.cut_plot import CutPlot
if isinstance(plot_handler, CutPlot) and plot_handler.is_changed("waterfall"):
script_lines.append("ax.set_waterfall({}, {}, {})\n".format(plot_handler.waterfall,
plot_handler.waterfall_x,
plot_handler.waterfall_y))
script_lines.append(f"ax.set_waterfall({plot_handler.waterfall}, {plot_handler.waterfall_x},{plot_handler.waterfall_y})\n")


def replace_ws_special_chars(workspace_name):
Expand Down
Loading

0 comments on commit ba4b8cd

Please sign in to comment.