From 33861c5a06538f97ca033770ba780a918f124c8c Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Fri, 25 Aug 2023 15:26:12 +0200 Subject: [PATCH 1/7] Creating K-Subjects out --- moabb/evaluations/base.py | 2 ++ moabb/evaluations/evaluations.py | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/moabb/evaluations/base.py b/moabb/evaluations/base.py index c78db516a..2e3c822fa 100644 --- a/moabb/evaluations/base.py +++ b/moabb/evaluations/base.py @@ -64,6 +64,7 @@ def __init__( return_epochs=False, return_raws=False, mne_labels=False, + n_splits=None, ): self.random_state = random_state self.n_jobs = n_jobs @@ -73,6 +74,7 @@ def __init__( self.return_epochs = return_epochs self.return_raws = return_raws self.mne_labels = mne_labels + self.n_splits = n_splits # check paradigm if not isinstance(paradigm, BaseParadigm): raise (ValueError("paradigm must be an Paradigm instance")) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index ed13be16b..e6a20de1d 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -12,6 +12,7 @@ from sklearn.metrics import get_scorer from sklearn.model_selection import ( GridSearchCV, + GroupKFold, LeaveOneGroupOut, StratifiedKFold, StratifiedShuffleSplit, @@ -715,6 +716,9 @@ class CrossSubjectEvaluation(BaseEvaluation): use MNE raw to train pipelines. mne_labels: bool, default=False if returning MNE epoch, use original dataset label if True + n_splits: int, default=None + Number of splits for cross-validation. If None, the number of splits + is equal to the number of subjects. """ def _grid_search(self, param_grid, name_grid, name, clf, pipelines, X, y, cv, groups): @@ -786,7 +790,11 @@ def evaluate(self, dataset, pipelines, param_grid, process_pipeline): scorer = get_scorer(self.paradigm.scoring) # perform leave one subject out CV - cv = LeaveOneGroupOut() + if self.n_splits is None: + cv = LeaveOneGroupOut() + # cv = GroupKFold(n_splits=n_subjects) + else: + cv = GroupKFold(n_splits=self.n_splits) # Implement Grid Search emissions_grid = {} From ba893b1cba7d4ef63e84c6f7107fb2b179370eeb Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 28 Aug 2023 18:07:39 +0200 Subject: [PATCH 2/7] Fixing the tqdm --- moabb/evaluations/evaluations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index e6a20de1d..e5d79d711 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -795,7 +795,7 @@ def evaluate(self, dataset, pipelines, param_grid, process_pipeline): # cv = GroupKFold(n_splits=n_subjects) else: cv = GroupKFold(n_splits=self.n_splits) - + n_subjects = self.n_splits # Implement Grid Search emissions_grid = {} From a63dca29cff992e39f1c64e059b34c2961355b81 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Sep 2023 12:52:31 +0000 Subject: [PATCH 3/7] [pre-commit.ci] auto fixes from pre-commit.com hooks --- moabb/evaluations/evaluations.py | 1 - 1 file changed, 1 deletion(-) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index 064d4125e..c09f54e73 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -8,7 +8,6 @@ from sklearn.base import clone from sklearn.metrics import get_scorer from sklearn.model_selection import ( - GridSearchCV, GroupKFold, LeaveOneGroupOut, StratifiedKFold, From 21be5abd5e8a1b0dd5e14555f3b29431499583cf Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Fri, 27 Oct 2023 17:36:32 +0200 Subject: [PATCH 4/7] Removing code not use --- moabb/evaluations/evaluations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index 08fc9973e..26a6dedbe 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -678,10 +678,10 @@ def evaluate( # perform leave one subject out CV if self.n_splits is None: cv = LeaveOneGroupOut() - # cv = GroupKFold(n_splits=n_subjects) else: cv = GroupKFold(n_splits=self.n_splits) n_subjects = self.n_splits + inner_cv = StratifiedKFold(3, shuffle=True, random_state=self.random_state) # Implement Grid Search From 04c990b6a8a8319eab1aabe1d5d7f38ea9c4cf56 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Fri, 5 Jan 2024 10:48:17 +0100 Subject: [PATCH 5/7] Updating whats_new.rst --- docs/source/whats_new.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index b361f18c5..fd6430de9 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -20,6 +20,7 @@ Enhancements - Adding cache option to the evaluation (:gh:`517` by `Bruno Aristimunha`_) - Option to interpolate channel in paradigms' `match_all` method (:gh:`480` by `Gregoire Cattan`_) +- Adding leave k-Subjects out evaluations (gh:`470` by `Bruno Aristimunha`_) Bugs ~~~~ @@ -244,7 +245,7 @@ Version - 0.4.2 Enhancements ~~~~~~~~~~~~ -- None +- Adding cache option to the evaluation (:gh:`517` by `Bruno Aristimunha`_) Bugs ~~~~ From 9f780c6ea21489305248c9ed989fee67ffc90489 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Fri, 5 Jan 2024 10:50:08 +0100 Subject: [PATCH 6/7] Fixing whats new --- docs/source/whats_new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index fd6430de9..1f2b10240 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -245,7 +245,7 @@ Version - 0.4.2 Enhancements ~~~~~~~~~~~~ -- Adding cache option to the evaluation (:gh:`517` by `Bruno Aristimunha`_) +- None Bugs ~~~~ From d6249584d3abb6d360e5f0860040dda5b2b046e4 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Fri, 5 Jan 2024 10:50:38 +0100 Subject: [PATCH 7/7] fixing again --- docs/source/whats_new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index 1f2b10240..6a322032b 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -20,7 +20,7 @@ Enhancements - Adding cache option to the evaluation (:gh:`517` by `Bruno Aristimunha`_) - Option to interpolate channel in paradigms' `match_all` method (:gh:`480` by `Gregoire Cattan`_) -- Adding leave k-Subjects out evaluations (gh:`470` by `Bruno Aristimunha`_) +- Adding leave k-Subjects out evaluations (:gh:`470` by `Bruno Aristimunha`_) Bugs ~~~~