Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
agramfort committed Apr 4, 2024
1 parent 7743497 commit 5798065
Show file tree
Hide file tree
Showing 200 changed files with 4,024 additions and 3,118 deletions.
2 changes: 1 addition & 1 deletion dev/.buildinfo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: 88bb6b736447eca80ee243b79569c33d
config: 4d327f39961be823859eba10cb8d21d1
tags: 645f666f9bcd5a90fca523b33c5a78b7
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
},
"outputs": [],
"source": [
"from numbers import Integral\nfrom braindecode.datasets import SleepPhysionet\n\nsubject_ids = [0, 1]\ndataset = SleepPhysionet(\n subject_ids=subject_ids, recording_ids=[2], crop_wake_mins=30)"
"from numbers import Integral\nfrom braindecode.datasets import SleepPhysionet\n\nsubject_ids = [0, 1]\ndataset = SleepPhysionet(subject_ids=subject_ids, recording_ids=[2], crop_wake_mins=30)"
]
},
{
Expand All @@ -51,7 +51,7 @@
},
"outputs": [],
"source": [
"from braindecode.preprocessing import preprocess, Preprocessor\nfrom numpy import multiply\n\nhigh_cut_hz = 30\nfactor = 1e6\n\npreprocessors = [\n Preprocessor(lambda data: multiply(data, factor), apply_on_array=True), # Convert from V to uV\n Preprocessor('filter', l_freq=None, h_freq=high_cut_hz)\n]\n\n# Transform the data\npreprocess(dataset, preprocessors)"
"from braindecode.preprocessing import preprocess, Preprocessor\nfrom numpy import multiply\n\nhigh_cut_hz = 30\nfactor = 1e6\n\npreprocessors = [\n Preprocessor(\n lambda data: multiply(data, factor), apply_on_array=True\n ), # Convert from V to uV\n Preprocessor(\"filter\", l_freq=None, h_freq=high_cut_hz),\n]\n\n# Transform the data\npreprocess(dataset, preprocessors)"
]
},
{
Expand All @@ -69,7 +69,7 @@
},
"outputs": [],
"source": [
"from braindecode.preprocessing import create_windows_from_events\n\nmapping = { # We merge stages 3 and 4 following AASM standards.\n 'Sleep stage W': 0,\n 'Sleep stage 1': 1,\n 'Sleep stage 2': 2,\n 'Sleep stage 3': 3,\n 'Sleep stage 4': 3,\n 'Sleep stage R': 4\n}\n\nwindow_size_s = 30\nsfreq = 100\nwindow_size_samples = window_size_s * sfreq\n\nwindows_dataset = create_windows_from_events(\n dataset,\n trial_start_offset_samples=0,\n trial_stop_offset_samples=0,\n window_size_samples=window_size_samples,\n window_stride_samples=window_size_samples,\n preload=True,\n mapping=mapping\n)"
"from braindecode.preprocessing import create_windows_from_events\n\nmapping = { # We merge stages 3 and 4 following AASM standards.\n \"Sleep stage W\": 0,\n \"Sleep stage 1\": 1,\n \"Sleep stage 2\": 2,\n \"Sleep stage 3\": 3,\n \"Sleep stage 4\": 3,\n \"Sleep stage R\": 4,\n}\n\nwindow_size_s = 30\nsfreq = 100\nwindow_size_samples = window_size_s * sfreq\n\nwindows_dataset = create_windows_from_events(\n dataset,\n trial_start_offset_samples=0,\n trial_stop_offset_samples=0,\n window_size_samples=window_size_samples,\n window_stride_samples=window_size_samples,\n preload=True,\n mapping=mapping,\n)"
]
},
{
Expand Down Expand Up @@ -123,7 +123,7 @@
},
"outputs": [],
"source": [
"import numpy as np\nfrom braindecode.samplers import SequenceSampler\n\nn_windows = 3 # Sequences of 3 consecutive windows\nn_windows_stride = 3 # Maximally overlapping sequences\n\ntrain_sampler = SequenceSampler(\n train_set.get_metadata(), n_windows, n_windows_stride, randomize=True\n)\nvalid_sampler = SequenceSampler(valid_set.get_metadata(), n_windows, n_windows_stride)\n\n# Print number of examples per class\nprint('Training examples: ', len(train_sampler))\nprint('Validation examples: ', len(valid_sampler))"
"import numpy as np\nfrom braindecode.samplers import SequenceSampler\n\nn_windows = 3 # Sequences of 3 consecutive windows\nn_windows_stride = 3 # Maximally overlapping sequences\n\ntrain_sampler = SequenceSampler(\n train_set.get_metadata(), n_windows, n_windows_stride, randomize=True\n)\nvalid_sampler = SequenceSampler(valid_set.get_metadata(), n_windows, n_windows_stride)\n\n# Print number of examples per class\nprint(\"Training examples: \", len(train_sampler))\nprint(\"Validation examples: \", len(valid_sampler))"
]
},
{
Expand Down Expand Up @@ -159,7 +159,7 @@
},
"outputs": [],
"source": [
"from sklearn.utils import compute_class_weight\n\ny_train = [train_set[idx][1] for idx in train_sampler]\nclass_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)"
"from sklearn.utils import compute_class_weight\n\ny_train = [train_set[idx][1] for idx in train_sampler]\nclass_weights = compute_class_weight(\"balanced\", classes=np.unique(y_train), y=y_train)"
]
},
{
Expand All @@ -177,7 +177,7 @@
},
"outputs": [],
"source": [
"import torch\nfrom torch import nn\nfrom braindecode.util import set_random_seeds\nfrom braindecode.models import SleepStagerChambon2018, TimeDistributed\n\ncuda = torch.cuda.is_available() # check if GPU is available\ndevice = 'cuda' if torch.cuda.is_available() else 'cpu'\nif cuda:\n torch.backends.cudnn.benchmark = True\n# Set random seed to be able to roughly reproduce results\n# Note that with cudnn benchmark set to True, GPU indeterminism\n# may still make results substantially different between runs.\n# To obtain more consistent results at the cost of increased computation time,\n# you can set `cudnn_benchmark=False` in `set_random_seeds`\n# or remove `torch.backends.cudnn.benchmark = True`\nset_random_seeds(seed=31, cuda=cuda)\n\nn_classes = 5\n# Extract number of channels and time steps from dataset\nn_channels, input_size_samples = train_set[0][0].shape\n\nfeat_extractor = SleepStagerChambon2018(\n n_channels,\n sfreq,\n n_outputs=n_classes,\n n_times=input_size_samples,\n return_feats=True\n)\n\nmodel = nn.Sequential(\n TimeDistributed(feat_extractor), # apply model on each 30-s window\n nn.Sequential( # apply linear layer on concatenated feature vectors\n nn.Flatten(start_dim=1),\n nn.Dropout(0.5),\n nn.Linear(feat_extractor.len_last_layer * n_windows, n_classes)\n )\n)\n\n# Send model to GPU\nif cuda:\n model.cuda()"
"import torch\nfrom torch import nn\nfrom braindecode.util import set_random_seeds\nfrom braindecode.models import SleepStagerChambon2018, TimeDistributed\n\ncuda = torch.cuda.is_available() # check if GPU is available\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\nif cuda:\n torch.backends.cudnn.benchmark = True\n# Set random seed to be able to roughly reproduce results\n# Note that with cudnn benchmark set to True, GPU indeterminism\n# may still make results substantially different between runs.\n# To obtain more consistent results at the cost of increased computation time,\n# you can set `cudnn_benchmark=False` in `set_random_seeds`\n# or remove `torch.backends.cudnn.benchmark = True`\nset_random_seeds(seed=31, cuda=cuda)\n\nn_classes = 5\n# Extract number of channels and time steps from dataset\nn_channels, input_size_samples = train_set[0][0].shape\n\nfeat_extractor = SleepStagerChambon2018(\n n_channels,\n sfreq,\n n_outputs=n_classes,\n n_times=input_size_samples,\n return_feats=True,\n)\n\nmodel = nn.Sequential(\n TimeDistributed(feat_extractor), # apply model on each 30-s window\n nn.Sequential( # apply linear layer on concatenated feature vectors\n nn.Flatten(start_dim=1),\n nn.Dropout(0.5),\n nn.Linear(feat_extractor.len_last_layer * n_windows, n_classes),\n ),\n)\n\n# Send model to GPU\nif cuda:\n model.cuda()"
]
},
{
Expand All @@ -195,7 +195,7 @@
},
"outputs": [],
"source": [
"from skorch.helper import predefined_split\nfrom skorch.callbacks import EpochScoring\nfrom braindecode import EEGClassifier\n\nlr = 1e-3\nbatch_size = 32\nn_epochs = 10\n\ntrain_bal_acc = EpochScoring(\n scoring='balanced_accuracy', on_train=True, name='train_bal_acc',\n lower_is_better=False)\nvalid_bal_acc = EpochScoring(\n scoring='balanced_accuracy', on_train=False, name='valid_bal_acc',\n lower_is_better=False)\ncallbacks = [\n ('train_bal_acc', train_bal_acc),\n ('valid_bal_acc', valid_bal_acc)\n]\n\nclf = EEGClassifier(\n model,\n criterion=torch.nn.CrossEntropyLoss,\n criterion__weight=torch.Tensor(class_weights).to(device),\n optimizer=torch.optim.Adam,\n iterator_train__shuffle=False,\n iterator_train__sampler=train_sampler,\n iterator_valid__sampler=valid_sampler,\n train_split=predefined_split(valid_set), # using valid_set for validation\n optimizer__lr=lr,\n batch_size=batch_size,\n callbacks=callbacks,\n device=device,\n classes=np.unique(y_train),\n)\n# Model training for a specified number of epochs. `y` is None as it is already\n# supplied in the dataset.\nclf.fit(train_set, y=None, epochs=n_epochs)"
"from skorch.helper import predefined_split\nfrom skorch.callbacks import EpochScoring\nfrom braindecode import EEGClassifier\n\nlr = 1e-3\nbatch_size = 32\nn_epochs = 10\n\ntrain_bal_acc = EpochScoring(\n scoring=\"balanced_accuracy\",\n on_train=True,\n name=\"train_bal_acc\",\n lower_is_better=False,\n)\nvalid_bal_acc = EpochScoring(\n scoring=\"balanced_accuracy\",\n on_train=False,\n name=\"valid_bal_acc\",\n lower_is_better=False,\n)\ncallbacks = [(\"train_bal_acc\", train_bal_acc), (\"valid_bal_acc\", valid_bal_acc)]\n\nclf = EEGClassifier(\n model,\n criterion=torch.nn.CrossEntropyLoss,\n criterion__weight=torch.Tensor(class_weights).to(device),\n optimizer=torch.optim.Adam,\n iterator_train__shuffle=False,\n iterator_train__sampler=train_sampler,\n iterator_valid__sampler=valid_sampler,\n train_split=predefined_split(valid_set), # using valid_set for validation\n optimizer__lr=lr,\n batch_size=batch_size,\n callbacks=callbacks,\n device=device,\n classes=np.unique(y_train),\n)\n# Model training for a specified number of epochs. `y` is None as it is already\n# supplied in the dataset.\nclf.fit(train_set, y=None, epochs=n_epochs)"
]
},
{
Expand All @@ -213,7 +213,7 @@
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\nimport pandas as pd\n\n# Extract loss and balanced accuracy values for plotting from history object\ndf = pd.DataFrame(clf.history.to_list())\ndf.index.name = \"Epoch\"\nfig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 7), sharex=True)\ndf[['train_loss', 'valid_loss']].plot(color=['r', 'b'], ax=ax1)\ndf[['train_bal_acc', 'valid_bal_acc']].plot(color=['r', 'b'], ax=ax2)\nax1.set_ylabel('Loss')\nax2.set_ylabel('Balanced accuracy')\nax1.legend(['Train', 'Valid'])\nax2.legend(['Train', 'Valid'])\nfig.tight_layout()\nplt.show()"
"import matplotlib.pyplot as plt\nimport pandas as pd\n\n# Extract loss and balanced accuracy values for plotting from history object\ndf = pd.DataFrame(clf.history.to_list())\ndf.index.name = \"Epoch\"\nfig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 7), sharex=True)\ndf[[\"train_loss\", \"valid_loss\"]].plot(color=[\"r\", \"b\"], ax=ax1)\ndf[[\"train_bal_acc\", \"valid_bal_acc\"]].plot(color=[\"r\", \"b\"], ax=ax2)\nax1.set_ylabel(\"Loss\")\nax2.set_ylabel(\"Balanced accuracy\")\nax1.legend([\"Train\", \"Valid\"])\nax2.legend([\"Train\", \"Valid\"])\nfig.tight_layout()\nplt.show()"
]
},
{
Expand All @@ -231,7 +231,7 @@
},
"outputs": [],
"source": [
"from sklearn.metrics import confusion_matrix, classification_report\nfrom braindecode.visualization import plot_confusion_matrix\n\ny_true = [valid_set[[i]][1][0] for i in range(len(valid_sampler))]\ny_pred = clf.predict(valid_set)\n\nconfusion_mat = confusion_matrix(y_true, y_pred)\n\nplot_confusion_matrix(confusion_mat=confusion_mat,\n class_names=['Wake', 'N1', 'N2', 'N3', 'REM'])\n\nprint(classification_report(y_true, y_pred))"
"from sklearn.metrics import confusion_matrix, classification_report\nfrom braindecode.visualization import plot_confusion_matrix\n\ny_true = [valid_set[[i]][1][0] for i in range(len(valid_sampler))]\ny_pred = clf.predict(valid_set)\n\nconfusion_mat = confusion_matrix(y_true, y_pred)\n\nplot_confusion_matrix(\n confusion_mat=confusion_mat, class_names=[\"Wake\", \"N1\", \"N2\", \"N3\", \"REM\"]\n)\n\nprint(classification_report(y_true, y_pred))"
]
},
{
Expand All @@ -249,7 +249,7 @@
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n\nfig, ax = plt.subplots(figsize=(15, 5))\nax.plot(y_true, color='b', label='Expert annotations')\nax.plot(y_pred.flatten(), color='r', label='Predict annotations', alpha=0.5)\nax.set_xlabel('Time (epochs)')\nax.set_ylabel('Sleep stage')"
"import matplotlib.pyplot as plt\n\nfig, ax = plt.subplots(figsize=(15, 5))\nax.plot(y_true, color=\"b\", label=\"Expert annotations\")\nax.plot(y_pred.flatten(), color=\"r\", label=\"Predict annotations\", alpha=0.5)\nax.set_xlabel(\"Time (epochs)\")\nax.set_ylabel(\"Sleep stage\")"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,24 +65,29 @@

from numpy import multiply

from braindecode.preprocessing import (Preprocessor,
exponential_moving_standardize,
preprocess)
from braindecode.preprocessing import (
Preprocessor,
exponential_moving_standardize,
preprocess,
)

low_cut_hz = 4. # low cut frequency for filtering
high_cut_hz = 38. # high cut frequency for filtering
low_cut_hz = 4.0 # low cut frequency for filtering
high_cut_hz = 38.0 # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
# Factor to convert from V to uV
factor = 1e6

preprocessors = [
Preprocessor('pick_types', eeg=True, meg=False, stim=False), # Keep EEG sensors
Preprocessor("pick_types", eeg=True, meg=False, stim=False), # Keep EEG sensors
Preprocessor(lambda data: multiply(data, factor)), # Convert from V to uV
Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz), # Bandpass filter
Preprocessor(exponential_moving_standardize, # Exponential moving standardization
factor_new=factor_new, init_block_size=init_block_size)
Preprocessor("filter", l_freq=low_cut_hz, h_freq=high_cut_hz), # Bandpass filter
Preprocessor(
exponential_moving_standardize, # Exponential moving standardization
factor_new=factor_new,
init_block_size=init_block_size,
),
]

