Skip to content

Commit

Permalink
Remove support for callabes in constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
withmywoessner committed Jan 24, 2024
1 parent b5829b0 commit a11a410
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 44 deletions.
15 changes: 9 additions & 6 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def apply_baseline(self, baseline=(None, 0), *, verbose=None):
self.baseline = baseline
return self

def _reject_setup(self, reject, flat):
def _reject_setup(self, reject, flat, *, allow_callable=False):
"""Set self._reject_time and self._channel_type_idx."""
idx = channel_indices_by_type(self.info)
reject = deepcopy(reject) if reject is not None else dict()
Expand Down Expand Up @@ -818,9 +818,12 @@ def _reject_setup(self, reject, flat):
for rej, kind in zip((reject, flat), ("Rejection", "Flat")):
for key, val in rej.items():
name = f"{kind} dict value for {key}"
if callable(val):
if callable(val) and allow_callable:
continue
_validate_type(val, "numeric", name, extra="or callable")
extra_str = ""
if allow_callable:
extra_str = "or callable"
_validate_type(val, "numeric", name, extra=extra_str)
if val is None or val < 0:
raise ValueError(
f"If using numerical {name} criteria, the value "
Expand All @@ -844,7 +847,7 @@ def _reject_setup(self, reject, flat):
# make sure new thresholds are at least as stringent as the old ones
for key in reject:
# Skip this check if old_reject and reject are callables
if callable(reject[key]):
if callable(reject[key]) and allow_callable:
continue
if key in old_reject and reject[key] > old_reject[key]:
raise ValueError(
Expand All @@ -861,7 +864,7 @@ def _reject_setup(self, reject, flat):
for key in set(old_flat) - set(flat):
flat[key] = old_flat[key]
for key in flat:
if callable(flat[key]):
if callable(flat[key]) and allow_callable:
continue
if key in old_flat and flat[key] < old_flat[key]:
raise ValueError(
Expand Down Expand Up @@ -1416,7 +1419,7 @@ def drop_bad(self, reject="existing", flat="existing", verbose=None):
flat = self.flat
if any(isinstance(rej, str) and rej != "existing" for rej in (reject, flat)):
raise ValueError('reject and flat, if strings, must be "existing"')
self._reject_setup(reject, flat)
self._reject_setup(reject, flat, allow_callable=True)
self._get_data(out=False, verbose=verbose)
return self

Expand Down
66 changes: 48 additions & 18 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,19 @@ def test_reject():
preload=False,
reject=dict(eeg=np.inf),
)

# Good function
def my_reject_1(epoch_data):
bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35)
reasons = "a" * len(bad_idxs[0])
return len(bad_idxs) > 0, reasons

# Bad function
def my_reject_2(epoch_data):
bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35)
reasons = "a" * len(bad_idxs[0])
return len(bad_idxs), reasons

for val in (-1, -2): # protect against older MNE-C types
for kwarg in ("reject", "flat"):
pytest.raises(
Expand All @@ -564,32 +577,42 @@ def test_reject():
**{kwarg: dict(grad=val)},
)

def my_reject_1(epoch_data):
bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35)
return len(bad_idxs) > 0

def my_reject_2(epoch_data):
bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35)
reasons = "a" * len(bad_idxs[0])
return len(bad_idxs), reasons
# Check that reject and flat in constructor are not callables
val = my_reject_1
for kwarg in ("reject", "flat"):
with pytest.raises(
TypeError,
match=r".* must be an instance of numeric, got <class 'function'> instead."
):
Epochs(
raw,
events,
event_id,
tmin,
tmax,
picks=picks_meg,
preload=False,
**{kwarg: dict(grad=val)},
)

bad_types = [my_reject_1, my_reject_2, ("Hi" "Hi"), (1, 1), None]
# Check if callable returns a tuple with reasons
bad_types = [my_reject_2, ("Hi" "Hi"), (1, 1), None]
for val in bad_types: # protect against bad types
for kwarg in ("reject", "flat"):
with pytest.raises(
TypeError,
match=r".* must be an instance of .* got <class '.*'> instead.",
):
Epochs(
epochs = Epochs(
raw,
events,
event_id,
tmin,
tmax,
picks=picks_meg,
preload=True,
**{kwarg: dict(grad=val)},
)
epochs.drop_bad(**{kwarg: dict(grad=val)})

pytest.raises(
KeyError,
Expand Down Expand Up @@ -2202,9 +2225,10 @@ def test_callable_reject():
tmin=0,
tmax=1,
baseline=None,
reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median")),
preload=True,
)
epochs.drop_bad(reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median")))

assert epochs.drop_log[2] == ("eeg median",)

epochs = mne.Epochs(
Expand All @@ -2213,9 +2237,10 @@ def test_callable_reject():
tmin=0,
tmax=1,
baseline=None,
reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), ("eeg max",))),
preload=True,
)
epochs.drop_bad(reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), ("eeg max",))))

assert epochs.drop_log[0] == ("eeg max",)

def reject_criteria(x):
Expand All @@ -2229,25 +2254,31 @@ def reject_criteria(x):
tmin=0,
tmax=1,
baseline=None,
reject=dict(eeg=reject_criteria),
preload=True,
)
epochs.drop_bad(reject=dict(eeg=reject_criteria))

assert epochs.drop_log[0] == ("eeg max or median",) and epochs.drop_log[2] == (
"eeg max or median",
)

