Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating K-Subjects out evaluations #470

Merged
merged 14 commits into from
Jan 5, 2024
2 changes: 2 additions & 0 deletions moabb/evaluations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
return_epochs=False,
return_raws=False,
mne_labels=False,
n_splits=None,
save_model=False,
cache_config=None,
):
Expand All @@ -77,6 +78,7 @@ def __init__(
self.return_epochs = return_epochs
self.return_raws = return_raws
self.mne_labels = mne_labels
self.n_splits = n_splits
self.save_model = save_model
self.cache_config = cache_config
# check paradigm
Expand Down
11 changes: 10 additions & 1 deletion moabb/evaluations/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sklearn.base import clone
from sklearn.metrics import get_scorer
from sklearn.model_selection import (
GroupKFold,
LeaveOneGroupOut,
StratifiedKFold,
StratifiedShuffleSplit,
Expand Down Expand Up @@ -632,6 +633,9 @@ class CrossSubjectEvaluation(BaseEvaluation):
use MNE raw to train pipelines.
mne_labels: bool, default=False
if returning MNE epoch, use original dataset label if True
n_splits: int, default=None
Number of splits for cross-validation. If None, the number of splits
is equal to the number of subjects.
"""

# flake8: noqa: C901
Expand Down Expand Up @@ -675,7 +679,12 @@ def evaluate(
scorer = get_scorer(self.paradigm.scoring)

# perform leave one subject out CV
cv = LeaveOneGroupOut()
if self.n_splits is None:
cv = LeaveOneGroupOut()
else:
cv = GroupKFold(n_splits=self.n_splits)
n_subjects = self.n_splits

inner_cv = StratifiedKFold(3, shuffle=True, random_state=self.random_state)

# Implement Grid Search
Expand Down
Loading