From a11a4106a55c0150c8904a004603b55ca3fa0a0d Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Wed, 24 Jan 2024 17:25:43 -0600 Subject: [PATCH] Remove support for callabes in constructor --- mne/epochs.py | 15 +++-- mne/tests/test_epochs.py | 66 ++++++++++++++----- mne/utils/docs.py | 56 +++++++++++----- .../preprocessing/20_rejecting_bad_data.py | 11 ++-- 4 files changed, 104 insertions(+), 44 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index c787ae48635..3e1c11650fd 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -787,7 +787,7 @@ def apply_baseline(self, baseline=(None, 0), *, verbose=None): self.baseline = baseline return self - def _reject_setup(self, reject, flat): + def _reject_setup(self, reject, flat, *, allow_callable=False): """Set self._reject_time and self._channel_type_idx.""" idx = channel_indices_by_type(self.info) reject = deepcopy(reject) if reject is not None else dict() @@ -818,9 +818,12 @@ def _reject_setup(self, reject, flat): for rej, kind in zip((reject, flat), ("Rejection", "Flat")): for key, val in rej.items(): name = f"{kind} dict value for {key}" - if callable(val): + if callable(val) and allow_callable: continue - _validate_type(val, "numeric", name, extra="or callable") + extra_str = "" + if allow_callable: + extra_str = "or callable" + _validate_type(val, "numeric", name, extra=extra_str) if val is None or val < 0: raise ValueError( f"If using numerical {name} criteria, the value " @@ -844,7 +847,7 @@ def _reject_setup(self, reject, flat): # make sure new thresholds are at least as stringent as the old ones for key in reject: # Skip this check if old_reject and reject are callables - if callable(reject[key]): + if callable(reject[key]) and allow_callable: continue if key in old_reject and reject[key] > old_reject[key]: raise ValueError( @@ -861,7 +864,7 @@ def _reject_setup(self, reject, flat): for key in set(old_flat) - set(flat): flat[key] = old_flat[key] for key in flat: - if callable(flat[key]): + if callable(flat[key]) and allow_callable: continue if key in old_flat and flat[key] < old_flat[key]: raise ValueError( @@ -1416,7 +1419,7 @@ def drop_bad(self, reject="existing", flat="existing", verbose=None): flat = self.flat if any(isinstance(rej, str) and rej != "existing" for rej in (reject, flat)): raise ValueError('reject and flat, if strings, must be "existing"') - self._reject_setup(reject, flat) + self._reject_setup(reject, flat, allow_callable=True) self._get_data(out=False, verbose=verbose) return self diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 2dce010f01e..34218546ffc 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -549,6 +549,19 @@ def test_reject(): preload=False, reject=dict(eeg=np.inf), ) + + # Good function + def my_reject_1(epoch_data): + bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35) + reasons = "a" * len(bad_idxs[0]) + return len(bad_idxs) > 0, reasons + + # Bad function + def my_reject_2(epoch_data): + bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35) + reasons = "a" * len(bad_idxs[0]) + return len(bad_idxs), reasons + for val in (-1, -2): # protect against older MNE-C types for kwarg in ("reject", "flat"): pytest.raises( @@ -564,23 +577,33 @@ def test_reject(): **{kwarg: dict(grad=val)}, ) - def my_reject_1(epoch_data): - bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35) - return len(bad_idxs) > 0 - - def my_reject_2(epoch_data): - bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35) - reasons = "a" * len(bad_idxs[0]) - return len(bad_idxs), reasons + # Check that reject and flat in constructor are not callables + val = my_reject_1 + for kwarg in ("reject", "flat"): + with pytest.raises( + TypeError, + match=r".* must be an instance of numeric, got instead." + ): + Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks_meg, + preload=False, + **{kwarg: dict(grad=val)}, + ) - bad_types = [my_reject_1, my_reject_2, ("Hi" "Hi"), (1, 1), None] + # Check if callable returns a tuple with reasons + bad_types = [my_reject_2, ("Hi" "Hi"), (1, 1), None] for val in bad_types: # protect against bad types for kwarg in ("reject", "flat"): with pytest.raises( TypeError, match=r".* must be an instance of .* got instead.", ): - Epochs( + epochs = Epochs( raw, events, event_id, @@ -588,8 +611,8 @@ def my_reject_2(epoch_data): tmax, picks=picks_meg, preload=True, - **{kwarg: dict(grad=val)}, ) + epochs.drop_bad(**{kwarg: dict(grad=val)}) pytest.raises( KeyError, @@ -2202,9 +2225,10 @@ def test_callable_reject(): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median")), preload=True, ) + epochs.drop_bad(reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median"))) + assert epochs.drop_log[2] == ("eeg median",) epochs = mne.Epochs( @@ -2213,9 +2237,10 @@ def test_callable_reject(): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), ("eeg max",))), preload=True, ) + epochs.drop_bad(reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), ("eeg max",)))) + assert epochs.drop_log[0] == ("eeg max",) def reject_criteria(x): @@ -2229,25 +2254,31 @@ def reject_criteria(x): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=reject_criteria), preload=True, ) + epochs.drop_bad(reject=dict(eeg=reject_criteria)) + assert epochs.drop_log[0] == ("eeg max or median",) and epochs.drop_log[2] == ( "eeg max or median", ) # Test reasons must be str or tuple of str - with pytest.raises(TypeError): + with pytest.raises( + TypeError, + match=r".* must be an instance of str, got instead.", + ): epochs = mne.Epochs( edit_raw, events, tmin=0, tmax=1, baseline=None, + preload=True, + ) + epochs.drop_bad( reject=dict( eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), ("eeg median", 2)) - ), - preload=True, + ) ) @@ -3323,7 +3354,6 @@ def test_drop_epochs(): ("a", "b"), ] - @pytest.mark.parametrize("preload", (True, False)) def test_drop_epochs_mult(preload): """Test that subselecting epochs or making fewer epochs is similar.""" diff --git a/mne/utils/docs.py b/mne/utils/docs.py index efc1046ac44..c3005427ead 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1442,14 +1442,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ _flat_common = """\ - Reject epochs based on **minimum** peak-to-peak signal amplitude (PTP) - or a custom function. Valid **keys** can be any channel type present - in the object. If using PTP, **values** are floats that set the minimum - acceptable PTP. If the PTP is smaller than this threshold, the epoch - will be dropped. If ``None`` then no rejection is performed based on - flatness of the signal. If a custom function is used than ``flat`` can be - used to reject epochs based on any criteria (including maxima and - minima).""" + Reject epochs based on **minimum** peak-to-peak signal amplitude (PTP). + Valid **keys** can be any channel type present in the object. The + **values** are floats that set the minimum acceptable PTP. If the PTP + is smaller than this threshold, the epoch will be dropped. If ``None`` + then no rejection is performed based on flatness of the signal.""" docdict["flat"] = f""" flat : dict | None @@ -1459,9 +1456,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): quality, pass the ``reject_tmin`` and ``reject_tmax`` parameters. """ -docdict["flat_drop_bad"] = f""" +docdict["flat_drop_bad"] = """ flat : dict | str | None -{_flat_common} + Reject epochs based on **minimum** peak-to-peak signal amplitude (PTP) + or a custom function. Valid **keys** can be any channel type present + in the object. If using PTP, **values** are floats that set the minimum + acceptable PTP. If the PTP is smaller than this threshold, the epoch + will be dropped. If ``None`` then no rejection is performed based on + flatness of the signal. If a custom function is used than ``flat`` can be + used to reject epochs based on any criteria (including maxima and + minima). If ``'existing'``, then the flat parameters set during epoch creation are used. """ @@ -3271,6 +3275,31 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ) _reject_common = """\ + Reject epochs based on **maximum** peak-to-peak signal amplitude (PTP), + i.e. the absolute difference between the lowest and the highest signal + value. In each individual epoch, the PTP is calculated for every channel. + If the PTP of any one channel exceeds the rejection threshold, the + respective epoch will be dropped. + + The dictionary keys correspond to the different channel types; valid + **keys** can be any channel type present in the object. + + Example:: + + reject = dict(grad=4000e-13, # unit: T / m (gradiometers) + mag=4e-12, # unit: T (magnetometers) + eeg=40e-6, # unit: V (EEG channels) + eog=250e-6 # unit: V (EOG channels) + ) + + .. note:: Since rejection is based on a signal **difference** + calculated for each channel separately, applying baseline + correction does not affect the rejection procedure, as the + difference will be preserved. +""" + +docdict["reject_drop_bad"] = """ +reject : dict | str | None Reject epochs based on **maximum** peak-to-peak signal amplitude (PTP) or custom functions. Peak-to-peak signal amplitude is defined as the absolute difference between the lowest and the highest signal @@ -3303,14 +3332,9 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. note:: If ``reject`` is a callable, than **any** criteria can be used to reject epochs (including maxima and minima). -""" # noqa: E501 - -docdict["reject_drop_bad"] = f""" -reject : dict | str | None -{_reject_common} If ``reject`` is ``None``, no rejection is performed. If ``'existing'`` (default), then the rejection parameters set at instantiation are used. -""" +""" # noqa: E501 docdict["reject_epochs"] = f""" reject : dict | None diff --git a/tutorials/preprocessing/20_rejecting_bad_data.py b/tutorials/preprocessing/20_rejecting_bad_data.py index 51f8fa012f8..4883c6bce4c 100644 --- a/tutorials/preprocessing/20_rejecting_bad_data.py +++ b/tutorials/preprocessing/20_rejecting_bad_data.py @@ -338,7 +338,7 @@ # Sometimes it is useful to reject epochs based criteria other than # peak-to-peak amplitudes. For example, we might want to reject epochs # based on the maximum or minimum amplitude of a channel. -# In this case, the :class:`mne.Epochs` class constructor also accepts +# In this case, the `mne.Epochs.drop_bad` function also accepts # callables (functions) in the ``reject`` and ``flat`` parameters. This # allows us to define functions to reject epochs based on our desired criteria. # @@ -376,9 +376,10 @@ tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1e-2).any(), "max amp")), preload=True, ) + +epochs.drop_bad(reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1e-2).any(), "max amp"))) epochs.plot(scalings=dict(eeg=50e-5)) # %% @@ -397,9 +398,10 @@ tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp")), preload=True, ) + +epochs.drop_bad(reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp"))) epochs.plot(scalings=dict(eeg=50e-5)) # %% @@ -420,9 +422,10 @@ def reject_criteria(x): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=reject_criteria), preload=True, ) + +epochs.drop_bad(reject=dict(eeg=reject_criteria)) epochs.plot(events=True) # %%