diff --git a/pyglotaran_extras/plotting/plot_overview.py b/pyglotaran_extras/plotting/plot_overview.py index 715e70ab..e9933a47 100644 --- a/pyglotaran_extras/plotting/plot_overview.py +++ b/pyglotaran_extras/plotting/plot_overview.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from pathlib import Path +from typing import TYPE_CHECKING import matplotlib.pyplot as plt import xarray as xr @@ -10,8 +13,51 @@ from pyglotaran_extras.plotting.plot_traces import plot_traces from pyglotaran_extras.plotting.style import PlotStyle - -def plot_overview(result, center_λ=None, linlog=True, linthresh=1, show_data=False): +if TYPE_CHECKING: + from glotaran.project import Result + from matplotlib.figure import Figure + + +def plot_overview( + result: xr.Dataset | Path | Result, + center_λ: float | None = None, + linlog: bool = True, + linthresh: float = 1, + linscale: float = 1, + show_data: bool = False, + main_irf_nr: int = 0, +) -> Figure: + """Plot overview of the optimization result. + + Parameters + ---------- + result : xr.Dataset | Path | Result + Result from a pyglotaran optimization as dataset, Path or Result object. + center_λ: float | None + Center wavelength (λ in nm) + linlog: bool + Whether to use 'symlog' scale or not, by default False + linthresh: int + A single float which defines the range (-x, x), within which the plot is linear. + This avoids having the plot go to infinity around zero., by default 1 + linscale: int + This allows the linear range (-linthresh to linthresh) to be stretched + relative to the logarithmic range. + Its value is the number of decades to use for each half of the linear range. + For example, when linscale == 1.0 (the default), the space used for the + positive and negative halves of the linear range will be equal to one + decade in the logarithmic range., by default 1 + show_data : bool + Whether to show the input data or residual, by default False + main_irf_nr: int + Index of the main ``irf`` component when using an ``irf`` + parametrized with multiple peaks , by default 0 + + Returns + ------- + Figure + Figure object which contains the plots. + """ res = load_data(result) @@ -27,7 +73,15 @@ def plot_overview(result, center_λ=None, linlog=True, linthresh=1, show_data=Fa center_λ = min(res.dims["spectral"], round(res.dims["spectral"] / 2)) # First and second row: concentrations - SAS/EAS - DAS - plot_traces(res, ax[0, 0], center_λ, linlog=linlog, linthresh=linthresh) + plot_traces( + res, + ax[0, 0], + center_λ, + linlog=linlog, + linthresh=linthresh, + linscale=linscale, + main_irf_nr=main_irf_nr, + ) plot_spectra(res, ax[0:2, 1:3]) plot_svd(res, ax[2:4, 0:3], linlog=linlog, linthresh=linthresh) plot_residual(res, ax[1, 0], linlog=linlog, linthresh=linthresh, show_data=show_data) diff --git a/pyglotaran_extras/plotting/plot_traces.py b/pyglotaran_extras/plotting/plot_traces.py index a454945e..72ffdcdd 100644 --- a/pyglotaran_extras/plotting/plot_traces.py +++ b/pyglotaran_extras/plotting/plot_traces.py @@ -1,9 +1,40 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import matplotlib.pyplot as plt from pyglotaran_extras.plotting.style import PlotStyle +if TYPE_CHECKING: + import xarray as xr + + +def get_shifted_traces( + res: xr.Dataset, center_λ: float | None = None, main_irf_nr: int = 0 +) -> xr.DataArray: + """Shift traces by the position of the main ``irf``. + + Parameters + ---------- + res: xr.Dataset + Result dataset from a pyglotaran optimization. + center_λ: float|None + Center wavelength (λ in nm), by default None + main_irf_nr: int + Index of the main ``irf`` component when using an ``irf`` + parametrized with multiple peaks , by default 0 -def get_shifted_traces(res, center_λ=None): + Returns + ------- + xr.DataArray + Traces shifted by the ``irf``s location, to align the at 0. + + Raises + ------ + ValueError + If no known concentration was found in the result. + """ if "species_concentration" in res: traces = res.species_concentration elif "species_associated_concentrations" in res: @@ -13,14 +44,22 @@ def get_shifted_traces(res, center_λ=None): times = traces.coords["time"] if center_λ is None: # center wavelength (λ in nm) center_λ = min(res.dims["spectral"], round(res.dims["spectral"] / 2)) - if "center_dispersion_1" in res: - center_dispersion = res.center_dispersion_1 # TODO: clarify against pyglotaran API why _1? + + if "irf_center_location" in res: + irf_center_location = res.irf_center_location + irf_loc = irf_center_location.sel(spectral=center_λ, method="nearest").item() + elif "center_dispersion_1" in res: + # legacy compatibility pyglotaran<0.5.0 + center_dispersion = res.center_dispersion_1 irf_loc = center_dispersion.sel(spectral=center_λ, method="nearest").item() elif "irf_center" in res: irf_loc = res.irf_center else: irf_loc = min(times) + if hasattr(irf_loc, "shape") and len(irf_loc.shape) > 0: + irf_loc = irf_loc[main_irf_nr] + times_shifted = times - irf_loc return traces.assign_coords(time=times_shifted) @@ -32,8 +71,46 @@ def calculate_x_ranges(res, linrange): pass -def plot_traces(res, ax, center_λ, linlog=False, linthresh=1, linscale=1): - traces = get_shifted_traces(res, center_λ) +def plot_traces( + res: xr.Dataset, + ax: plt.Axes, + center_λ: float | None, + linlog: bool = False, + linthresh: float = 1, + linscale: float = 1, + main_irf_nr: int = 0, +) -> None: + """Plot traces on the given axis ``ax`` + + Parameters + ---------- + res: xr.Dataset + Result dataset from a pyglotaran optimization. + ax: plt.Axes + Axes to plot the traces on + center_λ: float | None + Center wavelength (λ in nm) + linlog: bool + Whether to use 'symlog' scale or not, by default False + linthresh: int + A single float which defines the range (-x, x), within which the plot is linear. + This avoids having the plot go to infinity around zero., by default 1 + linscale: int + This allows the linear range (-linthresh to linthresh) to be stretched + relative to the logarithmic range. + Its value is the number of decades to use for each half of the linear range. + For example, when linscale == 1.0 (the default), the space used for the + positive and negative halves of the linear range will be equal to one + decade in the logarithmic range., by default 1 + main_irf_nr: int + Index of the main ``irf`` component when using an ``irf`` + parametrized with multiple peaks , by default 0 + + See Also + -------- + get_shifted_traces + """ + traces = get_shifted_traces(res, center_λ, main_irf_nr) plot_style = PlotStyle() plt.rc("axes", prop_cycle=plot_style.cycler)