Skip to content

Commit

Permalink
actually use GFP for EEG channels in plot_compare_evokeds (mne-tools#…
Browse files Browse the repository at this point in the history
…12410)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
drammock and autofix-ci[bot] authored Feb 2, 2024
1 parent 3a42bb9 commit d8ea2f5
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 69 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12410.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
In :func:`~mne.viz.plot_compare_evokeds`, actually plot GFP (not RMS amplitude) for EEG channels when global field power is requested by `Daniel McCloy`_.
6 changes: 1 addition & 5 deletions mne/time_frequency/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
_is_numeric,
check_fname,
)
from ..utils.misc import _pl
from ..utils.misc import _identity_function, _pl
from ..utils.spectrum import _split_psd_kwargs
from ..viz.topo import _plot_timeseries, _plot_timeseries_unified, _plot_topo
from ..viz.topomap import _make_head_outlines, _prepare_topomap_plot, plot_psds_topomap
Expand All @@ -60,10 +60,6 @@
from .psd import _check_nfft, psd_array_welch


def _identity_function(x):
return x


class SpectrumMixin:
"""Mixin providing spectral plotting methods to sensor-space containers."""

Expand Down
46 changes: 40 additions & 6 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,12 +718,46 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
0 and 255.
"""

docdict["combine"] = """
combine : None | str | callable
How to combine information across channels. If a :class:`str`, must be
one of 'mean', 'median', 'std' (standard deviation) or 'gfp' (global
field power).
"""
_combine_template = """
combine : 'mean' | {literals} | callable | None
How to aggregate across channels. If ``None``, {none}. If a string,
``"mean"`` uses :func:`numpy.mean`, {other_string}.
If :func:`callable`, it must operate on an :class:`array <numpy.ndarray>`
of shape ``({shape})`` and return an array of shape
``({return_shape})``. {example}
{notes}Defaults to ``None``.
"""
_example = """For example::
combine = lambda data: np.median(data, axis=1)
"""
_median_std_gfp = """``"median"`` computes the `marginal median
<https://en.wikipedia.org/wiki/Median#Marginal_median>`__, ``"std"``
uses :func:`numpy.std`, and ``"gfp"`` computes global field power
for EEG channels and RMS amplitude for MEG channels"""
docdict["combine_plot_compare_evokeds"] = _combine_template.format(
literals="'median' | 'std' | 'gfp'",
none="""channels are combined by
computing GFP/RMS, unless ``picks`` is a single channel (not channel type)
or ``axes="topo"``, in which cases no combining is performed""",
other_string=_median_std_gfp,
shape="n_evokeds, n_channels, n_times",
return_shape="n_evokeds, n_times",
example=_example,
notes="",
)
docdict["combine_plot_epochs_image"] = _combine_template.format(
literals="'median' | 'std' | 'gfp'",
none="""channels are combined by
computing GFP/RMS, unless ``group_by`` is also ``None`` and ``picks`` is a
list of specific channels (not channel types), in which case no combining
is performed and each channel gets its own figure""",
other_string=_median_std_gfp,
shape="n_epochs, n_channels, n_times",
return_shape="n_epochs, n_times",
example=_example,
notes="See Notes for further details. ",
)

docdict["compute_proj_ecg"] = """This function will:
Expand Down
4 changes: 4 additions & 0 deletions mne/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from .check import _check_option, _validate_type


def _identity_function(x):
return x


