From 3a42bb913fcbfdfed7ae9e23b5649c51b372eb9c Mon Sep 17 00:00:00 2001 From: Jacob Woessner Date: Fri, 2 Feb 2024 09:13:11 -0600 Subject: [PATCH 1/2] ENH: Add ability to reject epochs using callables (#12195) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Larson --- doc/changes/devel/12195.newfeature.rst | 1 + mne/epochs.py | 102 ++++++++--- mne/tests/test_epochs.py | 165 +++++++++++++++++- mne/utils/docs.py | 48 ++++- mne/utils/mixin.py | 13 +- .../preprocessing/20_rejecting_bad_data.py | 108 +++++++++++- 6 files changed, 399 insertions(+), 38 deletions(-) create mode 100644 doc/changes/devel/12195.newfeature.rst diff --git a/doc/changes/devel/12195.newfeature.rst b/doc/changes/devel/12195.newfeature.rst new file mode 100644 index 00000000000..0c7e044abce --- /dev/null +++ b/doc/changes/devel/12195.newfeature.rst @@ -0,0 +1 @@ +Add ability reject :class:`mne.Epochs` using callables, by `Jacob Woessner`_. \ No newline at end of file diff --git a/mne/epochs.py b/mne/epochs.py index 2b437dca6b3..1e86c6c96b0 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -787,7 +787,7 @@ def apply_baseline(self, baseline=(None, 0), *, verbose=None): self.baseline = baseline return self - def _reject_setup(self, reject, flat): + def _reject_setup(self, reject, flat, *, allow_callable=False): """Set self._reject_time and self._channel_type_idx.""" idx = channel_indices_by_type(self.info) reject = deepcopy(reject) if reject is not None else dict() @@ -814,11 +814,21 @@ def _reject_setup(self, reject, flat): f"{key.upper()}." ) - # check for invalid values - for rej, kind in zip((reject, flat), ("Rejection", "Flat")): - for key, val in rej.items(): - if val is None or val < 0: - raise ValueError(f'{kind} value must be a number >= 0, not "{val}"') + # check for invalid values + for rej, kind in zip((reject, flat), ("Rejection", "Flat")): + for key, val in rej.items(): + name = f"{kind} dict value for {key}" + if callable(val) and allow_callable: + continue + extra_str = "" + if allow_callable: + extra_str = "or callable" + _validate_type(val, "numeric", name, extra=extra_str) + if val is None or val < 0: + raise ValueError( + f"If using numerical {name} criteria, the value " + f"must be >= 0, not {repr(val)}" + ) # now check to see if our rejection and flat are getting more # restrictive @@ -836,6 +846,9 @@ def _reject_setup(self, reject, flat): reject[key] = old_reject[key] # make sure new thresholds are at least as stringent as the old ones for key in reject: + # Skip this check if old_reject and reject are callables + if callable(reject[key]) and allow_callable: + continue if key in old_reject and reject[key] > old_reject[key]: raise ValueError( bad_msg.format( @@ -851,6 +864,8 @@ def _reject_setup(self, reject, flat): for key in set(old_flat) - set(flat): flat[key] = old_flat[key] for key in flat: + if callable(flat[key]) and allow_callable: + continue if key in old_flat and flat[key] < old_flat[key]: raise ValueError( bad_msg.format( @@ -1404,7 +1419,7 @@ def drop_bad(self, reject="existing", flat="existing", verbose=None): flat = self.flat if any(isinstance(rej, str) and rej != "existing" for rej in (reject, flat)): raise ValueError('reject and flat, if strings, must be "existing"') - self._reject_setup(reject, flat) + self._reject_setup(reject, flat, allow_callable=True) self._get_data(out=False, verbose=verbose) return self @@ -1520,8 +1535,9 @@ def drop(self, indices, reason="USER", verbose=None): Set epochs to remove by specifying indices to remove or a boolean mask to apply (where True values get removed). Events are correspondingly modified. - reason : str - Reason for dropping the epochs ('ECG', 'timeout', 'blink' etc). + reason : list | tuple | str + Reason(s) for dropping the epochs ('ECG', 'timeout', 'blink' etc). + Reason(s) are applied to all indices specified. Default: 'USER'. %(verbose)s @@ -1533,7 +1549,9 @@ def drop(self, indices, reason="USER", verbose=None): indices = np.atleast_1d(indices) if indices.ndim > 1: - raise ValueError("indices must be a scalar or a 1-d array") + raise TypeError("indices must be a scalar or a 1-d array") + # Check if indices and reasons are of the same length + # if using collection to drop epochs if indices.dtype == np.dtype(bool): indices = np.where(indices)[0] @@ -3199,6 +3217,10 @@ class Epochs(BaseEpochs): See :meth:`~mne.Epochs.equalize_event_counts` - 'USER' For user-defined reasons (see :meth:`~mne.Epochs.drop`). + + When dropping based on flat or reject parameters the tuple of + reasons contains a tuple of channels that satisfied the rejection + criteria. filename : str The filename of the object. times : ndarray @@ -3667,6 +3689,8 @@ def _is_good( ): """Test if data segment e is good according to reject and flat. + The reject and flat parameters can accept functions as values. + If full_report=True, it will give True/False as well as a list of all offending channels. """ @@ -3674,30 +3698,60 @@ def _is_good( has_printed = False checkable = np.ones(len(ch_names), dtype=bool) checkable[np.array([c in ignore_chs for c in ch_names], dtype=bool)] = False + for refl, f, t in zip([reject, flat], [np.greater, np.less], ["", "flat"]): if refl is not None: - for key, thresh in refl.items(): + for key, refl in refl.items(): + criterion = refl idx = channel_type_idx[key] name = key.upper() if len(idx) > 0: e_idx = e[idx] - deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1) checkable_idx = checkable[idx] - idx_deltas = np.where( - np.logical_and(f(deltas, thresh), checkable_idx) - )[0] + # Check if criterion is a function and apply it + if callable(criterion): + result = criterion(e_idx) + _validate_type(result, tuple, "reject/flat output") + if len(result) != 2: + raise TypeError( + "Function criterion must return a tuple of length 2" + ) + cri_truth, reasons = result + _validate_type(cri_truth, (bool, np.bool_), cri_truth, "bool") + _validate_type( + reasons, (str, list, tuple), reasons, "str, list, or tuple" + ) + idx_deltas = np.where(np.logical_and(cri_truth, checkable_idx))[ + 0 + ] + else: + deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1) + idx_deltas = np.where( + np.logical_and(f(deltas, criterion), checkable_idx) + )[0] if len(idx_deltas) > 0: - bad_names = [ch_names[idx[i]] for i in idx_deltas] - if not has_printed: - logger.info( - f" Rejecting {t} epoch based on {name} : {bad_names}" - ) - has_printed = True - if not full_report: - return False + # Check to verify that refl is a callable that returns + # (bool, reason). Reason must be a str/list/tuple. + # If using tuple + if callable(refl): + if isinstance(reasons, str): + reasons = (reasons,) + for idx, reason in enumerate(reasons): + _validate_type(reason, str, reason) + bad_tuple += tuple(reasons) else: - bad_tuple += tuple(bad_names) + bad_names = [ch_names[idx[i]] for i in idx_deltas] + if not has_printed: + logger.info( + " Rejecting %s epoch based on %s : " + "%s" % (t, name, bad_names) + ) + has_printed = True + if not full_report: + return False + else: + bad_tuple += tuple(bad_names) if not full_report: return True diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 2b67dd9dbd6..96d90414e07 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -550,7 +550,20 @@ def test_reject(): preload=False, reject=dict(eeg=np.inf), ) - for val in (None, -1): # protect against older MNE-C types + + # Good function + def my_reject_1(epoch_data): + bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35) + reasons = "a" * len(bad_idxs[0]) + return len(bad_idxs) > 0, reasons + + # Bad function + def my_reject_2(epoch_data): + bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35) + reasons = "a" * len(bad_idxs[0]) + return len(bad_idxs), reasons + + for val in (-1, -2): # protect against older MNE-C types for kwarg in ("reject", "flat"): pytest.raises( ValueError, @@ -564,6 +577,44 @@ def test_reject(): preload=False, **{kwarg: dict(grad=val)}, ) + + # Check that reject and flat in constructor are not callables + val = my_reject_1 + for kwarg in ("reject", "flat"): + with pytest.raises( + TypeError, + match=r".* must be an instance of numeric, got instead.", + ): + Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks_meg, + preload=False, + **{kwarg: dict(grad=val)}, + ) + + # Check if callable returns a tuple with reasons + bad_types = [my_reject_2, ("Hi" "Hi"), (1, 1), None] + for val in bad_types: # protect against bad types + for kwarg in ("reject", "flat"): + with pytest.raises( + TypeError, + match=r".* must be an instance of .* got instead.", + ): + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks_meg, + preload=True, + ) + epochs.drop_bad(**{kwarg: dict(grad=val)}) + pytest.raises( KeyError, Epochs, @@ -2149,6 +2200,93 @@ def test_reject_epochs(tmp_path): assert epochs_cleaned.flat == dict(grad=new_flat["grad"], mag=flat["mag"]) +@testing.requires_testing_data +def test_callable_reject(): + """Test using a callable for rejection.""" + raw = read_raw_fif(fname_raw_testing, preload=True) + raw.crop(0, 5) + raw.del_proj() + chans = raw.info["ch_names"][-6:-1] + raw.pick(chans) + data = raw.get_data() + + # Add some artifacts + new_data = data + new_data[0, 180:200] *= 1e7 + new_data[0, 610:880] += 1e-3 + edit_raw = mne.io.RawArray(new_data, raw.info) + + events = mne.make_fixed_length_events(edit_raw, id=1, duration=1.0, start=0) + epochs = mne.Epochs(edit_raw, events, tmin=0, tmax=1, baseline=None, preload=True) + assert len(epochs) == 5 + + epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + preload=True, + ) + epochs.drop_bad( + reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median")) + ) + + assert epochs.drop_log[2] == ("eeg median",) + + epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + preload=True, + ) + epochs.drop_bad( + reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), ("eeg max",))) + ) + + assert epochs.drop_log[0] == ("eeg max",) + + def reject_criteria(x): + max_condition = np.max(x, axis=1) > 1e-2 + median_condition = np.median(x, axis=1) > 1e-4 + return (max_condition.any() or median_condition.any()), "eeg max or median" + + epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + preload=True, + ) + epochs.drop_bad(reject=dict(eeg=reject_criteria)) + + assert epochs.drop_log[0] == ("eeg max or median",) and epochs.drop_log[2] == ( + "eeg max or median", + ) + + # Test reasons must be str or tuple of str + with pytest.raises( + TypeError, + match=r".* must be an instance of str, got instead.", + ): + epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + preload=True, + ) + epochs.drop_bad( + reject=dict( + eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), ("eeg median", 2)) + ) + ) + + def test_preload_epochs(): """Test preload of epochs.""" raw, events, picks = _get_data() @@ -3180,9 +3318,16 @@ def test_drop_epochs(): events1 = events[events[:, 2] == event_id] # Bound checks - pytest.raises(IndexError, epochs.drop, [len(epochs.events)]) - pytest.raises(IndexError, epochs.drop, [-len(epochs.events) - 1]) - pytest.raises(ValueError, epochs.drop, [[1, 2], [3, 4]]) + with pytest.raises(IndexError, match=r"Epoch index .* is out of bounds"): + epochs.drop([len(epochs.events)]) + with pytest.raises(IndexError, match=r"Epoch index .* is out of bounds"): + epochs.drop([-len(epochs.events) - 1]) + with pytest.raises(TypeError, match="indices must be a scalar or a 1-d array"): + epochs.drop([[1, 2], [3, 4]]) + with pytest.raises( + TypeError, match=r".* must be an instance of .* got instead." + ): + epochs.drop([1], reason=("a", "b", 2)) # Test selection attribute assert_array_equal(epochs.selection, np.where(events[:, 2] == event_id)[0]) @@ -3202,6 +3347,18 @@ def test_drop_epochs(): assert_array_equal(events[epochs[3:].selection], events1[[5, 6]]) assert_array_equal(events[epochs["1"].selection], events1[[0, 1, 3, 5, 6]]) + # Test using tuple to drop epochs + raw, events, picks = _get_data() + epochs_tuple = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True) + selection_tuple = epochs_tuple.selection.copy() + epochs_tuple.drop((2, 3, 4), reason=("a", "b")) + n_events = len(epochs.events) + assert [epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]] == [ + ("a", "b"), + ("a", "b"), + ("a", "b"), + ] + @pytest.mark.parametrize("preload", (True, False)) def test_drop_epochs_mult(preload): diff --git a/mne/utils/docs.py b/mne/utils/docs.py index ec9fe66bae0..c3005427ead 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1456,9 +1456,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): quality, pass the ``reject_tmin`` and ``reject_tmax`` parameters. """ -docdict["flat_drop_bad"] = f""" +docdict["flat_drop_bad"] = """ flat : dict | str | None -{_flat_common} + Reject epochs based on **minimum** peak-to-peak signal amplitude (PTP) + or a custom function. Valid **keys** can be any channel type present + in the object. If using PTP, **values** are floats that set the minimum + acceptable PTP. If the PTP is smaller than this threshold, the epoch + will be dropped. If ``None`` then no rejection is performed based on + flatness of the signal. If a custom function is used than ``flat`` can be + used to reject epochs based on any criteria (including maxima and + minima). If ``'existing'``, then the flat parameters set during epoch creation are used. """ @@ -3291,12 +3298,43 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): difference will be preserved. """ -docdict["reject_drop_bad"] = f""" +docdict["reject_drop_bad"] = """ reject : dict | str | None -{_reject_common} + Reject epochs based on **maximum** peak-to-peak signal amplitude (PTP) + or custom functions. Peak-to-peak signal amplitude is defined as + the absolute difference between the lowest and the highest signal + value. In each individual epoch, the PTP is calculated for every channel. + If the PTP of any one channel exceeds the rejection threshold, the + respective epoch will be dropped. + + The dictionary keys correspond to the different channel types; valid + **keys** can be any channel type present in the object. + + Example:: + + reject = dict(grad=4000e-13, # unit: T / m (gradiometers) + mag=4e-12, # unit: T (magnetometers) + eeg=40e-6, # unit: V (EEG channels) + eog=250e-6 # unit: V (EOG channels) + ) + + Custom rejection criteria can be also be used by passing a callable, + e.g., to check for 99th percentile of absolute values of any channel + across time being bigger than 1mV. The callable must return a good, reason tuple. + Where good must be bool and reason must be str, list, or tuple where each entry is a str.:: + + reject = dict(eeg=lambda x: ((np.percentile(np.abs(x), 99, axis=1) > 1e-3).any(), "> 1mV somewhere")) + + .. note:: If rejection is based on a signal **difference** + calculated for each channel separately, applying baseline + correction does not affect the rejection procedure, as the + difference will be preserved. + + .. note:: If ``reject`` is a callable, than **any** criteria can be + used to reject epochs (including maxima and minima). If ``reject`` is ``None``, no rejection is performed. If ``'existing'`` (default), then the rejection parameters set at instantiation are used. -""" +""" # noqa: E501 docdict["reject_epochs"] = f""" reject : dict | None diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index c90121fdfbb..87e86aaa315 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -178,7 +178,7 @@ def _getitem( ---------- item: slice, array-like, str, or list see `__getitem__` for details. - reason: str + reason: str, list/tuple of str entry in `drop_log` for unselected epochs copy: bool return a copy of the current object @@ -209,8 +209,15 @@ def _getitem( key_selection = inst.selection[select] drop_log = list(inst.drop_log) if reason is not None: - for k in np.setdiff1d(inst.selection, key_selection): - drop_log[k] = (reason,) + _validate_type(reason, (list, tuple, str), "reason") + if isinstance(reason, (list, tuple)): + for r in reason: + _validate_type(r, str, r) + if isinstance(reason, str): + reason = (reason,) + reason = tuple(reason) + for idx in np.setdiff1d(inst.selection, key_selection): + drop_log[idx] = reason inst.drop_log = tuple(drop_log) inst.selection = key_selection del drop_log diff --git a/tutorials/preprocessing/20_rejecting_bad_data.py b/tutorials/preprocessing/20_rejecting_bad_data.py index d478255b048..a04005f3532 100644 --- a/tutorials/preprocessing/20_rejecting_bad_data.py +++ b/tutorials/preprocessing/20_rejecting_bad_data.py @@ -23,6 +23,8 @@ import os +import numpy as np + import mne sample_data_folder = mne.datasets.sample.data_path() @@ -205,8 +207,8 @@ # %% # .. _`tut-reject-epochs-section`: # -# Rejecting Epochs based on channel amplitude -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Rejecting Epochs based on peak-to-peak channel amplitude +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # # Besides "bad" annotations, the :class:`mne.Epochs` class constructor has # another means of rejecting epochs, based on signal amplitude thresholds for @@ -328,6 +330,108 @@ epochs.drop_bad(reject=stronger_reject_criteria) print(epochs.drop_log) +# %% +# .. _`tut-reject-epochs-func-section`: +# +# Rejecting Epochs using callables (functions) +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Sometimes it is useful to reject epochs based criteria other than +# peak-to-peak amplitudes. For example, we might want to reject epochs +# based on the maximum or minimum amplitude of a channel. +# In this case, the `mne.Epochs.drop_bad` function also accepts +# callables (functions) in the ``reject`` and ``flat`` parameters. This +# allows us to define functions to reject epochs based on our desired criteria. +# +# Let's begin by generating Epoch data with large artifacts in one eeg channel +# in order to demonstrate the versatility of this approach. + +raw.crop(0, 5) +raw.del_proj() +chans = raw.info["ch_names"][-5:-1] +raw.pick(chans) +data = raw.get_data() + +new_data = data +new_data[0, 180:200] *= 1e3 +new_data[0, 460:580] += 1e-3 +edit_raw = mne.io.RawArray(new_data, raw.info) + +# Create fixed length epochs of 1 second +events = mne.make_fixed_length_events(edit_raw, id=1, duration=1.0, start=0) +epochs = mne.Epochs(edit_raw, events, tmin=0, tmax=1, baseline=None) +epochs.plot(scalings=dict(eeg=50e-5)) + +# %% +# As you can see, we have two large artifacts in the first channel. One large +# spike in amplitude and one large increase in amplitude. + +# Let's try to reject the epoch containing the spike in amplitude based on the +# maximum amplitude of the first channel. Please note that the callable in +# ``reject`` must return a (good, reason) tuple. Where the good must be bool +# and reason must be a str, list, or tuple where each entry is a str. + +epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + preload=True, +) + +epochs.drop_bad( + reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1e-2).any(), "max amp")) +) +epochs.plot(scalings=dict(eeg=50e-5)) + +# %% +# Here, the epoch containing the spike in amplitude was rejected for having a +# maximum amplitude greater than 1e-2 Volts. Notice the use of the ``any()`` +# function to check if any of the channels exceeded the threshold. We could +# have also used the ``all()`` function to check if all channels exceeded the +# threshold. + +# Next, let's try to reject the epoch containing the increase in amplitude +# using the median. + +epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + preload=True, +) + +epochs.drop_bad( + reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp")) +) +epochs.plot(scalings=dict(eeg=50e-5)) + +# %% +# Finally, let's try to reject both epochs using a combination of the maximum +# and median. We'll define a custom function and use boolean operators to +# combine the two criteria. + + +def reject_criteria(x): + max_condition = np.max(x, axis=1) > 1e-2 + median_condition = np.median(x, axis=1) > 1e-4 + return ((max_condition.any() or median_condition.any()), ["max amp", "median amp"]) + + +epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + preload=True, +) + +epochs.drop_bad(reject=dict(eeg=reject_criteria)) +epochs.plot(events=True) + # %% # Note that a complementary Python module, the `autoreject package`_, uses # machine learning to find optimal rejection criteria, and is designed to From d8ea2f5e60174d61301dbefbed9c76c9adc01ec9 Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Fri, 2 Feb 2024 09:56:10 -0600 Subject: [PATCH 2/2] actually use GFP for EEG channels in plot_compare_evokeds (#12410) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- doc/changes/devel/12410.bugfix.rst | 1 + mne/time_frequency/spectrum.py | 6 +-- mne/utils/docs.py | 46 +++++++++++++++++++--- mne/utils/misc.py | 4 ++ mne/viz/epochs.py | 14 +------ mne/viz/evoked.py | 61 ++++++++++++++++++------------ mne/viz/tests/test_evoked.py | 33 +++++++++++----- mne/viz/utils.py | 57 ++++++++++++++++++++++------ 8 files changed, 153 insertions(+), 69 deletions(-) create mode 100644 doc/changes/devel/12410.bugfix.rst diff --git a/doc/changes/devel/12410.bugfix.rst b/doc/changes/devel/12410.bugfix.rst new file mode 100644 index 00000000000..c5d939845b0 --- /dev/null +++ b/doc/changes/devel/12410.bugfix.rst @@ -0,0 +1 @@ +In :func:`~mne.viz.plot_compare_evokeds`, actually plot GFP (not RMS amplitude) for EEG channels when global field power is requested by `Daniel McCloy`_. \ No newline at end of file diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index a7a2a753932..e46be389695 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -45,7 +45,7 @@ _is_numeric, check_fname, ) -from ..utils.misc import _pl +from ..utils.misc import _identity_function, _pl from ..utils.spectrum import _split_psd_kwargs from ..viz.topo import _plot_timeseries, _plot_timeseries_unified, _plot_topo from ..viz.topomap import _make_head_outlines, _prepare_topomap_plot, plot_psds_topomap @@ -60,10 +60,6 @@ from .psd import _check_nfft, psd_array_welch -def _identity_function(x): - return x - - class SpectrumMixin: """Mixin providing spectral plotting methods to sensor-space containers.""" diff --git a/mne/utils/docs.py b/mne/utils/docs.py index c3005427ead..87f457a982b 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -718,12 +718,46 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): 0 and 255. """ -docdict["combine"] = """ -combine : None | str | callable - How to combine information across channels. If a :class:`str`, must be - one of 'mean', 'median', 'std' (standard deviation) or 'gfp' (global - field power). -""" +_combine_template = """ +combine : 'mean' | {literals} | callable | None + How to aggregate across channels. If ``None``, {none}. If a string, + ``"mean"`` uses :func:`numpy.mean`, {other_string}. + If :func:`callable`, it must operate on an :class:`array ` + of shape ``({shape})`` and return an array of shape + ``({return_shape})``. {example} + {notes}Defaults to ``None``. +""" +_example = """For example:: + + combine = lambda data: np.median(data, axis=1) +""" +_median_std_gfp = """``"median"`` computes the `marginal median + `__, ``"std"`` + uses :func:`numpy.std`, and ``"gfp"`` computes global field power + for EEG channels and RMS amplitude for MEG channels""" +docdict["combine_plot_compare_evokeds"] = _combine_template.format( + literals="'median' | 'std' | 'gfp'", + none="""channels are combined by + computing GFP/RMS, unless ``picks`` is a single channel (not channel type) + or ``axes="topo"``, in which cases no combining is performed""", + other_string=_median_std_gfp, + shape="n_evokeds, n_channels, n_times", + return_shape="n_evokeds, n_times", + example=_example, + notes="", +) +docdict["combine_plot_epochs_image"] = _combine_template.format( + literals="'median' | 'std' | 'gfp'", + none="""channels are combined by + computing GFP/RMS, unless ``group_by`` is also ``None`` and ``picks`` is a + list of specific channels (not channel types), in which case no combining + is performed and each channel gets its own figure""", + other_string=_median_std_gfp, + shape="n_epochs, n_channels, n_times", + return_shape="n_epochs, n_times", + example=_example, + notes="See Notes for further details. ", +) docdict["compute_proj_ecg"] = """This function will: diff --git a/mne/utils/misc.py b/mne/utils/misc.py index 3f342c80570..2cebf8e5450 100644 --- a/mne/utils/misc.py +++ b/mne/utils/misc.py @@ -28,6 +28,10 @@ from .check import _check_option, _validate_type +def _identity_function(x): + return x + + # TODO: no longer needed when py3.9 is minimum supported version def _empty_hash(kind="md5"): func = getattr(hashlib, kind) diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index 95989637523..9871a0c2647 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -145,19 +145,7 @@ def plot_epochs_image( ``overlay_times`` should be ordered to correspond with the :class:`~mne.Epochs` object (i.e., ``overlay_times[0]`` corresponds to ``epochs[0]``, etc). - %(combine)s - If callable, the callable must accept one positional input (data of - shape ``(n_epochs, n_channels, n_times)``) and return an - :class:`array ` of shape ``(n_epochs, n_times)``. For - example:: - - combine = lambda data: np.median(data, axis=1) - - If ``combine`` is ``None``, channels are combined by computing GFP, - unless ``group_by`` is also ``None`` and ``picks`` is a list of - specific channels (not channel types), in which case no combining is - performed and each channel gets its own figure. See Notes for further - details. Defaults to ``None``. + %(combine_plot_epochs_image)s group_by : None | dict Specifies which channels are aggregated into a single figure, with aggregation method determined by the ``combine`` parameter. If not diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index 96976532767..f2a47fbe4d0 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -2484,14 +2484,22 @@ def _draw_axes_pce( ) -def _get_data_and_ci(evoked, combine, combine_func, picks, scaling=1, ci_fun=None): +def _get_data_and_ci( + evoked, combine, combine_func, ch_type, picks, scaling=1, ci_fun=None +): """Compute (sensor-aggregated, scaled) time series and possibly CI.""" picks = np.array(picks).flatten() # apply scalings data = np.array([evk.data[picks] * scaling for evk in evoked]) # combine across sensors if combine is not None: - logger.info(f'combining channels using "{combine}"') + if combine == "gfp" and ch_type == "eeg": + msg = f"GFP ({ch_type} channels)" + elif combine == "gfp" and ch_type in ("mag", "grad"): + msg = f"RMS ({ch_type} channels)" + else: + msg = f'"{combine}"' + logger.info(f"combining channels using {msg}") data = combine_func(data) # get confidence band if ci_fun is not None: @@ -2551,7 +2559,7 @@ def _plot_compare_evokeds( ax.set_title(title) -def _title_helper_pce(title, picked_types, picks, ch_names, combine): +def _title_helper_pce(title, picked_types, picks, ch_names, ch_type, combine): """Format title for plot_compare_evokeds.""" if title is None: title = ( @@ -2562,8 +2570,12 @@ def _title_helper_pce(title, picked_types, picks, ch_names, combine): # add the `combine` modifier do_combine = picked_types or len(ch_names) > 1 if title is not None and len(title) and isinstance(combine, str) and do_combine: - _comb = combine.upper() if combine == "gfp" else combine - _comb = "std. dev." if _comb == "std" else _comb + if combine == "gfp": + _comb = "RMS" if ch_type in ("mag", "grad") else "GFP" + elif combine == "std": + _comb = "std. dev." + else: + _comb = combine title += f" ({_comb})" return title @@ -2744,18 +2756,7 @@ def plot_compare_evokeds( value of the ``combine`` parameter. Defaults to ``None``. show : bool Whether to show the figure. Defaults to ``True``. - %(combine)s - If callable, the callable must accept one positional input (data of - shape ``(n_evokeds, n_channels, n_times)``) and return an - :class:`array ` of shape ``(n_epochs, n_times)``. For - example:: - - combine = lambda data: np.median(data, axis=1) - - If ``combine`` is ``None``, channels are combined by computing GFP, - unless ``picks`` is a single channel (not channel type) or - ``axes='topo'``, in which cases no combining is performed. Defaults to - ``None``. + %(combine_plot_compare_evokeds)s %(sphere_topomap_auto)s %(time_unit)s @@ -2914,11 +2915,19 @@ def plot_compare_evokeds( if combine is None and len(picks) > 1 and not do_topo: combine = "gfp" # convert `combine` into callable (if None or str) - combine_func = _make_combine_callable(combine) + combine_funcs = { + ch_type: _make_combine_callable(combine, ch_type=ch_type) + for ch_type in ch_types + } # title title = _title_helper_pce( - title, picked_types, picks=orig_picks, ch_names=ch_names, combine=combine + title, + picked_types, + picks=orig_picks, + ch_names=ch_names, + ch_type=ch_types[0] if len(ch_types) == 1 else None, + combine=combine, ) topo_disp_title = False # setup axes @@ -2943,9 +2952,7 @@ def plot_compare_evokeds( _validate_if_list_of_axes(axes, obligatory_len=len(ch_types)) if len(ch_types) > 1: - logger.info( - "Multiple channel types selected, returning one figure " "per type." - ) + logger.info("Multiple channel types selected, returning one figure per type.") figs = list() for ch_type, ax in zip(ch_types, axes): _picks = picks_by_type[ch_type] @@ -2954,7 +2961,12 @@ def plot_compare_evokeds( # don't pass `combine` here; title will run through this helper # function a second time & it will get added then _title = _title_helper_pce( - title, picked_types, picks=_picks, ch_names=_ch_names, combine=None + title, + picked_types, + picks=_picks, + ch_names=_ch_names, + ch_type=ch_type, + combine=None, ) figs.extend( plot_compare_evokeds( @@ -3003,7 +3015,7 @@ def plot_compare_evokeds( # some things that depend on ch_type: units = _handle_default("units")[ch_type] scalings = _handle_default("scalings")[ch_type] - + combine_func = combine_funcs[ch_type] # prep for topo pos_picks = picks # need this version of picks for sensor location inset info = pick_info(info, sel=picks, copy=True) @@ -3136,6 +3148,7 @@ def click_func( this_evokeds, combine, c_func, + ch_type=ch_type, picks=_picks, scaling=scalings, ci_fun=ci_fun, diff --git a/mne/viz/tests/test_evoked.py b/mne/viz/tests/test_evoked.py index 66609839df8..e177df6a9b8 100644 --- a/mne/viz/tests/test_evoked.py +++ b/mne/viz/tests/test_evoked.py @@ -402,21 +402,34 @@ def test_plot_white(): evoked_sss.plot_white(cov, time_unit="s") +@pytest.mark.parametrize( + "combine,vlines,title,picks", + ( + pytest.param(None, [0.1, 0.2], "MEG 0113", "MEG 0113", id="singlepick"), + pytest.param("mean", [], "(mean)", "mag", id="mag-mean"), + pytest.param("gfp", "auto", "(GFP)", "eeg", id="eeg-gfp"), + pytest.param(None, "auto", "(RMS)", ["MEG 0113", "MEG 0112"], id="meg-rms"), + pytest.param( + "std", "auto", "(std. dev.)", ["MEG 0113", "MEG 0112"], id="meg-std" + ), + pytest.param( + lambda x: np.min(x, axis=1), "auto", "MEG 0112", [0, 1], id="intpicks" + ), + ), +) +def test_plot_compare_evokeds_title(evoked, picks, vlines, combine, title): + """Test title generation by plot_compare_evokeds().""" + # test picks, combine, and vlines (1-channel pick also shows sensor inset) + fig = plot_compare_evokeds(evoked, picks=picks, vlines=vlines, combine=combine) + assert fig[0].axes[0].get_title().endswith(title) + + @pytest.mark.slowtest # slow on Azure -def test_plot_compare_evokeds(): +def test_plot_compare_evokeds(evoked): """Test plot_compare_evokeds.""" - evoked = _get_epochs().average() # test defaults figs = plot_compare_evokeds(evoked) assert len(figs) == 3 - # test picks, combine, and vlines (1-channel pick also shows sensor inset) - picks = ["MEG 0113", "mag"] + 2 * [["MEG 0113", "MEG 0112"]] + [[0, 1]] - vlines = [[0.1, 0.2], []] + 3 * ["auto"] - combine = [None, "mean", "std", None, lambda x: np.min(x, axis=1)] - title = ["MEG 0113", "(mean)", "(std. dev.)", "(GFP)", "MEG 0112"] - for _p, _v, _c, _t in zip(picks, vlines, combine, title): - fig = plot_compare_evokeds(evoked, picks=_p, vlines=_v, combine=_c) - assert fig[0].axes[0].get_title().endswith(_t) # test passing more than one evoked red, blue = evoked.copy(), evoked.copy() red.comment = red.comment + "*" * 100 diff --git a/mne/viz/utils.py b/mne/viz/utils.py index eeaf3d1098e..d325c474a16 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -46,6 +46,7 @@ ) from .._fiff.proj import Projection, setup_proj from ..defaults import _handle_default +from ..fixes import _median_complex from ..rank import compute_rank from ..transforms import apply_trans from ..utils import ( @@ -65,6 +66,7 @@ verbose, warn, ) +from ..utils.misc import _identity_function from .ui_events import ColormapRange, publish, subscribe _channel_type_prettyprint = { @@ -2328,30 +2330,63 @@ def _plot_masked_image( @fill_doc -def _make_combine_callable(combine): +def _make_combine_callable( + combine, + *, + axis=1, + valid=("mean", "median", "std", "gfp"), + ch_type=None, + keepdims=False, +): """Convert None or string values of ``combine`` into callables. Params ------ - %(combine)s - If callable, the callable must accept one positional input (data of - shape ``(n_epochs, n_channels, n_times)`` or ``(n_evokeds, n_channels, - n_times)``) and return an :class:`array ` of shape - ``(n_epochs, n_times)`` or ``(n_evokeds, n_times)``. + combine : None | str | callable + If callable, the callable must accept one positional input (a numpy array) and + return an array with one fewer dimensions (the missing dimension's position is + given by ``axis``). + axis : int + Axis of data array across which to combine. May vary depending on data + context; e.g., if data are time-domain sensor traces or TFRs, continuous + or epoched, etc. + valid : tuple + Valid string values for built-in combine methods + (may vary for, e.g., combining TFRs versus time-domain signals). + ch_type : str + Channel type. Affects whether "gfp" is allowed as a synonym for "rms". + keepdims : bool + Whether to retain the singleton dimension after collapsing across it. """ + kwargs = dict(axis=axis, keepdims=keepdims) if combine is None: - combine = partial(np.squeeze, axis=1) + combine = _identity_function if keepdims else partial(np.squeeze, axis=axis) elif isinstance(combine, str): combine_dict = { - key: partial(getattr(np, key), axis=1) for key in ("mean", "median", "std") + key: partial(getattr(np, key), **kwargs) + for key in valid + if getattr(np, key, None) is not None } - combine_dict["gfp"] = lambda data: np.sqrt((data**2).mean(axis=1)) + # marginal median that is safe for complex values: + if "median" in valid: + combine_dict["median"] = partial(_median_complex, axis=axis) + + # RMS and GFP; if GFP requested for MEG channels, will use RMS anyway + def _rms(data): + return np.sqrt((data**2).mean(**kwargs)) + + if "rms" in valid: + combine_dict["rms"] = _rms + if "gfp" in valid and ch_type == "eeg": + combine_dict["gfp"] = lambda data: data.std(axis=axis, ddof=0) + elif "gfp" in valid: + combine_dict["gfp"] = _rms try: combine = combine_dict[combine] except KeyError: raise ValueError( - '"combine" must be None, a callable, or one of "mean", "median", "std",' - f' or "gfp"; got {combine}' + f'"combine" must be None, a callable, or one of "{", ".join(valid)}"; ' + f'got {combine}' ) return combine