# Test reasons must be str or tuple of str
with pytest.raises(TypeError):
with pytest.raises(
TypeError,
match=r".* must be an instance of str, got <class 'int'> instead.",
):
epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
preload=True,
)
epochs.drop_bad(
reject=dict(
eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), ("eeg median", 2))
),
preload=True,
)
)


Expand Down Expand Up @@ -3323,7 +3354,6 @@ def test_drop_epochs():
("a", "b"),
]


@pytest.mark.parametrize("preload", (True, False))
def test_drop_epochs_mult(preload):
"""Test that subselecting epochs or making fewer epochs is similar."""
Expand Down
56 changes: 40 additions & 16 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,14 +1442,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
"""

_flat_common = """\
Reject epochs based on **minimum** peak-to-peak signal amplitude (PTP)
or a custom function. Valid **keys** can be any channel type present
in the object. If using PTP, **values** are floats that set the minimum
acceptable PTP. If the PTP is smaller than this threshold, the epoch
will be dropped. If ``None`` then no rejection is performed based on
flatness of the signal. If a custom function is used than ``flat`` can be
used to reject epochs based on any criteria (including maxima and
minima)."""
Reject epochs based on **minimum** peak-to-peak signal amplitude (PTP).
Valid **keys** can be any channel type present in the object. The
**values** are floats that set the minimum acceptable PTP. If the PTP
is smaller than this threshold, the epoch will be dropped. If ``None``
then no rejection is performed based on flatness of the signal."""

docdict["flat"] = f"""
flat : dict | None
Expand All @@ -1459,9 +1456,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
quality, pass the ``reject_tmin`` and ``reject_tmax`` parameters.
"""

docdict["flat_drop_bad"] = f"""
docdict["flat_drop_bad"] = """
flat : dict | str | None
{_flat_common}
Reject epochs based on **minimum** peak-to-peak signal amplitude (PTP)
or a custom function. Valid **keys** can be any channel type present
in the object. If using PTP, **values** are floats that set the minimum
acceptable PTP. If the PTP is smaller than this threshold, the epoch
will be dropped. If ``None`` then no rejection is performed based on
flatness of the signal. If a custom function is used than ``flat`` can be
used to reject epochs based on any criteria (including maxima and
minima).
If ``'existing'``, then the flat parameters set during epoch creation are
used.
"""
Expand Down Expand Up @@ -3271,6 +3275,31 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
)

_reject_common = """\
Reject epochs based on **maximum** peak-to-peak signal amplitude (PTP),
i.e. the absolute difference between the lowest and the highest signal
value. In each individual epoch, the PTP is calculated for every channel.
If the PTP of any one channel exceeds the rejection threshold, the
respective epoch will be dropped.
The dictionary keys correspond to the different channel types; valid
**keys** can be any channel type present in the object.
Example::
reject = dict(grad=4000e-13, # unit: T / m (gradiometers)
mag=4e-12, # unit: T (magnetometers)
eeg=40e-6, # unit: V (EEG channels)
eog=250e-6 # unit: V (EOG channels)
)
.. note:: Since rejection is based on a signal **difference**
calculated for each channel separately, applying baseline
correction does not affect the rejection procedure, as the
difference will be preserved.
"""

docdict["reject_drop_bad"] = """
reject : dict | str | None
Reject epochs based on **maximum** peak-to-peak signal amplitude (PTP)
or custom functions. Peak-to-peak signal amplitude is defined as
the absolute difference between the lowest and the highest signal
Expand Down Expand Up @@ -3303,14 +3332,9 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
.. note:: If ``reject`` is a callable, than **any** criteria can be
used to reject epochs (including maxima and minima).
""" # noqa: E501

docdict["reject_drop_bad"] = f"""
reject : dict | str | None
{_reject_common}
If ``reject`` is ``None``, no rejection is performed. If ``'existing'``
(default), then the rejection parameters set at instantiation are used.
"""
""" # noqa: E501

docdict["reject_epochs"] = f"""
reject : dict | None
Expand Down
11 changes: 7 additions & 4 deletions tutorials/preprocessing/20_rejecting_bad_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@
# Sometimes it is useful to reject epochs based criteria other than
# peak-to-peak amplitudes. For example, we might want to reject epochs
# based on the maximum or minimum amplitude of a channel.
# In this case, the :class:`mne.Epochs` class constructor also accepts
# In this case, the `mne.Epochs.drop_bad` function also accepts
# callables (functions) in the ``reject`` and ``flat`` parameters. This
# allows us to define functions to reject epochs based on our desired criteria.
#
Expand Down Expand Up @@ -376,9 +376,10 @@
tmin=0,
tmax=1,
baseline=None,
reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1e-2).any(), "max amp")),
preload=True,
)

epochs.drop_bad(reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1e-2).any(), "max amp")))
epochs.plot(scalings=dict(eeg=50e-5))

# %%
Expand All @@ -397,9 +398,10 @@
tmin=0,
tmax=1,
baseline=None,
reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp")),
preload=True,
)

epochs.drop_bad(reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp")))
epochs.plot(scalings=dict(eeg=50e-5))

# %%
Expand All @@ -420,9 +422,10 @@ def reject_criteria(x):
tmin=0,
tmax=1,
baseline=None,
reject=dict(eeg=reject_criteria),
preload=True,
)

epochs.drop_bad(reject=dict(eeg=reject_criteria))
epochs.plot(events=True)

# %%
Expand Down

0 comments on commit a11a410

Please sign in to comment.