# TODO: no longer needed when py3.9 is minimum supported version
def _empty_hash(kind="md5"):
func = getattr(hashlib, kind)
Expand Down
14 changes: 1 addition & 13 deletions mne/viz/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,19 +145,7 @@ def plot_epochs_image(
``overlay_times`` should be ordered to correspond with the
:class:`~mne.Epochs` object (i.e., ``overlay_times[0]`` corresponds to
``epochs[0]``, etc).
%(combine)s
If callable, the callable must accept one positional input (data of
shape ``(n_epochs, n_channels, n_times)``) and return an
:class:`array <numpy.ndarray>` of shape ``(n_epochs, n_times)``. For
example::
combine = lambda data: np.median(data, axis=1)
If ``combine`` is ``None``, channels are combined by computing GFP,
unless ``group_by`` is also ``None`` and ``picks`` is a list of
specific channels (not channel types), in which case no combining is
performed and each channel gets its own figure. See Notes for further
details. Defaults to ``None``.
%(combine_plot_epochs_image)s
group_by : None | dict
Specifies which channels are aggregated into a single figure, with
aggregation method determined by the ``combine`` parameter. If not
Expand Down
61 changes: 37 additions & 24 deletions mne/viz/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -2484,14 +2484,22 @@ def _draw_axes_pce(
)


def _get_data_and_ci(evoked, combine, combine_func, picks, scaling=1, ci_fun=None):
def _get_data_and_ci(
evoked, combine, combine_func, ch_type, picks, scaling=1, ci_fun=None
):
"""Compute (sensor-aggregated, scaled) time series and possibly CI."""
picks = np.array(picks).flatten()
# apply scalings
data = np.array([evk.data[picks] * scaling for evk in evoked])
# combine across sensors
if combine is not None:
logger.info(f'combining channels using "{combine}"')
if combine == "gfp" and ch_type == "eeg":
msg = f"GFP ({ch_type} channels)"
elif combine == "gfp" and ch_type in ("mag", "grad"):
msg = f"RMS ({ch_type} channels)"
else:
msg = f'"{combine}"'
logger.info(f"combining channels using {msg}")
data = combine_func(data)
# get confidence band
if ci_fun is not None:
Expand Down Expand Up @@ -2551,7 +2559,7 @@ def _plot_compare_evokeds(
ax.set_title(title)


def _title_helper_pce(title, picked_types, picks, ch_names, combine):
def _title_helper_pce(title, picked_types, picks, ch_names, ch_type, combine):
"""Format title for plot_compare_evokeds."""
if title is None:
title = (
Expand All @@ -2562,8 +2570,12 @@ def _title_helper_pce(title, picked_types, picks, ch_names, combine):
# add the `combine` modifier
do_combine = picked_types or len(ch_names) > 1
if title is not None and len(title) and isinstance(combine, str) and do_combine:
_comb = combine.upper() if combine == "gfp" else combine
_comb = "std. dev." if _comb == "std" else _comb
if combine == "gfp":
_comb = "RMS" if ch_type in ("mag", "grad") else "GFP"
elif combine == "std":
_comb = "std. dev."
else:
_comb = combine
title += f" ({_comb})"
return title

Expand Down Expand Up @@ -2744,18 +2756,7 @@ def plot_compare_evokeds(
value of the ``combine`` parameter. Defaults to ``None``.
show : bool
Whether to show the figure. Defaults to ``True``.
%(combine)s
If callable, the callable must accept one positional input (data of
shape ``(n_evokeds, n_channels, n_times)``) and return an
:class:`array <numpy.ndarray>` of shape ``(n_epochs, n_times)``. For
example::
combine = lambda data: np.median(data, axis=1)
If ``combine`` is ``None``, channels are combined by computing GFP,
unless ``picks`` is a single channel (not channel type) or
``axes='topo'``, in which cases no combining is performed. Defaults to
``None``.
%(combine_plot_compare_evokeds)s
%(sphere_topomap_auto)s
%(time_unit)s
Expand Down Expand Up @@ -2914,11 +2915,19 @@ def plot_compare_evokeds(
if combine is None and len(picks) > 1 and not do_topo:
combine = "gfp"
# convert `combine` into callable (if None or str)
combine_func = _make_combine_callable(combine)
combine_funcs = {
ch_type: _make_combine_callable(combine, ch_type=ch_type)
for ch_type in ch_types
}

# title
title = _title_helper_pce(
title, picked_types, picks=orig_picks, ch_names=ch_names, combine=combine
title,
picked_types,
picks=orig_picks,
ch_names=ch_names,
ch_type=ch_types[0] if len(ch_types) == 1 else None,
combine=combine,
)
topo_disp_title = False
# setup axes
Expand All @@ -2943,9 +2952,7 @@ def plot_compare_evokeds(
_validate_if_list_of_axes(axes, obligatory_len=len(ch_types))

if len(ch_types) > 1:
logger.info(
"Multiple channel types selected, returning one figure " "per type."
)
logger.info("Multiple channel types selected, returning one figure per type.")
figs = list()
for ch_type, ax in zip(ch_types, axes):
_picks = picks_by_type[ch_type]
Expand All @@ -2954,7 +2961,12 @@ def plot_compare_evokeds(
# don't pass `combine` here; title will run through this helper
# function a second time & it will get added then
_title = _title_helper_pce(
title, picked_types, picks=_picks, ch_names=_ch_names, combine=None
title,
picked_types,
picks=_picks,
ch_names=_ch_names,
ch_type=ch_type,
combine=None,
)
figs.extend(
plot_compare_evokeds(
Expand Down Expand Up @@ -3003,7 +3015,7 @@ def plot_compare_evokeds(
# some things that depend on ch_type:
units = _handle_default("units")[ch_type]
scalings = _handle_default("scalings")[ch_type]

combine_func = combine_funcs[ch_type]
# prep for topo
pos_picks = picks # need this version of picks for sensor location inset
info = pick_info(info, sel=picks, copy=True)
Expand Down Expand Up @@ -3136,6 +3148,7 @@ def click_func(
this_evokeds,
combine,
c_func,
ch_type=ch_type,
picks=_picks,
scaling=scalings,
ci_fun=ci_fun,
Expand Down
33 changes: 23 additions & 10 deletions mne/viz/tests/test_evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,21 +402,34 @@ def test_plot_white():
evoked_sss.plot_white(cov, time_unit="s")


@pytest.mark.parametrize(
"combine,vlines,title,picks",
(
pytest.param(None, [0.1, 0.2], "MEG 0113", "MEG 0113", id="singlepick"),
pytest.param("mean", [], "(mean)", "mag", id="mag-mean"),
pytest.param("gfp", "auto", "(GFP)", "eeg", id="eeg-gfp"),
pytest.param(None, "auto", "(RMS)", ["MEG 0113", "MEG 0112"], id="meg-rms"),
pytest.param(
"std", "auto", "(std. dev.)", ["MEG 0113", "MEG 0112"], id="meg-std"
),
pytest.param(
lambda x: np.min(x, axis=1), "auto", "MEG 0112", [0, 1], id="intpicks"
),
),
)
def test_plot_compare_evokeds_title(evoked, picks, vlines, combine, title):
"""Test title generation by plot_compare_evokeds()."""
# test picks, combine, and vlines (1-channel pick also shows sensor inset)
fig = plot_compare_evokeds(evoked, picks=picks, vlines=vlines, combine=combine)
assert fig[0].axes[0].get_title().endswith(title)


@pytest.mark.slowtest # slow on Azure
def test_plot_compare_evokeds():
def test_plot_compare_evokeds(evoked):
"""Test plot_compare_evokeds."""
evoked = _get_epochs().average()
# test defaults
figs = plot_compare_evokeds(evoked)
assert len(figs) == 3
# test picks, combine, and vlines (1-channel pick also shows sensor inset)
picks = ["MEG 0113", "mag"] + 2 * [["MEG 0113", "MEG 0112"]] + [[0, 1]]
vlines = [[0.1, 0.2], []] + 3 * ["auto"]
combine = [None, "mean", "std", None, lambda x: np.min(x, axis=1)]
title = ["MEG 0113", "(mean)", "(std. dev.)", "(GFP)", "MEG 0112"]
for _p, _v, _c, _t in zip(picks, vlines, combine, title):
fig = plot_compare_evokeds(evoked, picks=_p, vlines=_v, combine=_c)
assert fig[0].axes[0].get_title().endswith(_t)
# test passing more than one evoked
red, blue = evoked.copy(), evoked.copy()
red.comment = red.comment + "*" * 100
Expand Down
57 changes: 46 additions & 11 deletions mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from .._fiff.proj import Projection, setup_proj
from ..defaults import _handle_default
from ..fixes import _median_complex
from ..rank import compute_rank
from ..transforms import apply_trans
from ..utils import (
Expand All @@ -65,6 +66,7 @@
verbose,
warn,
)
from ..utils.misc import _identity_function
from .ui_events import ColormapRange, publish, subscribe

_channel_type_prettyprint = {
Expand Down Expand Up @@ -2328,30 +2330,63 @@ def _plot_masked_image(


@fill_doc
def _make_combine_callable(combine):
def _make_combine_callable(
combine,
*,
axis=1,
valid=("mean", "median", "std", "gfp"),
ch_type=None,
keepdims=False,
):
"""Convert None or string values of ``combine`` into callables.
Params
------
%(combine)s
If callable, the callable must accept one positional input (data of
shape ``(n_epochs, n_channels, n_times)`` or ``(n_evokeds, n_channels,
n_times)``) and return an :class:`array <numpy.ndarray>` of shape
``(n_epochs, n_times)`` or ``(n_evokeds, n_times)``.
combine : None | str | callable
If callable, the callable must accept one positional input (a numpy array) and
return an array with one fewer dimensions (the missing dimension's position is
given by ``axis``).
axis : int
Axis of data array across which to combine. May vary depending on data
context; e.g., if data are time-domain sensor traces or TFRs, continuous
or epoched, etc.
valid : tuple
Valid string values for built-in combine methods
(may vary for, e.g., combining TFRs versus time-domain signals).
ch_type : str
Channel type. Affects whether "gfp" is allowed as a synonym for "rms".
keepdims : bool
Whether to retain the singleton dimension after collapsing across it.
"""
kwargs = dict(axis=axis, keepdims=keepdims)
if combine is None:
combine = partial(np.squeeze, axis=1)
combine = _identity_function if keepdims else partial(np.squeeze, axis=axis)
elif isinstance(combine, str):
combine_dict = {
key: partial(getattr(np, key), axis=1) for key in ("mean", "median", "std")
key: partial(getattr(np, key), **kwargs)
for key in valid
if getattr(np, key, None) is not None
}
combine_dict["gfp"] = lambda data: np.sqrt((data**2).mean(axis=1))
# marginal median that is safe for complex values:
if "median" in valid:
combine_dict["median"] = partial(_median_complex, axis=axis)

# RMS and GFP; if GFP requested for MEG channels, will use RMS anyway
def _rms(data):
return np.sqrt((data**2).mean(**kwargs))

if "rms" in valid:
combine_dict["rms"] = _rms
if "gfp" in valid and ch_type == "eeg":
combine_dict["gfp"] = lambda data: data.std(axis=axis, ddof=0)
elif "gfp" in valid:
combine_dict["gfp"] = _rms
try:
combine = combine_dict[combine]
except KeyError:
raise ValueError(
'"combine" must be None, a callable, or one of "mean", "median", "std",'
f' or "gfp"; got {combine}'
f'"combine" must be None, a callable, or one of "{", ".join(valid)}"; '
f'got {combine}'
)
return combine

Expand Down

0 comments on commit d8ea2f5

Please sign in to comment.