-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #89 from wwu-mmll/develop
Develop
- Loading branch information
Showing
28 changed files
with
503 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
34 changes: 34 additions & 0 deletions
34
examples/advanced/connectome_based_predictive_modeling_example.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
""" | ||
Connectome-based predictive modeling | ||
CPM is a method described in the following Nature Protocols article: https://www.nature.com/articles/nprot.2016.178 | ||
It has been used in a number of publications to predict behavior from connectivity data. | ||
CPM works similar to a feature selection method. First, relevant edges (connectivity values) are identified through | ||
correlation analysis. Every edge is correlated with the predictive target. Only significant edges will be used in the | ||
subsequent steps. Next, the edge values for all significant positive and for all significant negative correlations are | ||
summed to create two new features. Lastly, these two features are used as input to another classifier. | ||
In this example, no connectivity data is used, but the method will still work. | ||
This example is just supposed to show how to use CPM as feature selection and integration tool in PHOTONAI. | ||
""" | ||
|
||
from sklearn.datasets import load_breast_cancer | ||
from sklearn.model_selection import KFold | ||
|
||
from photonai import Hyperpipe, PipelineElement | ||
|
||
|
||
X, y = load_breast_cancer(return_X_y=True) | ||
|
||
pipe = Hyperpipe("cpm_feature_selection_pipe", | ||
outer_cv=KFold(n_splits=5, shuffle=True, random_state=15), | ||
inner_cv=KFold(n_splits=5, shuffle=True, random_state=15), | ||
metrics=["balanced_accuracy"], best_config_metric="balanced_accuracy", | ||
project_folder='./tmp') | ||
|
||
pipe += PipelineElement('CPMFeatureSelection', hyperparameters={'corr_method': ['pearson', 'spearman'], | ||
'p_threshold': [0.01, 0.05]}) | ||
|
||
pipe += PipelineElement('LogisticRegression') | ||
|
||
pipe.fit(X, y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# pip install gpboost -U | ||
from sklearn.base import BaseEstimator, ClassifierMixin | ||
from sklearn.model_selection import GroupKFold, KFold | ||
from photonai.base import Hyperpipe, PipelineElement | ||
import numpy as np | ||
import pandas as pd | ||
import gpboost as gpb | ||
# from gpboost import GPBoostRegressor | ||
|
||
|
||
class GPBoostDataWrapper(BaseEstimator, ClassifierMixin): | ||
|
||
def __init__(self): | ||
self.needs_covariates = True | ||
# self.gpmodel = gpb.GPModel(likelihood="gaussian") | ||
self.gpboost = None | ||
|
||
|
||
def fit(self, X, y, **kwargs): | ||
self.gpboost = gpb.GPBoostRegressor() | ||
if "clusters" in kwargs: | ||
clst = pd.Series(kwargs["clusters"]) | ||
gpmodel = gpb.GPModel(likelihood="gaussian", group_data=clst) | ||
self.gpboost.fit(X, y, gp_model=gpmodel) | ||
else: | ||
raise NotImplementedError("GPBoost needs clusters") | ||
return self | ||
|
||
def predict(self, X, **kwargs): | ||
clst = pd.Series(kwargs["clusters"]) | ||
preds = self.gpboost.predict(X, group_data_pred=clst) | ||
preds = preds["response_mean"] | ||
return preds | ||
|
||
def save(self): | ||
return None | ||
|
||
|
||
def get_gpboost_pipe(pipe_name, project_folder, split="group"): | ||
|
||
if split == "group": | ||
outercv = GroupKFold(n_splits=10) | ||
else: | ||
outercv = KFold(n_splits=10) | ||
|
||
my_pipe = Hyperpipe(pipe_name, | ||
optimizer='grid_search', | ||
metrics=['mean_absolute_error', 'mean_squared_error', | ||
'spearman_correlation', 'pearson_correlation'], | ||
best_config_metric='mean_absolute_error', | ||
outer_cv=outercv, | ||
inner_cv=KFold(n_splits=10), | ||
calculate_metrics_across_folds=True, | ||
use_test_set=True, | ||
verbosity=1, | ||
project_folder=project_folder) | ||
|
||
# Add transformer elements | ||
my_pipe += PipelineElement("StandardScaler", hyperparameters={}, | ||
test_disabled=True, with_mean=True, with_std=True) | ||
|
||
my_pipe += PipelineElement.create("GPBoost", GPBoostDataWrapper(), hyperparameters={}) | ||
|
||
return my_pipe | ||
|
||
|
||
def get_mock_data(): | ||
|
||
X = np.random.randint(10, size=(200, 9)) | ||
y = np.sum(X, axis=1) | ||
clst = np.random.randint(10, size=200) | ||
|
||
return X, y, clst | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
|
||
X, y, clst = get_mock_data() | ||
|
||
# define project folder | ||
project_folder = "./tmp/gpboost_debug" | ||
|
||
my_pipe = get_gpboost_pipe("Test_gpboost", project_folder, split="random") | ||
my_pipe.fit(X, y, clusters=clst) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.