From 86341af8eeea3d44624dfa50b83e9e59da8e2085 Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Sun, 5 Sep 2021 23:55:27 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8Compatibility=20to=20pyglotaran=20(0.5?= =?UTF-8?q?.0-dev)=20staging=20branch=20(#30)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Backwards compatible to v0.4.1 and forwards compatible to v0.5.0. * 馃憣 Adjusted future usage to changes in pyglotaran PR 786 * 馃憣 Changed 'get_shifted_traces' to use main_irf_nr instead of mean This allows users to specify the index of the IRF's main components for the shift, rather than assuming that all components of the irf have the same center. --- pyglotaran_extras/plotting/plot_overview.py | 60 +++++++++++++- pyglotaran_extras/plotting/plot_traces.py | 87 +++++++++++++++++++-- 2 files changed, 139 insertions(+), 8 deletions(-) diff --git a/pyglotaran_extras/plotting/plot_overview.py b/pyglotaran_extras/plotting/plot_overview.py index 715e70ab..e9933a47 100644 --- a/pyglotaran_extras/plotting/plot_overview.py +++ b/pyglotaran_extras/plotting/plot_overview.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from pathlib import Path +from typing import TYPE_CHECKING import matplotlib.pyplot as plt import xarray as xr @@ -10,8 +13,51 @@ from pyglotaran_extras.plotting.plot_traces import plot_traces from pyglotaran_extras.plotting.style import PlotStyle - -def plot_overview(result, center_位=None, linlog=True, linthresh=1, show_data=False): +if TYPE_CHECKING: + from glotaran.project import Result + from matplotlib.figure import Figure + + +def plot_overview( + result: xr.Dataset | Path | Result, + center_位: float | None = None, + linlog: bool = True, + linthresh: float = 1, + linscale: float = 1, + show_data: bool = False, + main_irf_nr: int = 0, +) -> Figure: + """Plot overview of the optimization result. + + Parameters + ---------- + result : xr.Dataset | Path | Result + Result from a pyglotaran optimization as dataset, Path or Result object. + center_位: float | None + Center wavelength (位 in nm) + linlog: bool + Whether to use 'symlog' scale or not, by default False + linthresh: int + A single float which defines the range (-x, x), within which the plot is linear. + This avoids having the plot go to infinity around zero., by default 1 + linscale: int + This allows the linear range (-linthresh to linthresh) to be stretched + relative to the logarithmic range. + Its value is the number of decades to use for each half of the linear range. + For example, when linscale == 1.0 (the default), the space used for the + positive and negative halves of the linear range will be equal to one + decade in the logarithmic range., by default 1 + show_data : bool + Whether to show the input data or residual, by default False + main_irf_nr: int + Index of the main ``irf`` component when using an ``irf`` + parametrized with multiple peaks , by default 0 + + Returns + ------- + Figure + Figure object which contains the plots. + """ res = load_data(result) @@ -27,7 +73,15 @@ def plot_overview(result, center_位=None, linlog=True, linthresh=1, show_data=Fa center_位 = min(res.dims["spectral"], round(res.dims["spectral"] / 2)) # First and second row: concentrations - SAS/EAS - DAS - plot_traces(res, ax[0, 0], center_位, linlog=linlog, linthresh=linthresh) + plot_traces( + res, + ax[0, 0], + center_位, + linlog=linlog, + linthresh=linthresh, + linscale=linscale, + main_irf_nr=main_irf_nr, + ) plot_spectra(res, ax[0:2, 1:3]) plot_svd(res, ax[2:4, 0:3], linlog=linlog, linthresh=linthresh) plot_residual(res, ax[1, 0], linlog=linlog, linthresh=linthresh, show_data=show_data) diff --git a/pyglotaran_extras/plotting/plot_traces.py b/pyglotaran_extras/plotting/plot_traces.py index a454945e..72ffdcdd 100644 --- a/pyglotaran_extras/plotting/plot_traces.py +++ b/pyglotaran_extras/plotting/plot_traces.py @@ -1,9 +1,40 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import matplotlib.pyplot as plt from pyglotaran_extras.plotting.style import PlotStyle +if TYPE_CHECKING: + import xarray as xr + + +def get_shifted_traces( + res: xr.Dataset, center_位: float | None = None, main_irf_nr: int = 0 +) -> xr.DataArray: + """Shift traces by the position of the main ``irf``. + + Parameters + ---------- + res: xr.Dataset + Result dataset from a pyglotaran optimization. + center_位: float|None + Center wavelength (位 in nm), by default None + main_irf_nr: int + Index of the main ``irf`` component when using an ``irf`` + parametrized with multiple peaks , by default 0 -def get_shifted_traces(res, center_位=None): + Returns + ------- + xr.DataArray + Traces shifted by the ``irf``s location, to align the at 0. + + Raises + ------ + ValueError + If no known concentration was found in the result. + """ if "species_concentration" in res: traces = res.species_concentration elif "species_associated_concentrations" in res: @@ -13,14 +44,22 @@ def get_shifted_traces(res, center_位=None): times = traces.coords["time"] if center_位 is None: # center wavelength (位 in nm) center_位 = min(res.dims["spectral"], round(res.dims["spectral"] / 2)) - if "center_dispersion_1" in res: - center_dispersion = res.center_dispersion_1 # TODO: clarify against pyglotaran API why _1? + + if "irf_center_location" in res: + irf_center_location = res.irf_center_location + irf_loc = irf_center_location.sel(spectral=center_位, method="nearest").item() + elif "center_dispersion_1" in res: + # legacy compatibility pyglotaran<0.5.0 + center_dispersion = res.center_dispersion_1 irf_loc = center_dispersion.sel(spectral=center_位, method="nearest").item() elif "irf_center" in res: irf_loc = res.irf_center else: irf_loc = min(times) + if hasattr(irf_loc, "shape") and len(irf_loc.shape) > 0: + irf_loc = irf_loc[main_irf_nr] + times_shifted = times - irf_loc return traces.assign_coords(time=times_shifted) @@ -32,8 +71,46 @@ def calculate_x_ranges(res, linrange): pass -def plot_traces(res, ax, center_位, linlog=False, linthresh=1, linscale=1): - traces = get_shifted_traces(res, center_位) +def plot_traces( + res: xr.Dataset, + ax: plt.Axes, + center_位: float | None, + linlog: bool = False, + linthresh: float = 1, + linscale: float = 1, + main_irf_nr: int = 0, +) -> None: + """Plot traces on the given axis ``ax`` + + Parameters + ---------- + res: xr.Dataset + Result dataset from a pyglotaran optimization. + ax: plt.Axes + Axes to plot the traces on + center_位: float | None + Center wavelength (位 in nm) + linlog: bool + Whether to use 'symlog' scale or not, by default False + linthresh: int + A single float which defines the range (-x, x), within which the plot is linear. + This avoids having the plot go to infinity around zero., by default 1 + linscale: int + This allows the linear range (-linthresh to linthresh) to be stretched + relative to the logarithmic range. + Its value is the number of decades to use for each half of the linear range. + For example, when linscale == 1.0 (the default), the space used for the + positive and negative halves of the linear range will be equal to one + decade in the logarithmic range., by default 1 + main_irf_nr: int + Index of the main ``irf`` component when using an ``irf`` + parametrized with multiple peaks , by default 0 + + See Also + -------- + get_shifted_traces + """ + traces = get_shifted_traces(res, center_位, main_irf_nr) plot_style = PlotStyle() plt.rc("axes", prop_cycle=plot_style.cycler)