Skip to content

Commit

Permalink
✨Compatibility to pyglotaran (0.5.0-dev) staging branch (#30)
Browse files Browse the repository at this point in the history
Backwards compatible to v0.4.1 and forwards compatible to v0.5.0.

* 👌 Adjusted future usage to changes in pyglotaran PR 786

* 👌 Changed 'get_shifted_traces' to use main_irf_nr instead of mean

This allows users to specify the index of the IRF's main components for the shift, rather than assuming that all components of the irf have the same center.
  • Loading branch information
s-weigand authored Sep 5, 2021
1 parent c118df7 commit 86341af
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 8 deletions.
60 changes: 57 additions & 3 deletions pyglotaran_extras/plotting/plot_overview.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import xarray as xr
Expand All @@ -10,8 +13,51 @@
from pyglotaran_extras.plotting.plot_traces import plot_traces
from pyglotaran_extras.plotting.style import PlotStyle


def plot_overview(result, center_λ=None, linlog=True, linthresh=1, show_data=False):
if TYPE_CHECKING:
from glotaran.project import Result
from matplotlib.figure import Figure


def plot_overview(
result: xr.Dataset | Path | Result,
center_λ: float | None = None,
linlog: bool = True,
linthresh: float = 1,
linscale: float = 1,
show_data: bool = False,
main_irf_nr: int = 0,
) -> Figure:
"""Plot overview of the optimization result.
Parameters
----------
result : xr.Dataset | Path | Result
Result from a pyglotaran optimization as dataset, Path or Result object.
center_λ: float | None
Center wavelength (λ in nm)
linlog: bool
Whether to use 'symlog' scale or not, by default False
linthresh: int
A single float which defines the range (-x, x), within which the plot is linear.
This avoids having the plot go to infinity around zero., by default 1
linscale: int
This allows the linear range (-linthresh to linthresh) to be stretched
relative to the logarithmic range.
Its value is the number of decades to use for each half of the linear range.
For example, when linscale == 1.0 (the default), the space used for the
positive and negative halves of the linear range will be equal to one
decade in the logarithmic range., by default 1
show_data : bool
Whether to show the input data or residual, by default False
main_irf_nr: int
Index of the main ``irf`` component when using an ``irf``
parametrized with multiple peaks , by default 0
Returns
-------
Figure
Figure object which contains the plots.
"""

res = load_data(result)

Expand All @@ -27,7 +73,15 @@ def plot_overview(result, center_λ=None, linlog=True, linthresh=1, show_data=Fa
center_λ = min(res.dims["spectral"], round(res.dims["spectral"] / 2))

# First and second row: concentrations - SAS/EAS - DAS
plot_traces(res, ax[0, 0], center_λ, linlog=linlog, linthresh=linthresh)
plot_traces(
res,
ax[0, 0],
center_λ,
linlog=linlog,
linthresh=linthresh,
linscale=linscale,
main_irf_nr=main_irf_nr,
)
plot_spectra(res, ax[0:2, 1:3])
plot_svd(res, ax[2:4, 0:3], linlog=linlog, linthresh=linthresh)
plot_residual(res, ax[1, 0], linlog=linlog, linthresh=linthresh, show_data=show_data)
Expand Down
87 changes: 82 additions & 5 deletions pyglotaran_extras/plotting/plot_traces.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,40 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt

from pyglotaran_extras.plotting.style import PlotStyle

if TYPE_CHECKING:
import xarray as xr


def get_shifted_traces(
res: xr.Dataset, center_λ: float | None = None, main_irf_nr: int = 0
) -> xr.DataArray:
"""Shift traces by the position of the main ``irf``.
Parameters
----------
res: xr.Dataset
Result dataset from a pyglotaran optimization.
center_λ: float|None
Center wavelength (λ in nm), by default None
main_irf_nr: int
Index of the main ``irf`` component when using an ``irf``
parametrized with multiple peaks , by default 0
def get_shifted_traces(res, center_λ=None):
Returns
-------
xr.DataArray
Traces shifted by the ``irf``s location, to align the at 0.
Raises
------
ValueError
If no known concentration was found in the result.
"""
if "species_concentration" in res:
traces = res.species_concentration
elif "species_associated_concentrations" in res:
Expand All @@ -13,14 +44,22 @@ def get_shifted_traces(res, center_λ=None):
times = traces.coords["time"]
if center_λ is None: # center wavelength (λ in nm)
center_λ = min(res.dims["spectral"], round(res.dims["spectral"] / 2))
if "center_dispersion_1" in res:
center_dispersion = res.center_dispersion_1 # TODO: clarify against pyglotaran API why _1?

if "irf_center_location" in res:
irf_center_location = res.irf_center_location
irf_loc = irf_center_location.sel(spectral=center_λ, method="nearest").item()
elif "center_dispersion_1" in res:
# legacy compatibility pyglotaran<0.5.0
center_dispersion = res.center_dispersion_1
irf_loc = center_dispersion.sel(spectral=center_λ, method="nearest").item()
elif "irf_center" in res:
irf_loc = res.irf_center
else:
irf_loc = min(times)

if hasattr(irf_loc, "shape") and len(irf_loc.shape) > 0:
irf_loc = irf_loc[main_irf_nr]

times_shifted = times - irf_loc
return traces.assign_coords(time=times_shifted)

Expand All @@ -32,8 +71,46 @@ def calculate_x_ranges(res, linrange):
pass


def plot_traces(res, ax, center_λ, linlog=False, linthresh=1, linscale=1):
traces = get_shifted_traces(res, center_λ)
def plot_traces(
res: xr.Dataset,
ax: plt.Axes,
center_λ: float | None,
linlog: bool = False,
linthresh: float = 1,
linscale: float = 1,
main_irf_nr: int = 0,
) -> None:
"""Plot traces on the given axis ``ax``
Parameters
----------
res: xr.Dataset
Result dataset from a pyglotaran optimization.
ax: plt.Axes
Axes to plot the traces on
center_λ: float | None
Center wavelength (λ in nm)
linlog: bool
Whether to use 'symlog' scale or not, by default False
linthresh: int
A single float which defines the range (-x, x), within which the plot is linear.
This avoids having the plot go to infinity around zero., by default 1
linscale: int
This allows the linear range (-linthresh to linthresh) to be stretched
relative to the logarithmic range.
Its value is the number of decades to use for each half of the linear range.
For example, when linscale == 1.0 (the default), the space used for the
positive and negative halves of the linear range will be equal to one
decade in the logarithmic range., by default 1
main_irf_nr: int
Index of the main ``irf`` component when using an ``irf``
parametrized with multiple peaks , by default 0
See Also
--------
get_shifted_traces
"""
traces = get_shifted_traces(res, center_λ, main_irf_nr)
plot_style = PlotStyle()
plt.rc("axes", prop_cycle=plot_style.cycler)

Expand Down

0 comments on commit 86341af

Please sign in to comment.