# Transform the data
Expand All @@ -107,8 +112,8 @@

trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info['sfreq']
assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
sfreq = dataset.datasets[0].raw.info["sfreq"]
assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

Expand All @@ -134,9 +139,9 @@
# ``T`` for training and ``test`` for validation.
#

splitted = windows_dataset.split('session')
train_set = splitted['0train'] # Session train
valid_set = splitted['1test'] # Session evaluation
splitted = windows_dataset.split("session")
train_set = splitted["0train"] # Session train
valid_set = splitted["1test"] # Session evaluation


######################################################################
Expand All @@ -160,7 +165,7 @@
from braindecode.util import set_random_seeds

cuda = torch.cuda.is_available() # check if GPU is available, if True chooses to use it
device = 'cuda' if cuda else 'cpu'
device = "cuda" if cuda else "cpu"
if cuda:
torch.backends.cudnn.benchmark = True
# Set random seed to be able to roughly reproduce results
Expand All @@ -182,7 +187,7 @@
n_chans,
n_classes,
input_window_samples=input_window_samples,
final_conv_length='auto',
final_conv_length="auto",
)

# Display torchinfo table describing the model
Expand Down Expand Up @@ -240,7 +245,8 @@
optimizer__weight_decay=weight_decay,
batch_size=batch_size,
callbacks=[
"accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
"accuracy",
("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=n_epochs - 1)),
],
device=device,
classes=classes,
Expand All @@ -266,34 +272,45 @@
from matplotlib.lines import Line2D

# Extract loss and accuracy values for plotting from history object
results_columns = ['train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy']
df = pd.DataFrame(clf.history[:, results_columns], columns=results_columns,
index=clf.history[:, 'epoch'])
results_columns = ["train_loss", "valid_loss", "train_accuracy", "valid_accuracy"]
df = pd.DataFrame(
clf.history[:, results_columns],
columns=results_columns,
index=clf.history[:, "epoch"],
)

# get percent of misclass for better visual comparison to loss
df = df.assign(train_misclass=100 - 100 * df.train_accuracy,
valid_misclass=100 - 100 * df.valid_accuracy)
df = df.assign(
train_misclass=100 - 100 * df.train_accuracy,
valid_misclass=100 - 100 * df.valid_accuracy,
)

fig, ax1 = plt.subplots(figsize=(8, 3))
df.loc[:, ['train_loss', 'valid_loss']].plot(
ax=ax1, style=['-', ':'], marker='o', color='tab:blue', legend=False, fontsize=14)
df.loc[:, ["train_loss", "valid_loss"]].plot(
ax=ax1, style=["-", ":"], marker="o", color="tab:blue", legend=False, fontsize=14
)

ax1.tick_params(axis='y', labelcolor='tab:blue', labelsize=14)
ax1.set_ylabel("Loss", color='tab:blue', fontsize=14)
ax1.tick_params(axis="y", labelcolor="tab:blue", labelsize=14)
ax1.set_ylabel("Loss", color="tab:blue", fontsize=14)

ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis

df.loc[:, ['train_misclass', 'valid_misclass']].plot(
ax=ax2, style=['-', ':'], marker='o', color='tab:red', legend=False)
ax2.tick_params(axis='y', labelcolor='tab:red', labelsize=14)
ax2.set_ylabel("Misclassification Rate [%]", color='tab:red', fontsize=14)
df.loc[:, ["train_misclass", "valid_misclass"]].plot(
ax=ax2, style=["-", ":"], marker="o", color="tab:red", legend=False
)
ax2.tick_params(axis="y", labelcolor="tab:red", labelsize=14)
ax2.set_ylabel("Misclassification Rate [%]", color="tab:red", fontsize=14)
ax2.set_ylim(ax2.get_ylim()[0], 85) # make some room for legend
ax1.set_xlabel("Epoch", fontsize=14)

# where some data has already been plotted to ax
handles = []
handles.append(Line2D([0], [0], color='black', linewidth=1, linestyle='-', label='Train'))
handles.append(Line2D([0], [0], color='black', linewidth=1, linestyle=':', label='Valid'))
handles.append(
Line2D([0], [0], color="black", linewidth=1, linestyle="-", label="Train")
)
handles.append(
Line2D([0], [0], color="black", linewidth=1, linestyle=":", label="Valid")
)
plt.legend(handles, [h.get_label() for h in handles], fontsize=14)
plt.tight_layout()

Expand Down Expand Up @@ -323,7 +340,7 @@

# add class labels
# label_dict is class_name : str -> i_class : int
label_dict = windows_dataset.datasets[0].window_kwargs[0][1]['mapping']
label_dict = windows_dataset.datasets[0].window_kwargs[0][1]["mapping"]
# sort the labels by values (values are integer class labels)
labels = [k for k, v in sorted(label_dict.items(), key=lambda kv: kv[1])]

Expand Down
Loading

0 comments on commit 5798065

Please sign in to comment.