diff --git a/.travis.yml b/.travis.yml
index 43400cb..1f9c0d4 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -2,6 +2,7 @@ language: python
cache: pip
jobs:
allow_failures:
+ - os: windows
- os: osx
include:
- name: "Python 3.6.0 on Xenial Linux"
diff --git a/CHANGELOG.md b/CHANGELOG.md
index cfbe08d..de57125 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -13,6 +13,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
+## [0.10.0] - 2020-09-28
+### Added
+* Added new dataset versioning system,
+ you can work with multiple versions
+ of the same dataset during creation
+* Added new in-Python download system,
+ you can download a specific version
+ of the dataset during initialization
+
+### Changed
+* Refactored dataset indexing system,
+ a dedicated directory will be created
+ for the index database during initialization
+* Refactored preprocessing cache system,
+ a dedicated directory will be created
+ and a faster version for cache coherence
+ check has been implemented
+* Refactored logging system
+* Update package requirements
+* Update README
+
+### Removed
+* Deprecated Makefile
+
+
## [0.9.3] - 2020-07-05
### Change
diff --git a/Makefile b/Makefile
deleted file mode 100755
index e9c93da..0000000
--- a/Makefile
+++ /dev/null
@@ -1,21 +0,0 @@
-.PHONY: tuh_eeg_abnormal tuh_eeg_artifact eegmmidb clean
-
-tuh_eeg_abnormal:
- echo "Request your access password at: https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php"
- rsync -auxvL nedc_tuh_eeg@www.isip.piconepress.com:~/data/tuh_eeg_abnormal/ data/tuh_eeg_abnormal
-
-tuh_eeg_artifact:
- echo "Request your access password at: https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php"
- rsync -auxvL nedc_tuh_eeg@www.isip.piconepress.com:~/data/tuh_eeg_artifact/ data/tuh_eeg_artifact
-
-tuh_eeg_seizure:
- echo "Request your access password at: https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php"
- rsync -auxvL nedc_tuh_eeg@www.isip.piconepress.com:~/data/tuh_eeg_seizure/v1.5.2 data/tuh_eeg_seizure
-
-eegmmidb:
- wget -r -N -c -np https://physionet.org/files/eegmmidb/1.0.0/ -P data
-
-clean:
- find -L data -iname "*.db" -type f -delete
- find -L data -iname "*.log" -type f -delete
- find -L data -iname "*.fif.gz" -type f -delete
diff --git a/README.md b/README.md
index 21d59e3..205ed9f 100644
--- a/README.md
+++ b/README.md
@@ -14,7 +14,7 @@ Here is a simple quickstart:
from pyeeglab import *
dataset = TUHEEGAbnormalDataset()
- pipeline = Pipeline([
+ preprocessing = Pipeline([
CommonChannelSet(),
LowestFrequency(),
ToDataframe(),
@@ -35,10 +35,6 @@ PyEEGLab is distributed using the pip repository:
pip install PyEEGLab
-If you use Python 3.6, the dataclasses package must be installed as backport of Python 3.7 dataclasses:
-
- pip install dataclasses
-
If you need a bleeding edge version, you can install it directly from GitHub:
pip install git+https://github.com/AlessioZanga/PyEEGLab@develop
@@ -57,6 +53,8 @@ The following datasets will work upon downloading:
## How to Class Meaning - From the TUH Seizure docs
+
+
| **Class Code** | **Event Name** | **Description** |
| -------------- | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------ |
| _NULL_ | No Event | An unclassified event |
@@ -91,15 +89,24 @@ The following datasets will work upon downloading:
| _TRIP_ | Triphasic Wave | Large, three-phase waves frequently caused by an underlying metabolic condition. |
| _ELEC_ | Electrode Artifact | Electrode pop, Electrostatic artifacts, Lead artifacts. |
+
+
## How to Get a Dataset
> **WARNING**: Retriving the TUH EEG datasets require valid credentials, you can get your own at: https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php.
-In the root directory of this project there is a Makefile, by typing:
+Given the dataset instance, trigger the download using the "download" method:
+
+ from pyeeglab import *
+ dataset = TUHEEGAbnormalDataset()
+ dataset.download(user='USER', password='PASSWORD')
+ dataset.index()
+
+then index the new downloaded files.
- make tuh_eeg_abnormal
+It should be noted that the download mechanism work on Unix-like systems given the following packages:
-you will trigger the dataset download.
+ sudo apt install sshpass rsync wget
## Documentation
diff --git a/examples/tensorboard/example_tensorboard.py b/examples/tensorboard/example_tensorboard.py
index 448ac8e..de8ae71 100755
--- a/examples/tensorboard/example_tensorboard.py
+++ b/examples/tensorboard/example_tensorboard.py
@@ -26,7 +26,6 @@
import numpy as np
from random import shuffle
from itertools import product
-from networkx import to_numpy_matrix
from sklearn.model_selection import train_test_split
from tensorflow.python.keras.utils.np_utils import to_categorical
@@ -38,8 +37,6 @@
from pyeeglab import *
def build_data(dataset):
- dataset.set_cache_manager(PickleCache('../../export'))
-
preprocessing = Pipeline([
CommonChannelSet(),
LowestFrequency(),
@@ -96,7 +93,7 @@ def get_correlation_matrix(x, frame, N, F):
input_0 = tf.keras.Input((frames, N, F + N))
- gans = []
+ layers = []
for frame in range(frames):
feature_matrix = tf.keras.layers.Lambda(
get_feature_matrix,
@@ -110,9 +107,9 @@ def get_correlation_matrix(x, frame, N, F):
x = sp.layers.GraphAttention(hparams['output_shape'])([feature_matrix, correlation_matrix])
x = tf.keras.layers.Flatten()(x)
- gans.append(x)
+ layers.append(x)
- combine = tf.keras.layers.Concatenate()(gans)
+ combine = tf.keras.layers.Concatenate()(layers)
reshape = tf.keras.layers.Reshape((frames, N * hparams['output_shape']))(combine)
lstm = tf.keras.layers.LSTM(hparams['hidden_units'])(reshape)
dropout = tf.keras.layers.Dropout(hparams['dropout'])(lstm)
@@ -125,20 +122,25 @@ def get_correlation_matrix(x, frame, N, F):
metrics=[
'accuracy',
Recall(class_id=0, name='recall'),
+ Specificity(class_id=0, name='specificity'),
Precision(class_id=0, name='precision'),
+ F1Score(class_id=0, name='f1score'),
]
)
model.summary()
+ model.save('logs/plot_gat.h5')
return model
def run_trial(path, step, model, hparams, x_train, y_train, x_val, y_val, x_test, y_test, epochs):
with tf.summary.create_file_writer(path).as_default():
hp.hparams(hparams)
model.fit(x_train, y_train, epochs=epochs, batch_size=32, shuffle=True, validation_data=(x_val, y_val))
- loss, accuracy, recall, precision = model.evaluate(x_test, y_test)
+ loss, accuracy, recall, specificity, precision, f1score = model.evaluate(x_test, y_test)
tf.summary.scalar('accuracy', accuracy, step=step)
tf.summary.scalar('recall', recall, step=step)
+ tf.summary.scalar('specificity', specificity, step=step)
tf.summary.scalar('precision', precision, step=step)
+ tf.summary.scalar('f1score', f1score, step=step)
def hparams_combinations(hparams):
hp.hparams_config(
@@ -146,7 +148,9 @@ def hparams_combinations(hparams):
metrics=[
hp.Metric('accuracy', display_name='Accuracy'),
hp.Metric('recall', display_name='Recall'),
+ hp.Metric('specificity', display_name='Specificity'),
hp.Metric('precision', display_name='Precision'),
+ hp.Metric('f1score', display_name='F1Score'),
]
)
hparams_keys = list(hparams.keys())
@@ -162,7 +166,7 @@ def hparams_combinations(hparams):
return hparams
def tune_model(dataset_name, data):
- LOGS_DIR = join('./logs/generic', dataset_name)
+ LOGS_DIR = join('./logs/gat', dataset_name)
os.makedirs(LOGS_DIR, exist_ok=True)
# Prepare the data
x_train, y_train, x_val, y_val, x_test, y_test = adapt_data(data)
@@ -206,18 +210,18 @@ def tune_model(dataset_name, data):
if __name__ == '__main__':
dataset = {}
- dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf')
+ dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/')
- dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/v1.0.0/edf')
+ dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/')
dataset['tuh_eeg_artifact'].set_minimum_event_duration(4)
- dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/v1.5.2/edf')
+ dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/')
dataset['tuh_eeg_seizure'].set_minimum_event_duration(4)
- # dataset['eegmmidb'] = EEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/1.0.0')
+ # dataset['eegmmidb'] = PhysioNetEEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/')
# dataset['eegmmidb'].set_minimum_event_duration(4)
- dataset['chbmit'] = CHBMITDataset('../../data/physionet.org/files/chbmit/1.0.0')
+ dataset['chbmit'] = PhysioNetCHBMITDataset('../../data/physionet.org/files/chbmit/')
dataset['chbmit'].set_minimum_event_duration(4)
"""
diff --git a/examples/tensorboard/example_tensorboard_cnn.py b/examples/tensorboard/example_tensorboard_cnn.py
new file mode 100755
index 0000000..b345afc
--- /dev/null
+++ b/examples/tensorboard/example_tensorboard_cnn.py
@@ -0,0 +1,223 @@
+#!/usr/bin/env python
+
+# Ignore MNE and TensorFlow warnings
+import warnings
+warnings.simplefilter(action='ignore')
+
+# Import TensorFlow with GPU memory settings
+import tensorflow as tf
+gpus = tf.config.experimental.list_physical_devices('GPU')
+try:
+ for gpu in gpus:
+ tf.config.experimental.set_memory_growth(gpu, True)
+except RuntimeError as e:
+ print(e)
+
+# Import TensorBoard params and metrics
+from tensorboard.plugins.hparams import api as hp
+from tensorflow.keras.metrics import CategoricalAccuracy, Precision, Recall
+
+# Import Spektral for GraphAttention
+import spektral as sp
+
+# Others imports
+import os
+import pickle
+import numpy as np
+from random import shuffle
+from itertools import product
+from scipy.sparse import csc_matrix
+from spektral.layers.ops import sp_matrix_to_sp_tensor
+from sklearn.model_selection import train_test_split
+from tensorflow.python.keras.utils.np_utils import to_categorical
+
+# Relative import PyEEGLab
+import sys
+from os.path import abspath, dirname, join
+
+sys.path.insert(0, abspath(join(dirname(__file__), '../..')))
+from pyeeglab import *
+
+def build_data(dataset):
+ preprocessing = Pipeline([
+ CommonChannelSet(),
+ LowestFrequency(),
+ ToDataframe(),
+ MinMaxCentralizedNormalization(),
+ DynamicWindow(8),
+ ForkedPreprocessor(
+ inputs=[
+ SpearmanCorrelation(),
+ Mean(),
+ Variance(),
+ Skewness(),
+ Kurtosis(),
+ ZeroCrossing(),
+ AbsoluteArea(),
+ PeakToPeak(),
+ Bandpower(['Delta', 'Theta', 'Alpha', 'Beta'])
+ ],
+ output=ToMergedDataframes()
+ ),
+ ToNumpy()
+ ])
+
+ return dataset.set_pipeline(preprocessing).load()
+
+def adapt_data(data, test_size=0.1, shuffle=True):
+ if isinstance(data, str):
+ with open(data, 'rb') as f:
+ data = pickle.load(f)
+ samples, labels = data['data'], data['labels']
+ x_train, x_test, y_train, y_test = train_test_split(samples, labels, test_size=test_size, shuffle=shuffle, stratify=labels)
+ x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=test_size, shuffle=shuffle, stratify=y_train)
+ classes = np.sort(np.unique(labels))
+ y_train = to_categorical(y_train, num_classes=len(classes))
+ y_test = to_categorical(y_test, num_classes=len(classes))
+ y_val = to_categorical(y_val, num_classes=len(classes))
+ return x_train, y_train, x_val, y_val, x_test, y_test
+
+def build_model(shape, classes, hparams):
+ print(hparams)
+ N = shape[2]
+ F = shape[3] - N
+ frames = shape[1]
+
+ def get_frame(x, frame, N, F):
+ x = tf.slice(x, [0, frame, 0, 0], [-1, 1, N, F])
+ x = tf.squeeze(x, axis=[1])
+ return x
+
+ input_0 = tf.keras.Input((frames, N, F + N))
+
+ layers = []
+ for frame in range(frames):
+ frame_matrix = tf.keras.layers.Lambda(
+ get_frame,
+ arguments={'frame': frame, 'N': N, 'F': F}
+ )(input_0)
+
+ x = tf.keras.layers.Conv1D(hparams['filters'], hparams['kernel'], data_format='channels_first')(frame_matrix)
+ x = tf.keras.layers.MaxPooling1D(hparams['pool'])(x)
+ x = tf.keras.layers.Flatten()(x)
+ layers.append(x)
+
+ combine = tf.keras.layers.Concatenate()(layers)
+ reshape = tf.keras.layers.Reshape((-1, frames))(combine)
+ lstm = tf.keras.layers.LSTM(hparams['hidden_units'])(reshape)
+ dropout = tf.keras.layers.Dropout(hparams['dropout'])(lstm)
+ out = tf.keras.layers.Dense(classes, activation='softmax')(dropout)
+
+ model = tf.keras.Model(inputs=[input_0], outputs=out)
+ model.compile(
+ optimizer=tf.keras.optimizers.Adam(learning_rate=hparams['learning_rate']),
+ loss='categorical_crossentropy',
+ metrics=[
+ 'accuracy',
+ Recall(class_id=0, name='recall'),
+ Precision(class_id=0, name='precision'),
+ ]
+ )
+ model.summary()
+ model.save('logs/plot_cnn.h5')
+ return model
+
+def run_trial(path, step, model, hparams, x_train, y_train, x_val, y_val, x_test, y_test, epochs):
+ with tf.summary.create_file_writer(path).as_default():
+ hp.hparams(hparams)
+ model.fit(x_train, y_train, epochs=epochs, batch_size=32, shuffle=True, validation_data=(x_val, y_val))
+ loss, accuracy, recall, precision = model.evaluate(x_test, y_test)
+ tf.summary.scalar('accuracy', accuracy, step=step)
+ tf.summary.scalar('recall', recall, step=step)
+ tf.summary.scalar('precision', precision, step=step)
+
+def hparams_combinations(hparams):
+ hp.hparams_config(
+ hparams=list(hparams.values()),
+ metrics=[
+ hp.Metric('accuracy', display_name='Accuracy'),
+ hp.Metric('recall', display_name='Recall'),
+ hp.Metric('precision', display_name='Precision'),
+ ]
+ )
+ hparams_keys = list(hparams.keys())
+ hparams_values = list(product(*[
+ h.domain.values
+ for h in hparams.values()
+ ]))
+ hparams = [
+ dict(zip(hparams_keys, values))
+ for values in hparams_values
+ ]
+ shuffle(hparams)
+ return hparams
+
+def tune_model(dataset_name, data):
+ LOGS_DIR = join('./logs/cnn', dataset_name)
+ os.makedirs(LOGS_DIR, exist_ok=True)
+ # Prepare the data
+ x_train, y_train, x_val, y_val, x_test, y_test = adapt_data(data)
+ # Set tuning session
+ counter = 0
+ # Parameters to be tuned
+ hparams = {
+ 'learning_rate': [1e-4],
+ 'hidden_units': [8, 16, 32, 64],
+ 'filters': [8, 16, 32],
+ 'kernel': [3, 5, 7],
+ 'pool': [2, 3],
+ 'dropout': [0.00, 0.05, 0.10, 0.15, 0.20],
+ }
+ hparams = {
+ key: hp.HParam(key, hp.Discrete(value))
+ for key, value in hparams.items()
+ }
+ hparams = hparams_combinations(hparams)
+ for hparam in hparams:
+ # Build the model
+ model = build_model(data['data'].shape, len(data['labels_encoder']), hparam)
+ # Run session
+ run_name = f'run-{counter}'
+ print(f'--- Starting trial: {run_name}')
+ print(hparam)
+ run_trial(
+ join(LOGS_DIR, run_name),
+ counter,
+ model,
+ hparam,
+ x_train,
+ y_train,
+ x_val,
+ y_val,
+ x_test,
+ y_test,
+ epochs=50
+ )
+ counter += 1
+
+
+if __name__ == '__main__':
+ dataset = {}
+
+ # dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/')
+
+ # dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/')
+ # dataset['tuh_eeg_artifact'].set_minimum_event_duration(4)
+
+ dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/')
+ dataset['tuh_eeg_seizure'].set_minimum_event_duration(4)
+
+ # dataset['eegmmidb'] = PhysioNetEEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/')
+ # dataset['eegmmidb'].set_minimum_event_duration(4)
+
+ dataset['chbmit'] = PhysioNetCHBMITDataset('../../data/physionet.org/files/chbmit/')
+ dataset['chbmit'].set_minimum_event_duration(4)
+
+ """
+ Note: You can just use paths as values in the dictionary
+ and comment-out the first line of the following for cycle ;)
+ """
+
+ for key, value in dataset.items():
+ value = build_data(value)
+ tune_model(key, value)
diff --git a/examples/tensorboard/example_tensorboard_gat.py b/examples/tensorboard/example_tensorboard_gat.py
index cb839a9..f74c1dc 100755
--- a/examples/tensorboard/example_tensorboard_gat.py
+++ b/examples/tensorboard/example_tensorboard_gat.py
@@ -26,7 +26,6 @@
import numpy as np
from random import shuffle
from itertools import product
-from networkx import to_numpy_matrix
from sklearn.model_selection import train_test_split
from tensorflow.python.keras.utils.np_utils import to_categorical
@@ -38,8 +37,6 @@
from pyeeglab import *
def build_data(dataset):
- dataset.set_cache_manager(PickleCache('../../export'))
-
preprocessing = Pipeline([
CommonChannelSet(),
LowestFrequency(),
@@ -96,7 +93,7 @@ def get_correlation_matrix(x, frame, N, F):
input_0 = tf.keras.Input((frames, N, F + N))
- gans = []
+ layers = []
for frame in range(frames):
feature_matrix = tf.keras.layers.Lambda(
get_feature_matrix,
@@ -110,9 +107,9 @@ def get_correlation_matrix(x, frame, N, F):
x = sp.layers.GraphAttention(hparams['output_shape'])([feature_matrix, correlation_matrix])
x = tf.keras.layers.Flatten()(x)
- gans.append(x)
+ layers.append(x)
- combine = tf.keras.layers.Concatenate()(gans)
+ combine = tf.keras.layers.Concatenate()(layers)
reshape = tf.keras.layers.Reshape((frames, N * hparams['output_shape']))(combine)
lstm = tf.keras.layers.LSTM(hparams['hidden_units'])(reshape)
dropout = tf.keras.layers.Dropout(hparams['dropout'])(lstm)
@@ -129,6 +126,7 @@ def get_correlation_matrix(x, frame, N, F):
]
)
model.summary()
+ model.save('logs/plot_gat.h5')
return model
def run_trial(path, step, model, hparams, x_train, y_train, x_val, y_val, x_test, y_test, epochs):
@@ -162,7 +160,7 @@ def hparams_combinations(hparams):
return hparams
def tune_model(dataset_name, data):
- LOGS_DIR = join('./logs/gan', dataset_name)
+ LOGS_DIR = join('./logs/gat', dataset_name)
os.makedirs(LOGS_DIR, exist_ok=True)
# Prepare the data
x_train, y_train, x_val, y_val, x_test, y_test = adapt_data(data)
@@ -206,18 +204,18 @@ def tune_model(dataset_name, data):
if __name__ == '__main__':
dataset = {}
- dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf')
+ dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/')
- dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/v1.0.0/edf')
+ dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/')
dataset['tuh_eeg_artifact'].set_minimum_event_duration(4)
- dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/v1.5.2/edf')
+ dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/')
dataset['tuh_eeg_seizure'].set_minimum_event_duration(4)
- # dataset['eegmmidb'] = EEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/1.0.0')
+ # dataset['eegmmidb'] = PhysioNetEEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/')
# dataset['eegmmidb'].set_minimum_event_duration(4)
- dataset['chbmit'] = CHBMITDataset('../../data/physionet.org/files/chbmit/1.0.0')
+ dataset['chbmit'] = PhysioNetCHBMITDataset('../../data/physionet.org/files/chbmit/')
dataset['chbmit'].set_minimum_event_duration(4)
"""
diff --git a/examples/tensorboard/example_tensorboard_gnn.py b/examples/tensorboard/example_tensorboard_gnn.py
index 551e69c..40c9977 100755
--- a/examples/tensorboard/example_tensorboard_gnn.py
+++ b/examples/tensorboard/example_tensorboard_gnn.py
@@ -39,8 +39,6 @@
from pyeeglab import *
def build_data(dataset):
- dataset.set_cache_manager(PickleCache('../../export'))
-
preprocessing = Pipeline([
CommonChannelSet(),
LowestFrequency(),
@@ -97,7 +95,7 @@ def get_correlation_matrix(x, frame, N, F):
input_0 = tf.keras.Input((frames, N, F + N))
- gans = []
+ layers = []
for frame in range(frames):
feature_matrix = tf.keras.layers.Lambda(
get_feature_matrix,
@@ -111,9 +109,9 @@ def get_correlation_matrix(x, frame, N, F):
x = sp.layers.GraphConv(hparams['output_shape'])([feature_matrix, correlation_matrix])
x = tf.keras.layers.Flatten()(x)
- gans.append(x)
+ layers.append(x)
- combine = tf.keras.layers.Concatenate()(gans)
+ combine = tf.keras.layers.Concatenate()(layers)
reshape = tf.keras.layers.Reshape((frames, N * hparams['output_shape']))(combine)
lstm = tf.keras.layers.LSTM(hparams['hidden_units'])(reshape)
dropout = tf.keras.layers.Dropout(hparams['dropout'])(lstm)
@@ -130,6 +128,7 @@ def get_correlation_matrix(x, frame, N, F):
]
)
model.summary()
+ model.save('logs/plot_gnn.h5')
return model
def run_trial(path, step, model, hparams, x_train, y_train, x_val, y_val, x_test, y_test, epochs):
@@ -207,18 +206,18 @@ def tune_model(dataset_name, data):
if __name__ == '__main__':
dataset = {}
- # dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf')
+ # dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/')
- # dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/v1.0.0/edf')
+ # dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/')
# dataset['tuh_eeg_artifact'].set_minimum_event_duration(4)
- dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/v1.5.2/edf')
+ dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/')
dataset['tuh_eeg_seizure'].set_minimum_event_duration(4)
- # dataset['eegmmidb'] = EEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/1.0.0')
+ # dataset['eegmmidb'] = PhysioNetEEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/')
# dataset['eegmmidb'].set_minimum_event_duration(4)
- dataset['chbmit'] = CHBMITDataset('../../data/physionet.org/files/chbmit/1.0.0')
+ dataset['chbmit'] = PhysioNetCHBMITDataset('../../data/physionet.org/files/chbmit/')
dataset['chbmit'].set_minimum_event_duration(4)
"""
diff --git a/examples/tensorboard/example_validation.py b/examples/tensorboard/example_validation.py
new file mode 100755
index 0000000..cc906f7
--- /dev/null
+++ b/examples/tensorboard/example_validation.py
@@ -0,0 +1,249 @@
+#!/usr/bin/env python
+
+# Ignore MNE and TensorFlow warnings
+import warnings
+warnings.simplefilter(action='ignore')
+
+# Import TensorFlow with GPU memory settings
+import tensorflow as tf
+gpus = tf.config.experimental.list_physical_devices('GPU')
+try:
+ for gpu in gpus:
+ tf.config.experimental.set_memory_growth(gpu, True)
+except RuntimeError as e:
+ print(e)
+
+# Import TensorBoard params and metrics
+from tensorboard.plugins.hparams import api as hp
+from tensorflow.keras.metrics import Accuracy, Precision, Recall
+
+# Import Spektral for GraphAttention
+import spektral as sp
+
+# Others imports
+import os
+import pickle
+import numpy as np
+import pandas as pd
+from random import shuffle
+from itertools import product
+from sklearn.model_selection import train_test_split, StratifiedKFold
+from tensorflow.python.keras.utils.np_utils import to_categorical
+
+# Relative import PyEEGLab
+import sys
+from os.path import abspath, dirname, join
+
+sys.path.insert(0, abspath(join(dirname(__file__), '../..')))
+from pyeeglab import *
+
+def build_data(dataset):
+ preprocessing = Pipeline([
+ CommonChannelSet(),
+ LowestFrequency(),
+ ToDataframe(),
+ MinMaxCentralizedNormalization(),
+ DynamicWindow(8),
+ ForkedPreprocessor(
+ inputs=[
+ SpearmanCorrelation(),
+ Mean(),
+ Variance(),
+ Skewness(),
+ Kurtosis(),
+ ZeroCrossing(),
+ AbsoluteArea(),
+ PeakToPeak(),
+ Bandpower(['Delta', 'Theta', 'Alpha', 'Beta'])
+ ],
+ output=ToMergedDataframes()
+ ),
+ ToNumpy()
+ ])
+
+ return dataset.set_pipeline(preprocessing).load()
+
+def adapt_data(data, test_size=0.1):
+ if isinstance(data, str):
+ with open(data, 'rb') as f:
+ data = pickle.load(f)
+ samples, labels = data['data'], data['labels']
+ x_train, x_test, y_train, y_test = train_test_split(samples, labels, test_size=test_size, shuffle=True, stratify=labels)
+ x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=test_size, shuffle=True, stratify=y_train)
+ classes = np.sort(np.unique(labels))
+ y_train = to_categorical(y_train, num_classes=len(classes))
+ y_test = to_categorical(y_test, num_classes=len(classes))
+ y_val = to_categorical(y_val, num_classes=len(classes))
+ return x_train, y_train, x_val, y_val, x_test, y_test
+
+def stratified_k_cross_validation_split(data, folds):
+ if isinstance(data, str):
+ with open(data, 'rb') as f:
+ data = pickle.load(f)
+ samples, labels = data['data'], data['labels']
+ skf = StratifiedKFold(n_splits=folds, shuffle=True)
+ return samples, labels, skf.split(samples, labels)
+
+def build_model(shape, classes, hparams, metrics):
+ print(hparams)
+ N = shape[2]
+ F = shape[3] - N
+ frames = shape[1]
+
+ def get_frame(x, frame, N, F):
+ x = tf.slice(x, [0, frame, 0, 0], [-1, 1, N, F])
+ x = tf.squeeze(x, axis=[1])
+ return x
+
+ input_0 = tf.keras.Input((frames, N, F + N))
+
+ layers = []
+ for frame in range(frames):
+ frame_matrix = tf.keras.layers.Lambda(
+ get_frame,
+ arguments={'frame': frame, 'N': N, 'F': F}
+ )(input_0)
+
+ x = tf.keras.layers.Conv1D(hparams['filters'], hparams['kernel'], data_format='channels_first')(frame_matrix)
+ x = tf.keras.layers.MaxPooling1D(hparams['pool'])(x)
+ x = tf.keras.layers.Flatten()(x)
+ layers.append(x)
+
+ combine = tf.keras.layers.Concatenate()(layers)
+ reshape = tf.keras.layers.Reshape((-1, frames))(combine)
+ lstm = tf.keras.layers.LSTM(hparams['hidden_units'])(reshape)
+ dropout = tf.keras.layers.Dropout(hparams['dropout'])(lstm)
+ out = tf.keras.layers.Dense(classes, activation='softmax')(dropout)
+
+ model = tf.keras.Model(inputs=[input_0], outputs=out)
+ model.compile(
+ optimizer=tf.keras.optimizers.Adam(learning_rate=hparams['learning_rate']),
+ loss='categorical_crossentropy',
+ metrics=[
+ 'accuracy',
+ Recall(class_id=0, name='recall'),
+ Specificity(class_id=0, name='specificity'),
+ Precision(class_id=0, name='precision'),
+ F1Score(class_id=0, name='f1score'),
+ ]
+ )
+ model.summary()
+ model.save('logs/plot_cnn.h5')
+ return model
+
+def run_trial(path, step, model, hparams, metrics, x_train, y_train, x_test, y_test, epochs):
+ with tf.summary.create_file_writer(path).as_default():
+ hp.hparams(hparams)
+ model.fit(x_train, y_train, epochs=epochs, batch_size=32, shuffle=True)
+ metrics_output = model.evaluate(x_test, y_test)
+ metrics_output = dict(zip(model.metrics_names, metrics_output))
+ for metric_name, metric_value in metrics_output.items():
+ if metric_name in metrics:
+ tf.summary.scalar(metric_name, metric_value, step=step)
+ return metrics_output
+
+def hparams_combinations(hparams, metrics):
+ hp.hparams_config(
+ hparams=list(hparams.values()),
+ metrics=[
+ hp.Metric(metric, display_name=metric.capitalize())
+ for metric in metrics
+ ]
+ )
+ hparams_keys = list(hparams.keys())
+ hparams_values = list(product(*[
+ h.domain.values
+ for h in hparams.values()
+ ]))
+ hparams = [
+ dict(zip(hparams_keys, values))
+ for values in hparams_values
+ ]
+ shuffle(hparams)
+ return hparams
+
+def tune_model(dataset_name, data, metrics, folds, val_size=0.1):
+ LOGS_DIR = join('./logs/cnn', dataset_name)
+ os.makedirs(LOGS_DIR, exist_ok=True)
+ # Prepare the data
+ data, labels, splitter = stratified_k_cross_validation_split(data, folds)
+ # Set tuning session and results
+ counter = 0
+ results = []
+ # Parameters to be tuned
+ hparam = {
+ 'learning_rate': 1e-4,
+ 'kernel': 5,
+ 'filters': 32,
+ 'pool': 5,
+ 'hidden_units': 64,
+ 'dropout': 0.10,
+ }
+ for train_idx, test_idx in splitter:
+ # Generate folds
+ x_train, y_train = data[train_idx], labels[train_idx]
+ x_test, y_test = data[test_idx], labels[test_idx]
+ classes = np.sort(np.unique(labels))
+ y_train = to_categorical(y_train, num_classes=len(classes))
+ y_test = to_categorical(y_test, num_classes=len(classes))
+ # Build the model
+ model = build_model(data.shape, len(classes), hparam, metrics)
+ # Run session
+ run_name = f'run-{counter}'
+ print(f'--- Starting trial: {run_name}')
+ print(hparam)
+ result = run_trial(
+ join(LOGS_DIR, run_name),
+ counter,
+ model,
+ hparam,
+ metrics,
+ x_train,
+ y_train,
+ x_test,
+ y_test,
+ epochs=50
+ )
+ counter += 1
+ results.append(result)
+ results = pd.DataFrame(results)
+ hparam = '-'.join([
+ str(key) + '_' + str(value)
+ for key, value in hparam.items()
+ ])
+ results.to_csv(join(LOGS_DIR, "validation-{}-{}-results.csv".format(dataset_name, hparam)))
+ print(results)
+
+
+if __name__ == '__main__':
+ dataset = {}
+
+ # dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/')
+
+ # dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/')
+ # dataset['tuh_eeg_artifact'].set_minimum_event_duration(4)
+
+ # dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/')
+ # dataset['tuh_eeg_seizure'].set_minimum_event_duration(4)
+
+ # dataset['eegmmidb'] = PhysioNetEEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/')
+ # dataset['eegmmidb'].set_minimum_event_duration(4)
+
+ dataset['chbmit'] = PhysioNetCHBMITDataset('../../data/physionet.org/files/chbmit/')
+ dataset['chbmit'].set_minimum_event_duration(4)
+
+ """
+ Note: You can just use paths as values in the dictionary
+ and comment-out the first line of the following for cycle ;)
+ """
+
+ for key, value in dataset.items():
+ metrics = [
+ 'accuracy',
+ 'recall',
+ 'specificity',
+ 'precision',
+ 'f1score',
+ ]
+ value = build_data(value)
+ tune_model(key, value, metrics, folds=10)
diff --git a/examples/tuh_eeg_abnormal/example_cnn_dense_classification.py b/examples/tuh_eeg_abnormal/example_cnn_dense_classification.py
index 392d04f..f813d4f 100755
--- a/examples/tuh_eeg_abnormal/example_cnn_dense_classification.py
+++ b/examples/tuh_eeg_abnormal/example_cnn_dense_classification.py
@@ -21,12 +21,11 @@
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
-from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \
+from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \
LowestFrequency, ToDataframe, DynamicWindow, BinarizedSpearmanCorrelation, \
ToNumpy
-dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf')
-dataset.set_cache_manager(PickleCache('../../export'))
+dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/')
preprocessing = Pipeline([
CommonChannelSet(),
diff --git a/examples/tuh_eeg_abnormal/example_cnn_dense_classification_with_filter.py b/examples/tuh_eeg_abnormal/example_cnn_dense_classification_with_filter.py
index 3455490..59674b2 100755
--- a/examples/tuh_eeg_abnormal/example_cnn_dense_classification_with_filter.py
+++ b/examples/tuh_eeg_abnormal/example_cnn_dense_classification_with_filter.py
@@ -21,12 +21,11 @@
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
-from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \
+from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \
LowestFrequency, BandPassFrequency, ToDataframe, DynamicWindow, \
BinarizedSpearmanCorrelation, ToNumpy
-dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf')
-dataset.set_cache_manager(PickleCache('../../export'))
+dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/')
preprocessing = Pipeline([
CommonChannelSet(),
diff --git a/examples/tuh_eeg_abnormal/example_cnn_lstm_classification.py b/examples/tuh_eeg_abnormal/example_cnn_lstm_classification.py
index 38407f1..e553bd3 100755
--- a/examples/tuh_eeg_abnormal/example_cnn_lstm_classification.py
+++ b/examples/tuh_eeg_abnormal/example_cnn_lstm_classification.py
@@ -21,12 +21,11 @@
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
-from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \
+from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \
LowestFrequency, ToDataframe, DynamicWindow, BinarizedSpearmanCorrelation, \
ToNumpy
-dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf')
-dataset.set_cache_manager(PickleCache('../../export'))
+dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/')
preprocessing = Pipeline([
CommonChannelSet(),
diff --git a/examples/tuh_eeg_abnormal/example_cnn_lstm_classification_with_filter.py b/examples/tuh_eeg_abnormal/example_cnn_lstm_classification_with_filter.py
index ea7dab1..faee56a 100755
--- a/examples/tuh_eeg_abnormal/example_cnn_lstm_classification_with_filter.py
+++ b/examples/tuh_eeg_abnormal/example_cnn_lstm_classification_with_filter.py
@@ -21,12 +21,11 @@
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
-from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \
+from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \
LowestFrequency, BandPassFrequency, ToDataframe, DynamicWindow, \
BinarizedSpearmanCorrelation, ToNumpy
-dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf')
-dataset.set_cache_manager(PickleCache('../../export'))
+dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/')
preprocessing = Pipeline([
CommonChannelSet(),
diff --git a/examples/tuh_eeg_abnormal/example_gat_lstm_classification.py b/examples/tuh_eeg_abnormal/example_gat_lstm_classification.py
index 3f85e02..cb6e7be 100755
--- a/examples/tuh_eeg_abnormal/example_gat_lstm_classification.py
+++ b/examples/tuh_eeg_abnormal/example_gat_lstm_classification.py
@@ -23,12 +23,11 @@
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
-from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \
+from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \
LowestFrequency, ToDataframe, DynamicWindow, BinarizedSpearmanCorrelation, \
CorrelationToAdjacency, Bandpower, GraphWithFeatures, ForkedPreprocessor
-dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf')
-dataset.set_cache_manager(PickleCache('../../export'))
+dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/')
preprocessing = Pipeline([
CommonChannelSet(),
diff --git a/examples/tuh_eeg_abnormal/example_gat_lstm_classification_with_filter.py b/examples/tuh_eeg_abnormal/example_gat_lstm_classification_with_filter.py
index 5bd0c3f..859ff78 100755
--- a/examples/tuh_eeg_abnormal/example_gat_lstm_classification_with_filter.py
+++ b/examples/tuh_eeg_abnormal/example_gat_lstm_classification_with_filter.py
@@ -23,13 +23,12 @@
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
-from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \
+from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \
LowestFrequency, BandPassFrequency, ToDataframe, DynamicWindow, \
BinarizedSpearmanCorrelation, CorrelationToAdjacency, Bandpower, \
GraphWithFeatures, ForkedPreprocessor
-dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf')
-dataset.set_cache_manager(PickleCache('../../export'))
+dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/')
preprocessing = Pipeline([
CommonChannelSet(),
diff --git a/pyeeglab/__init__.py b/pyeeglab/__init__.py
index cce3be6..df80b73 100644
--- a/pyeeglab/__init__.py
+++ b/pyeeglab/__init__.py
@@ -1,15 +1,11 @@
import logging
import warnings
-from importlib.util import find_spec
-from mne.utils import set_config
+logging.basicConfig(format="%(asctime)s %(levelname)7s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S")
from .dataset import *
-from .io import Raw
-from .cache import PickleCache
-from .pipeline import Pipeline, ForkedPreprocessor
+from .pipeline import *
from .preprocess import *
logging.getLogger().setLevel(logging.DEBUG)
-
warnings.filterwarnings("ignore", category=RuntimeWarning)
diff --git a/pyeeglab/cache/__init__.py b/pyeeglab/cache/__init__.py
deleted file mode 100644
index eb3bba2..0000000
--- a/pyeeglab/cache/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .cache import Cache, PickleCache
diff --git a/pyeeglab/cache/cache.py b/pyeeglab/cache/cache.py
deleted file mode 100644
index 202b9bd..0000000
--- a/pyeeglab/cache/cache.py
+++ /dev/null
@@ -1,65 +0,0 @@
-import json
-import logging
-from abc import ABC, abstractmethod
-from typing import List, Dict
-
-from os.path import isfile, join
-from pathlib import Path
-from hashlib import md5
-from pickle import load, dump
-
-from ..io import DataLoader
-from ..pipeline import Pipeline
-
-class Cache(ABC):
-
- def __init__(self) -> None:
- logging.debug('Create cache manager')
-
- def _get_cache_key(self, dataset: str, loader: DataLoader, pipeline: Pipeline) -> str:
- if dataset.endswith('dataset'):
- dataset = dataset[:-7]
- key = [loader, pipeline]
- key = [hash(k) for k in key]
- key = [str(k).encode() for k in key]
- key = [md5(k).hexdigest()[:10] for k in key]
- key = list(zip(['loader', 'pipeline'], key))
- key = ['_'.join(k) for k in key]
- key = dataset + '_' + '_'.join(key)
- return key
-
- @abstractmethod
- def load(self, dataset: str, loader: DataLoader, pipeline: Pipeline) -> Dict:
- pass
-
-
-class PickleCache(Cache):
-
- def __init__(self, path: str):
- super().__init__()
- logging.debug('Create single pickle cache manager')
- Path(path).mkdir(parents=True, exist_ok=True)
- self.path = path
-
- def load(self, dataset: str, loader: DataLoader, pipeline: Pipeline):
- logging.debug('Computing cache key')
- key = self._get_cache_key(dataset, loader, pipeline)
- logging.debug('Computed cache key: %s', key)
- key = join(self.path, key + '.pkl')
- if isfile(key):
- logging.debug('Cache file found')
- with open(key, 'rb') as file:
- try:
- logging.debug('Loading cache file')
- data = load(file)
- return data
- except:
- logging.debug('Loading cache file failed')
- pass
- logging.debug('Cache file not found, genereting new one')
- data = loader.get_dataset()
- data = pipeline.run(data)
- with open(key, 'wb') as file:
- logging.debug('Dumping cache file')
- dump(data, file)
- return data
diff --git a/pyeeglab/database/__init__.py b/pyeeglab/database/__init__.py
deleted file mode 100644
index fdcd6bb..0000000
--- a/pyeeglab/database/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .tables import *
diff --git a/pyeeglab/database/tables.py b/pyeeglab/database/tables.py
deleted file mode 100644
index a5fd22b..0000000
--- a/pyeeglab/database/tables.py
+++ /dev/null
@@ -1,106 +0,0 @@
-import logging
-
-from dataclasses import dataclass
-
-from sqlalchemy import Column, Integer, Float, Text, ForeignKey
-from sqlalchemy.ext.declarative import declarative_base
-
-from ..io.raw import Raw
-
-
-BASE_TABLE = declarative_base()
-
-
-@dataclass
-class File(BASE_TABLE):
- """File represents a single file contained in the dataset.
-
- This is an ORM class derived from the BASE_TABLE in a declarative base
- used by SQLAlchemy.
-
- Attributes
- ----------
- id : str
- The primary key generated randomly using an UUID4 generator.
- extension : str
- A not null indexed string reporting the EEG recording format.
- path : str
- A not null string used to point to the relative path of the file respect
- to the current sqlite database location.
- """
- __tablename__ = 'file'
- id: str = Column(Text, primary_key=True)
- extension: str = Column(Text, nullable=False, index=True)
- path: str = Column(Text, nullable=False)
-
-
-@dataclass
-class Metadata(BASE_TABLE):
- """ Metadata represents a single metadata record associated with a single
- file contained in the dataset.
-
- This is an ORM class derived from the BASE_TABLE in a declarative base
- used by SQLAlchemy.
-
- Attributes
- ----------
- file_id : str
- The foreign key related to file_id, used also as a primary key for metadata
- table since this is a one-to-one relationship.
- duration : int
- This is the EEG sample duration reported in seconds. This is not inteded as
- a precise duration estimate, but only a reference for statistical analysis.
- For more precise duration measurement, please, use the Raw record class methods.
- channels_count : int
- This field report the number of channels reported in the EEG header.
- channels_reference : str
- An indexed string describing the EEG channel reference system.
- channels_set : str
- The list of channels saved as a JSON string extracted from the EEG header.
- sampling_frequency : int
- The sampling fequency expressend in Hz extrated from the EEG header.
- max_value : float
- The max value sampled in this record across all channels.
- min_value : float
- The min value sampled in this record across all channels.
- """
- __tablename__ = 'metadata'
- file_id: str = Column(Text, ForeignKey('file.id'), primary_key=True)
- duration: int = Column(Integer, nullable=False)
- channels_count: int = Column(Integer, nullable=False)
- channels_reference: str = Column(Text, nullable=True, index=True)
- channels_set: str = Column(Text, nullable=False, index=True)
- sampling_frequency: int = Column(Integer, nullable=False, index=True)
- max_value: float = Column(Float, nullable=False)
- min_value: float = Column(Float, nullable=False)
-
-
-@dataclass
-class Event(BASE_TABLE):
- """Event represents a single event associated to a single file.
- Multiple events can be associated to multiple files, according to
- the record annotations.
-
- Attributes
- ----------
- id : str
- This is the primary key that is associated to each event. It is
- created from a UUID4.
- file_id : str
- This the foreign key that link the event to the related file.
- begin : float
- The timestep expressed in seconds from which the event begins.
- end : float
- The timestep expressed in seconds at which the event ends.
- duration : float
- The entire duration of the event.
- label : str
- The label of this event as written inthe annotations.
- """
- __tablename__ = 'event'
- id: str = Column(Text, primary_key=True)
- file_id: str = Column(Text, ForeignKey('file.id'), index=True)
- begin: float = Column(Float, nullable=False)
- end: float = Column(Float, nullable=False)
- duration: float = Column(Float, nullable=False)
- label: str = Column(Text, nullable=False, index=True)
diff --git a/pyeeglab/dataset/__init__.py b/pyeeglab/dataset/__init__.py
index 6b1dac2..3880a61 100644
--- a/pyeeglab/dataset/__init__.py
+++ b/pyeeglab/dataset/__init__.py
@@ -1,5 +1,2 @@
-from .tuh_eeg_abnormal import TUHEEGAbnormalLoader, TUHEEGAbnormalDataset
-from .tuh_eeg_artifact import TUHEEGArtifactLoader, TUHEEGArtifactDataset
-from .tuh_eeg_seizure import TUHEEGSeizureLoader, TUHEEGSeizureDataset
-from .eegmmidb import EEGMMIDBLoader, EEGMMIDBDataset
-from .chbmit import CHBMITLoader, CHBMITDataset
+from .physionet import *
+from .tuh_eeg import *
diff --git a/pyeeglab/dataset/annotation.py b/pyeeglab/dataset/annotation.py
new file mode 100644
index 0000000..db8aee8
--- /dev/null
+++ b/pyeeglab/dataset/annotation.py
@@ -0,0 +1,33 @@
+from dataclasses import dataclass
+from mne.io import Raw, read_raw
+from sqlalchemy import Column, ForeignKey, Text, Float
+from sqlalchemy.orm import relationship
+from .declarative_base import Base
+
+
+@dataclass
+class Annotation(Base):
+ __tablename__ = "annotation"
+ uuid: str = Column(Text, primary_key=True)
+ file_uuid: str = Column(Text, ForeignKey("file.uuid"), nullable=False)
+ begin: float = Column(Float, nullable=False)
+ end: float = Column(Float, nullable=False)
+ label: str = Column(Text, nullable=False, index=True)
+
+ file = relationship("File", lazy="subquery")
+
+ @property
+ def duration(self) -> float:
+ return self.end - self.begin
+
+ def __enter__(self) -> Raw:
+ self.reader = read_raw(self.file.path)
+ tmax = self.reader.n_times / self.reader.info["sfreq"] - 0.1
+ tmax = tmax if self.end > tmax else self.end
+ self.reader.crop(self.begin, tmax)
+ return self.reader
+
+ def __exit__(self, *args, **kwargs) -> None:
+ self.reader.close()
+ del self.reader
+
diff --git a/pyeeglab/dataset/chbmit/__init__.py b/pyeeglab/dataset/chbmit/__init__.py
deleted file mode 100644
index a592069..0000000
--- a/pyeeglab/dataset/chbmit/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .chbmit_loader import CHBMITLoader
-from .chbmit_dataset import CHBMITDataset
diff --git a/pyeeglab/dataset/chbmit/chbmit_dataset.py b/pyeeglab/dataset/chbmit/chbmit_dataset.py
deleted file mode 100644
index 4692ae1..0000000
--- a/pyeeglab/dataset/chbmit/chbmit_dataset.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from typing import Dict
-
-from .chbmit_loader import CHBMITLoader
-from ..dataset import Dataset
-
-
-class CHBMITDataset(Dataset):
-
- def __init__(self, path: str = './data/physionet.org/files/chbmit/1.0.0/') -> None:
- super().__init__(CHBMITLoader(path))
-
- def _get_dataset_env(self) -> Dict:
- env = super()._get_dataset_env()
- env['class_id'] = 'noseizure'
- return env
diff --git a/pyeeglab/dataset/chbmit/chbmit_index.py b/pyeeglab/dataset/chbmit/chbmit_index.py
deleted file mode 100644
index 05a1df9..0000000
--- a/pyeeglab/dataset/chbmit/chbmit_index.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import logging
-
-from uuid import uuid4, uuid5, NAMESPACE_X500
-from os.path import isfile, join, sep
-from wfdb import rdann
-
-from ...io import Index
-from ...database import File, Event
-
-from typing import List
-
-
-class CHBMITIndex(Index):
- def __init__(self, path: str) -> None:
- logging.debug('Create CHB-MIT Scalp EEG Index')
- super().__init__('sqlite:///' + join(path, 'index.db'), path)
- self.index()
-
- def _get_file(self, path: str) -> File:
- length = len(self.path)
- meta = path[length:].split(sep)
- return File(
- id=str(uuid5(NAMESPACE_X500, path[length:])),
- extension=meta[-1].split('.')[-1],
- path=path[length:],
- )
-
- def _get_record_events(self, file: File) -> List[Event]:
- logging.debug('Add file %s raw events to index', file.id)
- path = join(self.path, file.path)
- if isfile(path + '.seizures'):
- events = rdann(path, 'seizures')
- events = list(events.sample / events.fs)
- events = [events[i:i+2] for i in range(0, len(events), 2)]
- events = [
- Event(
- id=str(uuid4()),
- file_id=file.id,
- begin=event[0],
- end=event[1],
- duration=(event[1] - event[0]),
- label='seizure'
- )
- for event in events
- ]
- else:
- events = [
- Event(
- id=str(uuid4()),
- file_id=file.id,
- begin=60,
- end=120,
- duration=60,
- label='noseizure'
- )
- ]
- return events
diff --git a/pyeeglab/dataset/chbmit/chbmit_loader.py b/pyeeglab/dataset/chbmit/chbmit_loader.py
deleted file mode 100644
index 022c057..0000000
--- a/pyeeglab/dataset/chbmit/chbmit_loader.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import logging
-
-from typing import List
-
-from ...io import DataLoader
-from .chbmit_index import CHBMITIndex
-
-
-class CHBMITLoader(DataLoader):
-
- def __init__(self, path: str) -> None:
- exclude_files = [
- 'chb03/chb03_35.edf', # Corrupted data
- 'chb12/chb12_27.edf', # Bad channel names
- 'chb12/chb12_28.edf', # Bad channel names
- 'chb12/chb12_29.edf', # Bad channel names
- ]
- super().__init__(path, exclude_files=exclude_files)
- logging.debug('Create CHB-MIT Scalp EEG Loader')
- self.index = CHBMITIndex(self.path)
diff --git a/pyeeglab/dataset/dataset.py b/pyeeglab/dataset/dataset.py
index 692b3a2..fc67639 100644
--- a/pyeeglab/dataset/dataset.py
+++ b/pyeeglab/dataset/dataset.py
@@ -1,45 +1,286 @@
+import os
+import json
import logging
-from abc import ABC
-from typing import Dict
+import hashlib
+import pickle
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from functools import reduce
+from multiprocessing import Pool, cpu_count
+from operator import add, and_
+from uuid import uuid4, uuid5, NAMESPACE_X500
+
+from typing import Dict, List, Tuple
+
+import mne
+from sqlalchemy import create_engine
+from sqlalchemy.orm import Session, sessionmaker, Query
+
+from .declarative_base import Base
+from .file import File
+from .metadata import Metadata
+from .annotation import Annotation
-from ..io import DataLoader
-from ..cache import Cache, PickleCache
from ..pipeline import Pipeline
+
+@dataclass(init=False)
class Dataset(ABC):
+ path: str
+ name: str
+ version: str
- cache: Cache
- loader: DataLoader
- pipeline: Pipeline
+ extensions: List[str]
+ exclude_file: List[str]
+ exclude_channels_set: List[str]
+ exclude_channels_reference: List[str]
+ exclude_sampling_frequency: List[int]
+ minimum_annotation_duration: float
- def __init__(self, loader: DataLoader) -> None:
- logging.debug('Create dataset')
- self.loader = loader
- self.set_cache_manager(PickleCache('export'))
- self.set_pipeline(Pipeline())
+ session: Session
+ query: Query
- def load(self) -> Dict:
- dataset = self.__class__.__name__.lower()
- return self.cache.load(dataset, self.loader, self.pipeline)
+ pipeline: Pipeline = None
+
+ def __init__(
+ self,
+ path: str,
+ name: str,
+ version: str = None,
+ extensions: List[str] = [".edf"],
+ exclude_file: List[str] = None,
+ exclude_channels_set: List[str] = None,
+ exclude_channels_reference: List[str] = None,
+ exclude_sampling_frequency: List[str] = None,
+ minimum_annotation_duration: float = None
+ ) -> None:
+ # Set basic attributes
+ self.path = os.path.abspath(os.path.join(path, version))
+ self.name = name
+ self.version = version
- def _get_dataset_env(self) -> Dict:
+ # Set data set filter attributes
+ self.extensions = extensions if extensions else []
+ self.exclude_file = exclude_file if exclude_file else []
+ self.exclude_channels_set = exclude_channels_set if exclude_channels_set else []
+ self.exclude_channels_reference = exclude_channels_reference if exclude_channels_reference else []
+ self.exclude_sampling_frequency = exclude_sampling_frequency if exclude_sampling_frequency else []
+ self.minimum_annotation_duration = minimum_annotation_duration if minimum_annotation_duration else 0
+
+ logging.info("Init dataset '%s'@'%s' at '%s'", self.name, self.version, self.path)
+
+ # Make workspace directory
+ logging.debug("Make .pyeeglab directory")
+ workspace = os.path.join(self.path, ".pyeeglab")
+ os.makedirs(workspace, exist_ok=True)
+ logging.debug("Make .pyeeglab/cache directory")
+ os.makedirs(os.path.join(workspace, "cache"), exist_ok=True)
+ logging.debug("Set MNE log .pyeeglab/mne.log")
+ mne.set_log_file(os.path.join(workspace, "mne.log"), overwrite=False)
+
+ # Index data set files
+ self.index()
+
+ def __getstate__(self):
+ # Workaround for unpickable sqlalchemy.orm.session
+ # during multiprocess dataset loading
+ state = self.__dict__.copy()
+ for attribute in ["session", "query"]:
+ if hasattr(self, attribute):
+ del state[attribute]
+ return state
+
+ @abstractmethod
+ def download(self, user: str = None, password: str = None) -> None:
+ pass
+
+ def index(self) -> None:
+ # Init index session
+ logging.debug("Make index session")
+ connection = os.path.join(self.path, ".pyeeglab", "index.sqlite3")
+ connection = create_engine("sqlite:///" + connection)
+ Base.metadata.create_all(connection)
+ self.session = sessionmaker(bind=connection)()
+ # Open multiprocess pool
+ logging.info("Index data set directory")
+ pool = Pool(cpu_count())
+ # Get files path from data set path
+ paths = [
+ os.path.join(directory, filename)
+ for directory, _, filenames in os.walk(self.path)
+ for filename in filenames
+ ]
+ # Get Files instances form paths, filtering already indexed
+ files = self.session.query(File).all()
+ files = [file.uuid for file in files]
+ files = [
+ file
+ for file in pool.map(self._get_file, paths)
+ if file.uuid not in files
+ ]
+ for file in files:
+ logging.debug("Add file %s to index", file.uuid)
+ # Filter raw data files by extension
+ raws = [
+ file
+ for file in files
+ if os.path.splitext(file.path)[-1] in self.extensions
+ ]
+ # Get metadata and annotation for data files
+ metadatas = pool.map(self._get_metadata, raws)
+ annotations = pool.map(self._get_annotation, raws)
+ # Close multiprocess pool
+ pool.close()
+ pool.join()
+ # Commit insertions to index
+ commits = files + metadatas + reduce(add, annotations, [])
+ if commits:
+ logging.info("Commit insertions to index")
+ self.session.add_all(commits)
+ self.session.commit()
+ logging.info("Index data set completed")
+ # Init default query
+ logging.debug("Init default query")
+ self.query = self.session.query(File, Metadata, Annotation).\
+ join(File.meta).\
+ join(File.annotations).\
+ filter(~Metadata.channels_reference.in_(self.exclude_channels_reference)).\
+ filter(~Metadata.sampling_frequency.in_(self.exclude_sampling_frequency)).\
+ filter((Annotation.end - Annotation.begin) >= self.minimum_annotation_duration)
+ # Filter exclude file paths
+ for file in self.exclude_file:
+ self.query = self.query.filter(~File.path.like("%{}%".format(file)))
+ logging.debug("SQL query representation: '%s'", str(self.query).replace("\n", ""))
+
+ def _get_file(self, path: str) -> File:
+ return File(
+ uuid=str(uuid5(NAMESPACE_X500, path)),
+ path=path,
+ extension=os.path.splitext(path)[-1]
+ )
+
+ def _get_metadata(self, file: File) -> Metadata:
+ logging.debug("Add file %s metadata to index", file.uuid)
+ with file as reader:
+ info = reader.info
+ metadata = Metadata(
+ file_uuid=file.uuid,
+ duration=reader.n_times/info["sfreq"],
+ channels_set=json.dumps(info["ch_names"]),
+ sampling_frequency=info["sfreq"],
+ max_value=reader.get_data().max(),
+ min_value=reader.get_data().min(),
+ )
+ return metadata
+
+ def _get_annotation(self, file: File) -> List[Annotation]:
+ logging.debug("Add file %s annotations to index", file.uuid)
+ with file as reader:
+ annotations = [
+ Annotation(
+ uuid=str(uuid4()),
+ file_uuid=file.uuid,
+ begin=annotation[0],
+ end=annotation[0]+annotation[1],
+ label=annotation[2],
+ )
+ for annotation in reader.annotations
+ ]
+ return annotations
+
+ @property
+ def environment(self) -> Dict:
+ min_max = self.signal_min_max_range
return {
- 'channels_set': self.loader.get_channels_set(),
- 'lowest_frequency': self.loader.get_lowest_frequency(),
- 'max_value': self.loader.get_max_value(),
- 'min_value': self.loader.get_min_value(),
+ "channels_set": self.maximal_channels_subset,
+ "lowest_frequency": self.lowest_frequency,
+ "min_value": min_max[0],
+ "max_value": min_max[1],
}
- def set_minimum_event_duration(self, minimum_event_duration: float) -> 'Dataset':
- self.loader.set_minimum_event_duration(minimum_event_duration)
- return self
+ @property
+ def lowest_frequency(self) -> float:
+ frequency = self.query.all()
+ frequency = min([
+ f[1].sampling_frequency
+ for f in frequency
+ ], default=0)
+ return frequency
- def set_cache_manager(self, cache: Cache) -> 'Dataset':
- self.cache = cache
- return self
-
- def set_pipeline(self, pipeline: Pipeline) -> 'Dataset':
+ @property
+ def maximal_channels_subset(self) -> List[str]:
+ channels = self.query.group_by(Metadata.channels_set).all()
+ channels = [
+ frozenset(json.loads(channel[1].channels_set))
+ for channel in channels
+ ]
+ channels = reduce(and_, channels)
+ channels = channels - frozenset(self.exclude_channels_set)
+ channels = sorted(channels)
+ return channels
+
+ @property
+ def signal_min_max_range(self) -> Tuple[float]:
+ min_max = self.query.all()
+ min_max = [m[1] for m in min_max]
+ min_max = tuple([
+ min([m.min_value for m in min_max], default=0),
+ max([m.max_value for m in min_max], default=0),
+ ])
+ return min_max
+
+ def set_pipeline(self, pipeline: Pipeline) -> "Dataset":
self.pipeline = pipeline
- environment = self._get_dataset_env()
- self.pipeline.environment.update(environment)
+ self.pipeline.environment.update(self.environment)
+ return self
+
+ def set_minimum_event_duration(self, duration: float) -> "Dataset":
+ logging.warning("This function will be deprecated in the near future")
+ self.minimum_annotation_duration = duration
return self
+
+ def load(self) -> Dict:
+ # Compute cache path
+ cache = os.path.join(self.path, ".pyeeglab", "cache")
+ # Compute cache key
+ logging.info("Compute cache key")
+ name = self.__class__.__name__.lower()
+ if name.endswith("dataset"):
+ name = name[:-len("dataset")]
+ key = [hash(self), hash(self.pipeline)]
+ key = [str(k).encode() for k in key]
+ key = [hashlib.md5(k).hexdigest()[:10] for k in key]
+ key = list(zip(["loader", "pipeline"], key))
+ key = ["_".join(k) for k in key]
+ key = name + "_" + "_".join(key)
+ logging.info("Computed cache key: %s", key)
+ # Load file cache
+ cache = os.path.join(cache, key + ".pkl")
+ if os.path.exists(cache):
+ logging.info("Cache file found at %s", cache)
+ with open(cache, "rb") as reader:
+ try:
+ logging.info("Loading cache file")
+ return pickle.load(reader)
+ except:
+ logging.error("Loading cache file failed")
+ # Cache file not found, start preprocessing
+ logging.info("Cache file not found, genereting new one")
+ data = [row[2] for row in self.query.all()]
+ data = self.pipeline.run(data)
+ with open(cache, "wb") as file:
+ logging.info("Dumping cache file")
+ pickle.dump(data, file)
+ return data
+
+ def __hash__(self) -> int:
+ key = [self.path, self.version, self.minimum_annotation_duration]
+ key += self.exclude_file
+ key += self.exclude_channels_set
+ key += self.exclude_channels_reference
+ key += self.exclude_sampling_frequency
+ key = json.dumps(key).encode()
+ key = hashlib.md5(key).hexdigest()
+ key = int(key, 16)
+ return key
diff --git a/pyeeglab/dataset/declarative_base.py b/pyeeglab/dataset/declarative_base.py
new file mode 100644
index 0000000..c64447d
--- /dev/null
+++ b/pyeeglab/dataset/declarative_base.py
@@ -0,0 +1,2 @@
+from sqlalchemy.ext.declarative import declarative_base
+Base = declarative_base()
diff --git a/pyeeglab/dataset/eegmmidb/__init__.py b/pyeeglab/dataset/eegmmidb/__init__.py
deleted file mode 100644
index 4cbd826..0000000
--- a/pyeeglab/dataset/eegmmidb/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .eegmmidb_loader import EEGMMIDBLoader
-from .eegmmidb_dataset import EEGMMIDBDataset
diff --git a/pyeeglab/dataset/eegmmidb/eegmmidb_dataset.py b/pyeeglab/dataset/eegmmidb/eegmmidb_dataset.py
deleted file mode 100644
index a5aa15d..0000000
--- a/pyeeglab/dataset/eegmmidb/eegmmidb_dataset.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from .eegmmidb_loader import EEGMMIDBLoader
-from ..dataset import Dataset
-
-
-class EEGMMIDBDataset(Dataset):
-
- def __init__(self, path: str = './data/physionet.org/files/eegmmidb/1.0.0/') -> None:
- super().__init__(EEGMMIDBLoader(path))
diff --git a/pyeeglab/dataset/eegmmidb/eegmmidb_index.py b/pyeeglab/dataset/eegmmidb/eegmmidb_index.py
deleted file mode 100644
index 83af529..0000000
--- a/pyeeglab/dataset/eegmmidb/eegmmidb_index.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import logging
-
-from uuid import uuid5, NAMESPACE_X500
-from os.path import join, sep
-
-from ...io import Index
-from ...database import File
-
-
-class EEGMMIDBIndex(Index):
- def __init__(self, path: str) -> None:
- logging.debug('Create EEG Motor Movement/Imagery Index')
- super().__init__('sqlite:///' + join(path, 'index.db'), path)
- self.index()
-
- def _get_file(self, path: str) -> File:
- length = len(self.path)
- meta = path[length:].split(sep)
- return File(
- id=str(uuid5(NAMESPACE_X500, path[length:])),
- extension=meta[-1].split('.')[-1],
- path=path[length:]
- )
diff --git a/pyeeglab/dataset/eegmmidb/eegmmidb_loader.py b/pyeeglab/dataset/eegmmidb/eegmmidb_loader.py
deleted file mode 100644
index b9bde7b..0000000
--- a/pyeeglab/dataset/eegmmidb/eegmmidb_loader.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import logging
-
-from typing import List
-
-from ...io import DataLoader
-from .eegmmidb_index import EEGMMIDBIndex
-
-
-class EEGMMIDBLoader(DataLoader):
-
- def __init__(self, path: str, exclude_frequency: List[int] = [128]) -> None:
- exclude_files = [
- 'S021/S021R08.edf', # Corrupted data
- 'S104/S104R04.edf', # Corrupted data
- ]
- super().__init__(path, exclude_files = exclude_files, exclude_frequency = exclude_frequency)
- logging.debug('Create EEG Motor Movement/Imagery Loader')
- self.index = EEGMMIDBIndex(self.path)
diff --git a/pyeeglab/dataset/file.py b/pyeeglab/dataset/file.py
new file mode 100644
index 0000000..e4cbbb1
--- /dev/null
+++ b/pyeeglab/dataset/file.py
@@ -0,0 +1,24 @@
+from dataclasses import dataclass
+from mne.io import Raw, read_raw
+from sqlalchemy import Column, Text
+from sqlalchemy.orm import relationship
+from .declarative_base import Base
+
+
+@dataclass
+class File(Base):
+ __tablename__ = "file"
+ uuid: str = Column(Text, primary_key=True)
+ path: str = Column(Text, nullable=False)
+ extension: str = Column(Text, nullable=False)
+
+ meta = relationship("Metadata", cascade="all,delete", backref="File")
+ annotations = relationship("Annotation", cascade="all,delete", backref="File")
+
+ def __enter__(self) -> Raw:
+ self.reader = read_raw(self.path)
+ return self.reader
+
+ def __exit__(self, *args, **kwargs) -> None:
+ self.reader.close()
+ del self.reader
diff --git a/pyeeglab/dataset/metadata.py b/pyeeglab/dataset/metadata.py
new file mode 100644
index 0000000..fb770e4
--- /dev/null
+++ b/pyeeglab/dataset/metadata.py
@@ -0,0 +1,15 @@
+from dataclasses import dataclass
+from sqlalchemy import Column, ForeignKey, Text, Float, Integer
+from .declarative_base import Base
+
+
+@dataclass
+class Metadata(Base):
+ __tablename__ = "metadata"
+ file_uuid: str = Column(Text, ForeignKey("file.uuid"), primary_key=True)
+ duration: int = Column(Float, nullable=False)
+ channels_set: str = Column(Text, nullable=False, index=True)
+ channels_reference: str = Column(Text, nullable=True, index=True)
+ sampling_frequency: int = Column(Integer, nullable=False, index=True)
+ max_value: float = Column(Float, nullable=False)
+ min_value: float = Column(Float, nullable=False)
diff --git a/pyeeglab/dataset/physionet/__init__.py b/pyeeglab/dataset/physionet/__init__.py
new file mode 100644
index 0000000..3e3a51d
--- /dev/null
+++ b/pyeeglab/dataset/physionet/__init__.py
@@ -0,0 +1,2 @@
+from .chbmit_dataset import *
+from .eegmmidb_dataset import *
diff --git a/pyeeglab/dataset/physionet/chbmit_dataset.py b/pyeeglab/dataset/physionet/chbmit_dataset.py
new file mode 100644
index 0000000..8fe88ea
--- /dev/null
+++ b/pyeeglab/dataset/physionet/chbmit_dataset.py
@@ -0,0 +1,72 @@
+import os
+import logging
+
+from uuid import uuid4
+
+from typing import List
+
+from .utils import wget
+
+from ..dataset import Dataset
+from ..file import File
+from ..annotation import Annotation
+
+import wfdb
+
+
+class PhysioNetCHBMITDataset(Dataset):
+
+ def __init__(
+ self,
+ path: str = "./data/physionet.org/files/chbmit/",
+ version: str = "1.0.0",
+ exclude_file: List[str] = [
+ "chb03/chb03_35.edf", # Corrupted data
+ "chb12/chb12_27.edf", # Bad channel names
+ "chb12/chb12_28.edf", # Bad channel names
+ "chb12/chb12_29.edf", # Bad channel names
+ ],
+ exclude_channels_set: List[str] = [
+ "ECG",
+ "LOC-ROC",
+ "LUE-RAE",
+ "VNS",
+ ],
+ ) -> None:
+ super().__init__(
+ path=path,
+ name="CHB-MIT Scalp EEG Database",
+ version=version,
+ exclude_file=exclude_file,
+ exclude_channels_set=exclude_channels_set,
+ )
+
+ def download(self, user: str = None, password: str = None) -> None:
+ wget(self.path, user, password, "chbmit", self.version)
+
+ def _get_annotation(self, file: File) -> List[Annotation]:
+ logging.debug("Add file %s annotations to index", file.uuid)
+ annotations = [
+ Annotation(
+ uuid=str(uuid4()),
+ file_uuid=file.uuid,
+ begin=60,
+ end=120,
+ label="noseizure"
+ )
+ ]
+ if os.path.isfile(file.path + ".seizures"):
+ annotations = wfdb.rdann(file.path, "seizures")
+ annotations = list(annotations.sample / annotations.fs)
+ annotations = [annotations[i:i+2] for i in range(0, len(annotations), 2)]
+ annotations = [
+ Annotation(
+ uuid=str(uuid4()),
+ file_uuid=file.uuid,
+ begin=annotation[0],
+ end=annotation[1],
+ label="seizure"
+ )
+ for annotation in annotations
+ ]
+ return annotations
diff --git a/pyeeglab/dataset/physionet/eegmmidb_dataset.py b/pyeeglab/dataset/physionet/eegmmidb_dataset.py
new file mode 100644
index 0000000..f6da608
--- /dev/null
+++ b/pyeeglab/dataset/physionet/eegmmidb_dataset.py
@@ -0,0 +1,64 @@
+import os
+import logging
+
+from uuid import uuid4
+
+from typing import List
+
+from .utils import wget
+
+from ..dataset import Dataset
+from ..file import File
+from ..annotation import Annotation
+
+
+class PhysioNetEEGMMIDBDataset(Dataset):
+
+ def __init__(
+ self,
+ path: str = "./data/physionet.org/files/eegmmidb/",
+ version: str = "1.0.0",
+ exclude_file: List[str] = [
+ "S021/S021R08.edf", # Corrupted data
+ "S104/S104R04.edf", # Corrupted data
+ ],
+ exclude_sampling_frequency: List[int] = [ 128 ],
+ ) -> None:
+ super().__init__(
+ path=path,
+ name="EEG Motor Movement/Imagery Dataset",
+ version=version,
+ exclude_file=exclude_file,
+ exclude_sampling_frequency=exclude_sampling_frequency,
+ )
+
+ def download(self, user: str = None, password: str = None) -> None:
+ wget(self.path, user, password, "eegmmidb", self.version)
+
+ def _get_annotation(self, file: File) -> List[Annotation]:
+ logging.debug("Add file %s annotations to index", file.uuid)
+ with file as reader:
+ try:
+ annotations = [
+ Annotation(
+ uuid=str(uuid4()),
+ file_uuid=file.uuid,
+ begin=annotation[0],
+ end=annotation[0]+annotation[1],
+ label=annotation[2],
+ )
+ for annotation in reader.annotations
+ ]
+ except KeyError:
+ # Alternative annotation format
+ annotations = [
+ Annotation(
+ uuid=str(uuid4()),
+ file_uuid=file.uuid,
+ begin=annotation["onset"],
+ end=annotation["onset"]+annotation["duration"],
+ label=annotation["description"],
+ )
+ for annotation in reader.annotations
+ ]
+ return annotations
diff --git a/pyeeglab/dataset/physionet/utils.py b/pyeeglab/dataset/physionet/utils.py
new file mode 100644
index 0000000..c379903
--- /dev/null
+++ b/pyeeglab/dataset/physionet/utils.py
@@ -0,0 +1,30 @@
+import logging
+import subprocess
+
+
+def wget(path: str, user: str, password: str, slug: str, version: str) -> None:
+ logging.info("Download started, it will take some time")
+ url = "https://physionet.org/files/" + slug + "/" + version + "/"
+ process = subprocess.Popen(
+ [
+ "wget",
+ "-r",
+ "-N",
+ "-c",
+ "-np",
+ "-nH",
+ "--cut-dirs=3",
+ url,
+ "-P",
+ path
+ ],
+ stdout=subprocess.PIPE,
+ universal_newlines=True,
+ )
+ while True:
+ output = process.stdout.readline()
+ logging.info(output.strip())
+ if process.poll() is not None:
+ for output in process.stdout.readlines():
+ logging.info(output.strip())
+ break
diff --git a/pyeeglab/dataset/tuh_eeg/__init__.py b/pyeeglab/dataset/tuh_eeg/__init__.py
new file mode 100644
index 0000000..d3db617
--- /dev/null
+++ b/pyeeglab/dataset/tuh_eeg/__init__.py
@@ -0,0 +1,3 @@
+from .abnormal_dataset import *
+from .artifact_dataset import *
+from .seizure_dataset import *
diff --git a/pyeeglab/dataset/tuh_eeg/abnormal_dataset.py b/pyeeglab/dataset/tuh_eeg/abnormal_dataset.py
new file mode 100644
index 0000000..62e885d
--- /dev/null
+++ b/pyeeglab/dataset/tuh_eeg/abnormal_dataset.py
@@ -0,0 +1,59 @@
+import os
+import logging
+
+from uuid import uuid4
+
+from typing import List
+
+from .utils import rsync
+
+from ..dataset import Dataset
+from ..file import File
+from ..metadata import Metadata
+from ..annotation import Annotation
+
+
+class TUHEEGAbnormalDataset(Dataset):
+
+ def __init__(
+ self,
+ path: str = "./data/tuh_eeg_abnormal/",
+ version: str = "2.0.0",
+ exclude_channels_set: List[str] = [
+ "BURSTS",
+ "ECG EKG-REF",
+ "EMG-REF",
+ "IBI",
+ "PHOTIC-REF",
+ "PULSE RATE",
+ "STI 014",
+ "SUPPR"
+ ],
+ ) -> None:
+ super().__init__(
+ path=path,
+ name="Temple University Hospital EEG Abnormal Dataset",
+ version="v"+version,
+ exclude_channels_set=exclude_channels_set,
+ )
+
+ def download(self, user: str = None, password: str = None) -> None:
+ rsync(self.path, user, password, "tuh_eeg_abnormal", self.version)
+
+ def _get_metadata(self, file: File) -> Metadata:
+ meta = file.path.split(os.path.sep)
+ metadata = super()._get_metadata(file)
+ metadata.channels_reference = meta[-5]
+ return metadata
+
+ def _get_annotation(self, file: File) -> List[Annotation]:
+ logging.debug("Add file %s annotations to index", file.uuid)
+ return [
+ Annotation(
+ uuid=str(uuid4()),
+ file_uuid=file.uuid,
+ begin=60,
+ end=120,
+ label=file.path.split(os.path.sep)[-6],
+ )
+ ]
diff --git a/pyeeglab/dataset/tuh_eeg/artifact_dataset.py b/pyeeglab/dataset/tuh_eeg/artifact_dataset.py
new file mode 100644
index 0000000..23ee37e
--- /dev/null
+++ b/pyeeglab/dataset/tuh_eeg/artifact_dataset.py
@@ -0,0 +1,51 @@
+import os
+import logging
+
+from typing import List
+
+from .utils import rsync, parse_tse
+
+from ..dataset import Dataset
+from ..file import File
+from ..metadata import Metadata
+from ..annotation import Annotation
+
+
+class TUHEEGArtifactDataset(Dataset):
+
+ def __init__(
+ self,
+ path: str = "./data/tuh_eeg_artifact/",
+ version: str = "1.0.0",
+ exclude_channels_set: List[str] = [
+ "BURSTS",
+ "EMG-REF",
+ "IBI",
+ "PHOTIC-REF",
+ "SUPPR"
+ ],
+ exclude_channels_reference: List[str] = [
+ "02_tcp_le",
+ "03_tcp_ar_a"
+ ],
+ ) -> None:
+ super().__init__(
+ path=path,
+ name="Temple University Hospital EEG Artifact Dataset",
+ version="v"+version,
+ exclude_channels_set=exclude_channels_set,
+ exclude_channels_reference=exclude_channels_reference,
+ )
+
+ def download(self, user: str = None, password: str = None) -> None:
+ rsync(self.path, user, password, "tuh_eeg_artifact", self.version)
+
+ def _get_metadata(self, file: File) -> Metadata:
+ meta = file.path.split(os.path.sep)
+ metadata = super()._get_metadata(file)
+ metadata.channels_reference = meta[-5]
+ return metadata
+
+ def _get_annotation(self, file: File) -> List[Annotation]:
+ logging.debug("Add file %s annotations to index", file.uuid)
+ return parse_tse(file)
diff --git a/pyeeglab/dataset/tuh_eeg/seizure_dataset.py b/pyeeglab/dataset/tuh_eeg/seizure_dataset.py
new file mode 100644
index 0000000..593ef3d
--- /dev/null
+++ b/pyeeglab/dataset/tuh_eeg/seizure_dataset.py
@@ -0,0 +1,57 @@
+import os
+import re
+import logging
+
+from uuid import uuid4
+
+from typing import List
+
+from .utils import rsync, parse_lbl
+
+from ..dataset import Dataset
+from ..file import File
+from ..metadata import Metadata
+from ..annotation import Annotation
+
+
+class TUHEEGSeizureDataset(Dataset):
+
+ def __init__(
+ self,
+ path: str = "./data/tuh_eeg_seizure/",
+ version: str = "1.5.2",
+ exclude_channels_set: List[str] = [
+ "BURSTS",
+ "ECG EKG-REF",
+ "EMG-REF",
+ "IBI",
+ "PHOTIC-REF",
+ "PULSE RATE",
+ "RESP ABDOMEN-RE",
+ "SUPPR"
+ ],
+ exclude_channels_reference: List[str] = [
+ "02_tcp_le",
+ "03_tcp_ar_a"
+ ],
+ ) -> None:
+ super().__init__(
+ path=path,
+ name="Temple University Hospital EEG Seizure Dataset",
+ version="v"+version,
+ exclude_channels_set=exclude_channels_set,
+ exclude_channels_reference=exclude_channels_reference,
+ )
+
+ def download(self, user: str = None, password: str = None) -> None:
+ rsync(self.path, user, password, "tuh_eeg_seizure", self.version)
+
+ def _get_metadata(self, file: File) -> Metadata:
+ meta = file.path.split(os.path.sep)
+ metadata = super()._get_metadata(file)
+ metadata.channels_reference = meta[-5]
+ return metadata
+
+ def _get_annotation(self, file: File) -> List[Annotation]:
+ logging.debug("Add file %s annotations to index", file.uuid)
+ return parse_lbl(file)
diff --git a/pyeeglab/dataset/tuh_eeg/utils.py b/pyeeglab/dataset/tuh_eeg/utils.py
new file mode 100644
index 0000000..d4a1d4c
--- /dev/null
+++ b/pyeeglab/dataset/tuh_eeg/utils.py
@@ -0,0 +1,96 @@
+import re
+import logging
+import subprocess
+
+from uuid import uuid4
+
+from typing import List
+
+from ..file import File
+from ..annotation import Annotation
+
+
+def rsync(path: str, user: str, password: str, slug: str, version: str) -> None:
+ if user is not None and password is not None:
+ logging.info("Download started, it will take some time")
+ url = user + "@" + "www.isip.piconepress.com:~/data/"
+ url = url + slug + "/v" + version + "/"
+ process = subprocess.Popen(
+ [
+ "sshpass",
+ "-p",
+ password,
+ "rsync",
+ "-auxvL",
+ url,
+ path
+ ],
+ stdout=subprocess.PIPE,
+ universal_newlines=True,
+ )
+ while True:
+ output = process.stdout.readline()
+ logging.info(output.strip())
+ if process.poll() is not None:
+ for output in process.stdout.readlines():
+ logging.info(output.strip())
+ break
+ else:
+ logging.warn("Download disabled, add 'user' and 'password' as optional parameters during data set creation")
+ logging.warn("Request your login credentials at: https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php")
+
+
+def parse_lbl(file: File) -> List[Annotation]:
+ path = file.path[:-4] + ".lbl"
+ with open(path, "r") as reader:
+ annotations = reader.read()
+ symbols = re.compile(
+ r"^symbols\[0\] = ({.*})$",
+ re.MULTILINE
+ )
+ symbols = re.findall(symbols, annotations)
+ symbols = eval(symbols[0])
+ pattern = re.compile(
+ r"^label = {([^,]*), ([^,]*), ([^,]*), ([^,]*), ([^,]*), ([^}]*)}$",
+ re.MULTILINE
+ )
+ annotations = re.findall(pattern, annotations)
+ annotations = {
+ (annotation[2], annotation[3], symbols[index])
+ for annotation in annotations
+ for index, value in enumerate(eval(annotation[5]))
+ if value > 0
+ }
+ annotations = [
+ Annotation(
+ uuid=str(uuid4()),
+ file_uuid=file.uuid,
+ begin=float(annotation[0]),
+ end=float(annotation[1]),
+ label=annotation[2]
+ )
+ for annotation in annotations
+ ]
+ return annotations
+
+
+def parse_tse(file: File) -> List[Annotation]:
+ path = file.path[:-4] + ".tse"
+ with open(path, "r") as reader:
+ annotations = reader.read()
+ pattern = re.compile(
+ r"^(\d+.\d+) (\d+.\d+) (\w+) (\d.\d+)$",
+ re.MULTILINE
+ )
+ annotations = re.findall(pattern, annotations)
+ annotations = [
+ Annotation(
+ uuid=str(uuid4()),
+ file_uuid=file.uuid,
+ begin=float(annotation[0]),
+ end=float(annotation[1]),
+ label=annotation[2]
+ )
+ for annotation in annotations
+ ]
+ return annotations
diff --git a/pyeeglab/dataset/tuh_eeg_abnormal/__init__.py b/pyeeglab/dataset/tuh_eeg_abnormal/__init__.py
deleted file mode 100644
index 926acfd..0000000
--- a/pyeeglab/dataset/tuh_eeg_abnormal/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .tuh_eeg_abnormal_loader import TUHEEGAbnormalLoader
-from .tuh_eeg_abnormal_dataset import TUHEEGAbnormalDataset
diff --git a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_dataset.py b/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_dataset.py
deleted file mode 100644
index dcad1d8..0000000
--- a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_dataset.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from typing import Dict
-
-from .tuh_eeg_abnormal_loader import TUHEEGAbnormalLoader
-from ..dataset import Dataset
-
-
-class TUHEEGAbnormalDataset(Dataset):
-
- def __init__(self, path: str = './data/tuh_eeg_abnormal/v2.0.0/edf/') -> None:
- super().__init__(TUHEEGAbnormalLoader(path))
-
- def _get_dataset_env(self) -> Dict:
- env = super()._get_dataset_env()
- blacklist = ['IBI', 'BURSTS', 'STI 014', 'SUPPR']
- env['channels_set'] = list(set(env['channels_set']) - set(blacklist))
- env['class_id'] = 'normal'
- return env
diff --git a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_index.py b/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_index.py
deleted file mode 100644
index 58ee4ba..0000000
--- a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_index.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import logging
-import json
-
-from typing import List
-
-from uuid import uuid4, uuid5, NAMESPACE_X500
-from os.path import join, sep
-
-from ...io import Index, Raw
-from ...database import File, Metadata, Event
-
-
-class TUHEEGAbnormalIndex(Index):
- def __init__(self, path: str) -> None:
- logging.debug('Create TUH EEG Corpus Index')
- super().__init__('sqlite:///' + join(path, 'index.db'), path)
- self.index()
-
- def _get_file(self, path: str) -> File:
- length = len(self.path)
- meta = path[length:].split(sep)
- return File(
- id=str(uuid5(NAMESPACE_X500, path[length:])),
- extension=meta[-1].split('.')[-1],
- path=path[length:]
- )
-
- def _get_record_metadata(self, file: File) -> Metadata:
- meta = file.path.split(sep)
- metadata = super()._get_record_metadata(file)
- metadata.channels_reference = meta[2]
- return metadata
-
- def _get_record_events(self, file: File) -> List[Event]:
- logging.debug('Add file %s raw events to index', file.id)
- raw = Raw(file.id, join(self.path, file.path))
- return [
- Event(
- id=str(uuid4()),
- file_id=raw.id,
- begin=60,
- end=120,
- duration=60,
- label=raw.path.split(sep)[-6]
- )
- ]
diff --git a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_loader.py b/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_loader.py
deleted file mode 100644
index 2adabcc..0000000
--- a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_loader.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import logging
-
-from ...io import DataLoader
-from .tuh_eeg_abnormal_index import TUHEEGAbnormalIndex
-
-from typing import List
-
-
-class TUHEEGAbnormalLoader(DataLoader):
-
- def __init__(self, path: str) -> None:
- super().__init__(path)
- logging.debug('Create TUH EEG Corpus Loader')
- self.index = TUHEEGAbnormalIndex(self.path)
diff --git a/pyeeglab/dataset/tuh_eeg_artifact/__init__.py b/pyeeglab/dataset/tuh_eeg_artifact/__init__.py
deleted file mode 100644
index e61f493..0000000
--- a/pyeeglab/dataset/tuh_eeg_artifact/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .tuh_eeg_artifact_loader import TUHEEGArtifactLoader
-from .tuh_eeg_artifact_dataset import TUHEEGArtifactDataset
diff --git a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_dataset.py b/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_dataset.py
deleted file mode 100644
index b32af2f..0000000
--- a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_dataset.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from typing import Dict
-
-from .tuh_eeg_artifact_loader import TUHEEGArtifactLoader
-from ..dataset import Dataset
-
-
-class TUHEEGArtifactDataset(Dataset):
-
- def __init__(self, path: str = './data/tuh_eeg_artifact/v1.0.0/edf/') -> None:
- super().__init__(TUHEEGArtifactLoader(path))
-
- def _get_dataset_env(self) -> Dict:
- env = super()._get_dataset_env()
- env['class_id'] = 'null'
- return env
diff --git a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_index.py b/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_index.py
deleted file mode 100644
index b9c74d3..0000000
--- a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_index.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import logging
-import re
-import json
-
-from typing import List
-
-from uuid import uuid4, uuid5, NAMESPACE_X500
-from os.path import join, sep
-
-from ...io import Index, Raw
-from ...database import File, Metadata, Event
-
-
-class TUHEEGArtifactIndex(Index):
-
- def __init__(self, path: str) -> None:
- logging.debug('Create TUH EEG Corpus Index')
- super().__init__('sqlite:///' + join(path, 'index.db'), path)
- self.index()
-
- def _get_file(self, path: str) -> File:
- length = len(self.path)
- meta = path[length:].split(sep)
- return File(
- id=str(uuid5(NAMESPACE_X500, path[length:])),
- extension=meta[-1].split('.')[-1],
- path=path[length:]
- )
-
- def _get_record_metadata(self, file: File) -> Metadata:
- meta = file.path.split(sep)
- metadata = super()._get_record_metadata(file)
- metadata.channels_reference = meta[0]
- return metadata
-
- def _get_record_events(self, file: File) -> List[Event]:
- logging.debug('Add file %s raw events to index', file.id)
- raw = Raw(file.id, join(self.path, file.path))
- path = raw.path[:-4] + '.tse'
- with open(path, 'r') as file:
- annotations = file.read()
- pattern = re.compile(
- r'^(\d+.\d+) (\d+.\d+) (\w+) (\d.\d+)$', re.MULTILINE)
- events = re.findall(pattern, annotations)
- events = [
- Event(
- id=str(uuid4()),
- file_id=raw.id,
- begin=float(event[0]),
- end=float(event[1]),
- duration=(float(event[1])-float(event[0])),
- label=event[2]
- )
- for event in events
- ]
- return events
diff --git a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_loader.py b/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_loader.py
deleted file mode 100644
index e09e7ac..0000000
--- a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_loader.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import logging
-
-from typing import List
-
-from ...io import DataLoader
-from .tuh_eeg_artifact_index import TUHEEGArtifactIndex
-
-
-class TUHEEGArtifactLoader(DataLoader):
-
- def __init__(self, path: str, exclude_channels_reference: List[str] = ['02_tcp_le', '02_tcp_ar_a']) -> None:
- super().__init__(path, exclude_channels_reference=exclude_channels_reference)
- logging.debug('Create TUH EEG Corpus Loader')
- self.index = TUHEEGArtifactIndex(self.path)
diff --git a/pyeeglab/dataset/tuh_eeg_seizure/__init__.py b/pyeeglab/dataset/tuh_eeg_seizure/__init__.py
deleted file mode 100644
index abf1a6c..0000000
--- a/pyeeglab/dataset/tuh_eeg_seizure/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .tuh_eeg_seizure_loader import TUHEEGSeizureLoader
-from .tuh_eeg_seizure_dataset import TUHEEGSeizureDataset
diff --git a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_dataset.py b/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_dataset.py
deleted file mode 100644
index 8f391cd..0000000
--- a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_dataset.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from typing import Dict
-
-from .tuh_eeg_seizure_loader import TUHEEGSeizureLoader
-from ..dataset import Dataset
-
-
-class TUHEEGSeizureDataset(Dataset):
-
- def __init__(self, path: str = './data/tuh_eeg_seizure/v1.5.2/edf/') -> None:
- super().__init__(TUHEEGSeizureLoader(path))
-
- def _get_dataset_env(self) -> Dict:
- env = super()._get_dataset_env()
- env['class_id'] = 'fnsz'
- return env
diff --git a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_index.py b/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_index.py
deleted file mode 100644
index 1d243d2..0000000
--- a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_index.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import logging
-import re
-import json
-import numpy as np
-
-from typing import List
-
-from uuid import uuid4, uuid5, NAMESPACE_X500
-from os.path import join, sep
-
-from ...io import Index, Raw
-from ...database import File, Metadata, Event
-
-
-class TUHEEGSeizureIndex(Index):
-
- def __init__(self, path: str, exclude_events: List = ['bckg', 'spsz']) -> None:
- logging.debug('Create TUH EEG Corpus Index')
- super().__init__('sqlite:///' + join(path, 'index.db'), path, exclude_events=exclude_events)
- self.index()
-
- def _get_file(self, path: str) -> File:
- length = len(self.path)
- meta = path[length:].split(sep)
- return File(
- id=str(uuid5(NAMESPACE_X500, path[length:])),
- extension=meta[-1].split('.')[-1],
- path=path[length:]
- )
-
- def _get_record_metadata(self, file: File) -> Metadata:
- meta = file.path.split(sep)
- metadata = super()._get_record_metadata(file)
- metadata.channels_reference = meta[1]
- return metadata
-
- def _get_record_events(self, file: File) -> List[Event]:
- logging.debug('Add file %s raw events to index', file.id)
- raw = Raw(file.id, join(self.path, file.path))
- path = raw.path[:-4] + '.lbl'
- with open(path, 'r') as file:
- annotations = file.read()
- pattern = re.compile(
- r'^symbols\[0\] = ({.*})$',
- re.MULTILINE
- )
- mapper = re.findall(pattern, annotations)
- mapper = eval(mapper[0])
- events = re.findall(pattern, annotations)
- pattern = re.compile(
- r'^label = {([^,]*), ([^,]*), ([^,]*), ([^,]*), ([^,]*), ([^}]*)}$',
- re.MULTILINE
- )
- events = re.findall(pattern, annotations)
- labels = {}
- for event in events:
- intervall = (event[2], event[3])
- if intervall not in labels:
- labels[intervall] = {}
- channel = event[4]
- if channel not in labels[intervall]:
- labels[intervall][channel] = []
- label = np.array(json.loads(event[5]))
- for i in list(np.nonzero(label)[0]):
- labels[intervall][channel].append(mapper[i])
- labels = {
- (intervall[0], intervall[1], label)
- for intervall, channels in labels.items()
- for channel, labels in channels.items()
- for label in labels
- }
- events = [
- Event(
- id=str(uuid4()),
- file_id=raw.id,
- begin=float(label[0]),
- end=float(label[1]),
- duration=(float(label[1])-float(label[0])),
- label=label[2]
- )
- for label in labels
- ]
- return events
diff --git a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_loader.py b/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_loader.py
deleted file mode 100644
index 5f2627a..0000000
--- a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_loader.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import logging
-
-from typing import List
-
-from ...io import DataLoader
-from .tuh_eeg_seizure_index import TUHEEGSeizureIndex
-
-
-class TUHEEGSeizureLoader(DataLoader):
-
- def __init__(self, path: str, exclude_channels_reference: List[str] = ['02_tcp_le', '02_tcp_ar_a']) -> None:
- super().__init__(path, exclude_channels_reference=exclude_channels_reference)
- logging.debug('Create TUH EEG Corpus Loader')
- self.index = TUHEEGSeizureIndex(self.path)
diff --git a/pyeeglab/io/__init__.py b/pyeeglab/io/__init__.py
deleted file mode 100644
index 12695ed..0000000
--- a/pyeeglab/io/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .index import Index
-from .loader import DataLoader
-from .raw import Raw
diff --git a/pyeeglab/io/index.py b/pyeeglab/io/index.py
deleted file mode 100644
index 9b13638..0000000
--- a/pyeeglab/io/index.py
+++ /dev/null
@@ -1,139 +0,0 @@
-import logging
-import json
-
-from abc import ABC, abstractmethod
-
-from os import walk
-from os.path import join, splitext
-from uuid import uuid4
-from multiprocessing import Pool, cpu_count
-from sqlalchemy import create_engine
-from sqlalchemy.orm import sessionmaker
-from mne import set_log_file
-
-from ..database import BASE_TABLE, File, Metadata, Event
-from .raw import Raw
-
-from typing import List
-
-
-class Index(ABC):
- """ An abstract class representing the Index mechanism that is used to
- discover the dataset structure.
-
- It is the first component that must be implemented in order to provide
- full support to a specific dataset. If the dataset structure is regular,
- the only method that must be implemented id "_get_file(path)".
- """
-
- def __init__(self, db: str, path: str, include_extensions: List[str] = ['edf'], exclude_events: List[str] = None) -> None:
- """
- Parameters
- ----------
- db : str
- The database connection handle expressed as string. This is usually
- configured by the Loader class as a sqlite handle, but in theory it
- could be used with any type of database connection.
-
- """
- logging.debug('Create index at %s', db)
- logging.debug('Load index at %s', db)
- engine = create_engine(db)
- BASE_TABLE.metadata.create_all(engine)
- self.db = sessionmaker(bind=engine)()
- self.path = path
- self.include_extensions = include_extensions
- self.exclude_events = exclude_events
- logging.debug('Redirect MNE logging interface to file')
- set_log_file(join(path, 'mne.log'), overwrite=False)
-
- def __getstate__(self):
- # Workaround for unpickable sqlalchemy.orm.session
- # during multiprocess dataset loading
- state = self.__dict__.copy()
- del state['db']
- return state
-
- def _get_paths(self, exclude_extensions: List[str] = ['.db', '.gz', '.log']) -> List[str]:
- logging.debug('Get files from path')
- paths = [
- join(dirpath, filename)
- for dirpath, _, filenames in walk(self.path)
- for filename in filenames
- if splitext(filename)[1] not in exclude_extensions
- ]
- return paths
-
- @abstractmethod
- def _get_file(self, path: str) -> File:
- pass
-
- def _get_files(self, paths: List[str]) -> List[File]:
- pool = Pool(cpu_count())
- files = pool.map(self._get_file, paths)
- pool.close()
- pool.join()
- return files
-
- def _get_record_metadata(self, file: File) -> Metadata:
- logging.debug('Add file %s raw metadata to index', file.id)
- raw = Raw(file.id, join(self.path, file.path))
- return Metadata(
- file_id=raw.id,
- duration=raw.open().n_times/raw.open().info['sfreq'],
- channels_count=raw.open().info['nchan'],
- channels_set=json.dumps(raw.open().info['ch_names']),
- sampling_frequency=raw.open().info['sfreq'],
- max_value=raw.open().get_data().max(),
- min_value=raw.open().get_data().min(),
- )
-
- def _parallel_record_metadata(self, files: List[File]) -> List[Metadata]:
- pool = Pool(cpu_count())
- metadata = pool.map(self._get_record_metadata, files)
- pool.close()
- pool.join()
- return metadata
-
- def _get_record_events(self, file: File) -> List[Event]:
- logging.debug('Add file %s raw events to index', file.id)
- raw = Raw(file.id, join(self.path, file.path))
- events = raw.get_events()
- for event in events:
- event['id'] = str(uuid4())
- event['file_id'] = raw.id
- events = [Event(**event) for event in events]
- return events
-
- def _parallel_record_events(self, files: List[File]) -> List[Event]:
- pool = Pool(cpu_count())
- events = pool.map(self._get_record_events, files)
- pool.close()
- pool.join()
- events = [e for event in events for e in event]
- if self.exclude_events:
- events = [
- e for e in events
- if e.label not in self.exclude_events
- ]
- return events
-
- def index(self) -> None:
- logging.debug('Index files')
- files = self._get_paths()
- files = self._get_files(files)
- files = [
- file
- for file in files
- if not self.db.query(File).filter(File.id == file.id).all()
- ]
- for file in files:
- logging.debug('Add file %s raw to index', file.id)
- if self.include_extensions:
- raws = [
- file for file in files if file.extension in self.include_extensions]
- metadata = self._parallel_record_metadata(raws)
- events = self._parallel_record_events(raws)
- self.db.add_all(files + metadata + events)
- self.db.commit()
- logging.debug('Index files completed')
diff --git a/pyeeglab/io/loader.py b/pyeeglab/io/loader.py
deleted file mode 100644
index 73e732a..0000000
--- a/pyeeglab/io/loader.py
+++ /dev/null
@@ -1,153 +0,0 @@
-import logging
-import json
-
-from abc import ABC
-from os.path import isfile, join, sep
-from hashlib import md5
-from multiprocessing import Pool, cpu_count
-
-from ..database import File, Metadata, Event
-from .raw import Raw
-from .index import Index
-
-from typing import List, Dict
-
-
-class DataLoader(ABC):
-
- index: Index
-
- def __init__(self, path: str, exclude_channels_reference: List[str] = None, exclude_frequency: List[int] = None,
- exclude_files: List[str] = None, minimum_event_duration: float = -1) -> None:
- logging.debug('Create data loader')
- if path[-1] != sep:
- path = path + sep
- self.path = path
- self.exclude_channels_reference = exclude_channels_reference
- self.exclude_frequency = exclude_frequency
- self.exclude_files = exclude_files
- self.minimum_event_duration = minimum_event_duration
-
- def __getstate__(self):
- # Workaround for unpickable sqlalchemy.orm.session
- # during multiprocess dataset loading
- state = self.__dict__.copy()
- del state['index']
- return state
-
- def _get_data_by_event(self, f: File, e: Event) -> Raw:
- path_edf = join(self.path, f.path)
- path_fif = path_edf + '_' + e.id + '_raw.fif.gz'
- if not isfile(path_fif):
- edf = Raw(f.id, path_edf, e.label)
- edf.crop(e.begin, e.end-e.begin)
- edf.open().save(path_fif)
- fif = Raw(f.id, path_fif, e.label)
- return fif
-
- def set_minimum_event_duration(self, minimum_event_duration: float) -> 'DataLoader':
- self.minimum_event_duration = minimum_event_duration
-
- def get_dataset(self) -> List[Raw]:
- files = self.index.db.query(File, Metadata, Event)
- files = files.filter(File.id == Metadata.file_id)
- files = files.filter(File.id == Event.file_id)
- if self.index.include_extensions:
- files = files.filter(File.extension.in_(self.index.include_extensions))
- if self.exclude_channels_reference:
- files = files.filter(~Metadata.channels_reference.in_(self.exclude_channels_reference))
- if self.exclude_frequency:
- files = files.filter(~Metadata.sampling_frequency.in_(self.exclude_frequency))
- if self.exclude_files:
- files = files.filter(~File.path.in_(self.exclude_files))
- if self.minimum_event_duration > 0:
- files = files.filter(Event.duration >= self.minimum_event_duration)
- files = files.all()
- files = [(file[0], file[2]) for file in files]
- pool = Pool(cpu_count())
- fifs = pool.starmap(self._get_data_by_event, files)
- pool.close()
- pool.join()
- return fifs
-
- def get_dataset_text(self) -> Dict:
- txts = self.index.db.query(File, Event)
- txts = txts.filter(File.id == Event.file_id)
- txts = txts.filter(File.extension == 'txt')
- txts = txts.all()
- txts = {f.id: (join(self.index.path, f.path), e.label) for f, e in txts}
- return txts
-
- def get_channels_set(self) -> List[str]:
- files = self.index.db.query(File, Metadata)
- files = files.filter(File.id == Metadata.file_id)
- if self.exclude_channels_reference:
- files = files.filter(~Metadata.channels_reference.in_(self.exclude_channels_reference))
- if self.exclude_frequency:
- files = files.filter(~Metadata.sampling_frequency.in_(self.exclude_frequency))
- if self.exclude_files:
- files = files.filter(~File.path.in_(self.exclude_files))
- if self.minimum_event_duration > 0:
- files = files.filter(Event.duration >= self.minimum_event_duration)
- files = files.group_by(Metadata.channels_set)
- files = files.all()
- files = [file[1] for file in files]
- files = [set(json.loads(file.channels_set)) for file in files]
- channels_set = files[0]
- for file in files[1:]:
- channels_set = channels_set.intersection(file)
- return sorted(channels_set)
-
- def get_lowest_frequency(self) -> float:
- frequency = self.index.db.query(Metadata)
- if self.exclude_frequency:
- frequency = frequency.filter(~Metadata.sampling_frequency.in_(self.exclude_frequency))
- frequency = frequency.all()
- frequency = min([f.sampling_frequency for f in frequency], default=0)
- return frequency
-
- def get_max_value(self) -> float:
- files = self.index.db.query(File, Metadata)
- files = files.filter(File.id == Metadata.file_id)
- if self.exclude_channels_reference:
- files = files.filter(~Metadata.channels_reference.in_(self.exclude_channels_reference))
- if self.exclude_frequency:
- files = files.filter(~Metadata.sampling_frequency.in_(self.exclude_frequency))
- if self.exclude_files:
- files = files.filter(~File.path.in_(self.exclude_files))
- if self.minimum_event_duration > 0:
- files = files.filter(Event.duration >= self.minimum_event_duration)
- files = files.group_by(Metadata.channels_set)
- files = files.all()
- max_value = max([f.max_value for _, f in files], default=0)
- return max_value
-
- def get_min_value(self) -> float:
- files = self.index.db.query(File, Metadata)
- files = files.filter(File.id == Metadata.file_id)
- if self.exclude_channels_reference:
- files = files.filter(~Metadata.channels_reference.in_(self.exclude_channels_reference))
- if self.exclude_frequency:
- files = files.filter(~Metadata.sampling_frequency.in_(self.exclude_frequency))
- if self.exclude_files:
- files = files.filter(~File.path.in_(self.exclude_files))
- if self.minimum_event_duration > 0:
- files = files.filter(Event.duration >= self.minimum_event_duration)
- files = files.group_by(Metadata.channels_set)
- files = files.all()
- min_value = min([f.min_value for _, f in files], default=0)
- return min_value
-
- def __eq__(self, other):
- return hash(self) == hash(other)
-
- def __hash__(self):
- value = [self.path] + [self.minimum_event_duration]
- if self.exclude_channels_reference:
- value += self.exclude_channels_reference
- if self.exclude_frequency:
- value += self.exclude_frequency
- value = json.dumps(value).encode()
- value = md5(value).hexdigest()
- value = int(value, 16)
- return value
diff --git a/pyeeglab/io/raw.py b/pyeeglab/io/raw.py
deleted file mode 100644
index 75a41ff..0000000
--- a/pyeeglab/io/raw.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import logging
-from typing import List
-from importlib.util import find_spec
-
-from mne.io import Raw as Reader
-from mne.io import read_raw_edf, read_raw_fif
-
-class Raw():
-
- reader: Reader = None
-
- def __init__(self, fid: str, path: str, label: str = None) -> None:
- self.id = fid
- self.path = path
- self.label = label
- n_jobs = 1
- if find_spec('cupy') is not None:
- n_jobs = 'cuda'
- self.n_jobs = n_jobs
-
- def close(self) -> 'Raw':
- if self.reader is not None:
- logging.debug('Close Raw %s reader', self.id)
- self.reader.close()
- self.reader = None
- return self
-
- def crop(self, offset: int, length: int) -> 'Raw':
- logging.debug('Crop %s data to %s seconds from %s', self.id, length, offset)
- tmax = self.open().n_times / self.open().info['sfreq'] - 0.1
- if offset + length < tmax:
- tmax = offset + length
- self.reader = self.open().crop(offset, tmax)
- return self
-
- def open(self) -> Reader:
- if self.reader is None:
- if self.path.endswith('.edf'):
- logging.debug('Open RawEDF %s reader', self.id)
- try:
- self.reader = read_raw_edf(self.path)
- except RuntimeError:
- logging.debug('Preload RawEDF %s reader', self.id)
- self.reader = read_raw_edf(self.path, preload=True)
- if self.path.endswith('.fif.gz'):
- logging.debug('Open RawFIF %s reader', self.id)
- self.reader = read_raw_fif(self.path)
- return self.reader
-
- def get_events(self) -> List:
- events = self.open().annotations
- events = list(zip(events.onset, events.duration, events.description))
- events = [(event[0], event[0] + event[1], event[1], event[2]) for event in events]
- keys = ['begin', 'end', 'duration', 'label']
- events = [dict(zip(keys, event)) for event in events]
- return events
-
- def set_channels(self, channels: List[str]) -> 'Raw':
- channels = set(self.open().ch_names) - set(channels)
- channels = list(channels)
- if len(channels) > 0:
- logging.debug('Drop %s channels %s', self.id, '|'.join(channels))
- self.reader = self.open().drop_channels(channels)
- return self
-
- def set_frequency(self, frequency: float) -> 'Raw':
- sfreq = self.open().info['sfreq']
- if sfreq > frequency:
- logging.debug('Downsample %s from %s Hz to %s Hz', self.id, sfreq, frequency)
- self.reader = self.open().resample(frequency, n_jobs=self.n_jobs)
- return self
-
- def set_filter(self, low_freq: float = None, high_freq: float = None) -> 'Raw':
- if low_freq is not None or high_freq is not None:
- logging.debug('Filter %s with low_req: %s Hz and high_freq: %s Hz', self.id, low_freq, high_freq)
- self.reader = self.open().filter(low_freq, high_freq, n_jobs=self.n_jobs)
- return self
-
- def notch_filter(self, freq: float) -> 'Raw':
- if freq is not None:
- logging.debug('Notch filter %s with freq: %s Hz ', self.id, freq)
- self.reader = self.open().notch_filter(freq, n_jobs=self.n_jobs)
- return self
diff --git a/pyeeglab/pipeline/pipeline.py b/pyeeglab/pipeline/pipeline.py
index f778b2e..ec082ff 100644
--- a/pyeeglab/pipeline/pipeline.py
+++ b/pyeeglab/pipeline/pipeline.py
@@ -1,18 +1,17 @@
-import logging
import json
-
-import numpy as np
-import pandas as pd
+import logging
from hashlib import md5
from os.path import join
from multiprocessing import Pool, cpu_count
-from ..io import Raw
-from .preprocessor import Preprocessor
-
from typing import Dict, List
+import numpy as np
+import pandas as pd
+
+from .preprocessor import Preprocessor
+
class Pipeline():
@@ -32,11 +31,13 @@ def _check_nans(self, data):
nans = data.isnull().values.any()
return nans
- def _trigger_pipeline(self, data: Raw, kwargs):
- file_id = data.id
- data.open().load_data()
- for preprocessor in self.pipeline:
- data = preprocessor.run(data, **kwargs)
+ def _trigger_pipeline(self, annotation, kwargs):
+ data = None
+
+ with annotation as reader:
+ data = reader.load_data()
+ for preprocessor in self.pipeline:
+ data = preprocessor.run(data, **kwargs)
nans = False
if isinstance(data, list):
@@ -44,11 +45,11 @@ def _trigger_pipeline(self, data: Raw, kwargs):
else:
nans = self._check_nans(data)
if nans:
- raise ValueError('Nans found in file with id {}'.format(file_id))
+ raise ValueError('Nans found in file with id {}'.format(annotation.file_uuid))
return data
- def run(self, data: List[Raw]) -> Dict:
+ def run(self, data) -> Dict:
logging.debug('Environment variables: {}'.format(
str(self.environment)
))
@@ -62,7 +63,7 @@ def run(self, data: List[Raw]) -> Dict:
labels = [self.labels_mapping[label] for label in labels]
onehot_encoder = sorted(set(labels))
class_id = self.environment.get('class_id', None)
- if class_id:
+ if class_id in onehot_encoder:
onehot_encoder.remove(class_id)
onehot_encoder = [class_id] + onehot_encoder
labels = np.array([onehot_encoder.index(label) for label in labels])
diff --git a/pyeeglab/preprocess/features/brain_connectivity.py b/pyeeglab/preprocess/features/brain_connectivity.py
index 96b96ac..ff3af4c 100644
--- a/pyeeglab/preprocess/features/brain_connectivity.py
+++ b/pyeeglab/preprocess/features/brain_connectivity.py
@@ -1,16 +1,15 @@
-import logging
import json
+import logging
+
+from itertools import product
+from typing import List
import numpy as np
import pandas as pd
-
from yasa import bandpower
-from itertools import product
from ...pipeline import Preprocessor
-from typing import List
-
class SpearmanCorrelation(Preprocessor):
diff --git a/pyeeglab/preprocess/features/stat_features.py b/pyeeglab/preprocess/features/stat_features.py
index 234ca03..5121fec 100644
--- a/pyeeglab/preprocess/features/stat_features.py
+++ b/pyeeglab/preprocess/features/stat_features.py
@@ -1,14 +1,11 @@
-import logging
+from typing import List
import numpy as np
import pandas as pd
-
from scipy.integrate import simps
from ...pipeline import Preprocessor
-from typing import List
-
class Mean(Preprocessor):
def run(self, data: List[pd.DataFrame], **kwargs) -> List[pd.DataFrame]:
diff --git a/pyeeglab/preprocess/signal/channel_selector.py b/pyeeglab/preprocess/signal/channel_selector.py
index 124ffaf..9f1fb4a 100644
--- a/pyeeglab/preprocess/signal/channel_selector.py
+++ b/pyeeglab/preprocess/signal/channel_selector.py
@@ -1,18 +1,18 @@
-import logging
import json
-
-from ...io import Raw
-from ...pipeline import Preprocessor
+import logging
from typing import List
+from mne.io import Raw
+
+from ...pipeline import Preprocessor
class CommonChannelSet(Preprocessor):
- def __init__(self, blacklist: List[str] = []) -> None:
+ def __init__(self, blacklist: List[str] = None) -> None:
super().__init__()
logging.debug('Create common channels_set preprocessor')
- self.blacklist = blacklist
+ self.blacklist = blacklist if blacklist else []
def to_json(self) -> str:
out = {
@@ -24,7 +24,9 @@ def to_json(self) -> str:
return out
def run(self, data: Raw, **kwargs) -> Raw:
- channels = set(kwargs['channels_set']) - set(self.blacklist)
- channels = list(channels)
- data.set_channels(channels)
+ channels = set(data.ch_names)
+ channels = channels.difference(set(kwargs['channels_set']))
+ data = data.drop_channels(channels)
+ data = data.drop_channels(self.blacklist)
+ data = data.reorder_channels(kwargs['channels_set'])
return data
diff --git a/pyeeglab/preprocess/signal/filter_selector.py b/pyeeglab/preprocess/signal/filter_selector.py
index 33c7b15..2007338 100644
--- a/pyeeglab/preprocess/signal/filter_selector.py
+++ b/pyeeglab/preprocess/signal/filter_selector.py
@@ -1,7 +1,8 @@
-import logging
import json
+import logging
+
+from mne.io import Raw
-from ...io import Raw
from ...pipeline import Preprocessor
@@ -24,8 +25,7 @@ def to_json(self) -> str:
return out
def run(self, data: Raw, **kwargs) -> Raw:
- data.set_filter(self.low_freq, self.high_freq)
- return data
+ return data.filter(self.low_freq, self.high_freq)
class NotchFrequency(Preprocessor):
@@ -45,5 +45,4 @@ def to_json(self) -> str:
return out
def run(self, data: Raw, **kwargs) -> Raw:
- data.notch_filter(self.freq)
- return data
+ return data.notch_filter(self.freq)
diff --git a/pyeeglab/preprocess/signal/frequency_selector.py b/pyeeglab/preprocess/signal/frequency_selector.py
index d3aca97..fc606c1 100644
--- a/pyeeglab/preprocess/signal/frequency_selector.py
+++ b/pyeeglab/preprocess/signal/frequency_selector.py
@@ -1,6 +1,7 @@
import logging
-from ...io import Raw
+from mne.io import Raw
+
from ...pipeline import Preprocessor
@@ -11,5 +12,7 @@ def __init__(self) -> None:
logging.debug('Create lowest_frequency preprocessor')
def run(self, data: Raw, **kwargs) -> Raw:
- data.set_frequency(kwargs['lowest_frequency'])
+ lowest_frequency = kwargs['lowest_frequency']
+ if data.info['sfreq'] > lowest_frequency:
+ data = data.resample(lowest_frequency)
return data
diff --git a/pyeeglab/preprocess/signal/normalization.py b/pyeeglab/preprocess/signal/normalization.py
index 7670ca2..11b07d4 100644
--- a/pyeeglab/preprocess/signal/normalization.py
+++ b/pyeeglab/preprocess/signal/normalization.py
@@ -1,12 +1,12 @@
import logging
+from typing import List
+
import numpy as np
import pandas as pd
from ...pipeline import Preprocessor
-from typing import List
-
class MinMaxNormalization(Preprocessor):
def run(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:
diff --git a/pyeeglab/preprocess/transform/__init__.py b/pyeeglab/preprocess/transform/__init__.py
index ec636cc..d28947a 100644
--- a/pyeeglab/preprocess/transform/__init__.py
+++ b/pyeeglab/preprocess/transform/__init__.py
@@ -1,3 +1,2 @@
from .data_converter import ToDataframe, ToNumpy, ToMergedDataframes, CorrelationToAdjacency
from .frame_generator import StaticWindow, DynamicWindow, StaticWindowOverlap, DynamicWindowOverlap
-from .graph_generator import GraphGenerator, GraphWithFeatures
diff --git a/pyeeglab/preprocess/transform/data_converter.py b/pyeeglab/preprocess/transform/data_converter.py
index 3e2124f..06f747a 100644
--- a/pyeeglab/preprocess/transform/data_converter.py
+++ b/pyeeglab/preprocess/transform/data_converter.py
@@ -1,13 +1,14 @@
-import logging
import json
+import logging
+
+from typing import List
import numpy as np
import pandas as pd
-from ...io import Raw
-from ...pipeline import Preprocessor
+from mne.io import Raw
-from typing import List
+from ...pipeline import Preprocessor
class ToDataframe(Preprocessor):
@@ -17,8 +18,7 @@ def __init__(self) -> None:
logging.debug('Create DataFrame converter preprocessor')
def run(self, data: Raw, **kwargs) -> pd.DataFrame:
- dataframe = data.open().to_data_frame().drop('time', axis=1)
- data.close()
+ dataframe = data.to_data_frame().drop('time', axis=1)
return dataframe
diff --git a/pyeeglab/preprocess/transform/frame_generator.py b/pyeeglab/preprocess/transform/frame_generator.py
index 9c10a20..9fd9ee1 100644
--- a/pyeeglab/preprocess/transform/frame_generator.py
+++ b/pyeeglab/preprocess/transform/frame_generator.py
@@ -1,13 +1,12 @@
-import logging
import json
-
-import pandas as pd
+import logging
from math import floor
+from typing import List
-from ...pipeline import Preprocessor
+import pandas as pd
-from typing import List
+from ...pipeline import Preprocessor
class StaticWindow(Preprocessor):
diff --git a/pyeeglab/preprocess/transform/graph_generator.py b/pyeeglab/preprocess/transform/graph_generator.py
deleted file mode 100644
index 59339f7..0000000
--- a/pyeeglab/preprocess/transform/graph_generator.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import logging
-import json
-
-import pandas as pd
-import networkx as nx
-from networkx.convert_matrix import from_pandas_edgelist
-
-from ...pipeline import Preprocessor
-
-from typing import List
-
-
-class GraphGenerator(Preprocessor):
-
- def __init__(self) -> None:
- super().__init__()
- logging.debug('Create new graph generator preprocessor')
-
- def run(self, data: List[pd.DataFrame], **kwargs) -> List[nx.Graph]:
- nodes = [set(d.From.to_list() + d.To.to_list()) for d in data]
- edges = [d.where(d.Weight != 0).dropna().reset_index(drop=True) for d in data]
- graphs = []
- for i in range(len(data)):
- graph = from_pandas_edgelist(edges[i], 'From', 'To')
- graph.add_nodes_from(nodes[i])
- graphs.append(graph)
- return graphs
-
-
-class GraphWithFeatures(GraphGenerator):
-
- def __init__(self):
- super().__init__()
- logging.debug('Create new graph with features preprocessor')
-
- def _run(self, adjacency: List[pd.DataFrame], features: List[pd.DataFrame], **kwargs) -> List[nx.Graph]:
- graphs = super().run(adjacency, **kwargs)
- for i, graph in enumerate(graphs):
- feature = {
- node: {'features': features[i].loc[node, :].to_numpy()}
- for node in graph.nodes
- }
- nx.set_node_attributes(graph, feature)
- return graphs
-
- def run(self, data, **kwargs) -> List[nx.Graph]:
- return self._run(*data, **kwargs)
diff --git a/requirements.txt b/requirements.txt
index a02381a..b2628f1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,8 @@
-mne==0.20.0
-networkx>=2.2
numpy
scipy
pandas
-sqlalchemy>=1.2
wfdb
+sqlalchemy
+dataclasses
+mne>=0.21.0
yasa>=0.1.6
diff --git a/setup.py b/setup.py
index a12c589..d6153c7 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@
setuptools.setup(
name='PyEEGLab',
- version='0.9.4',
+ version='0.10.0',
author='Alessio Zanga',
author_email='alessio.zanga@outlook.it',
license='GNU GENERAL PUBLIC LICENSE - Version 3, 29 June 2007',
@@ -32,13 +32,13 @@
url='https://github.com/AlessioZanga/PyEEGLab',
packages=setuptools.find_packages(),
install_requires=[
- 'mne==0.20.0',
- 'networkx>=2.2',
'numpy',
'scipy',
'pandas',
- 'sqlalchemy>=1.2',
'wfdb',
+ 'sqlalchemy',
+ 'dataclasses',
+ 'mne>=0.21.0',
'yasa>=0.1.6',
],
)
diff --git a/tests/test_chbmit.py b/tests/test_chbmit.py
index b4b0f48..69ec621 100644
--- a/tests/test_chbmit.py
+++ b/tests/test_chbmit.py
@@ -3,38 +3,29 @@
import unittest
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
-from pyeeglab import CHBMITLoader, CHBMITDataset, \
- Pipeline, CommonChannelSet, LowestFrequency, BandPassFrequency, ToDataframe, \
- DynamicWindow, ForkedPreprocessor, BinarizedSpearmanCorrelation, \
- CorrelationToAdjacency, Bandpower, GraphWithFeatures
+from pyeeglab import *
class TestCHBMIT(unittest.TestCase):
- PATH = './tests/samples/physionet.org/files/chbmit/1.0.0'
+ PATH = './tests/samples/physionet.org/files/chbmit/'
def test_index(self):
- CHBMITLoader(self.PATH)
+ PhysioNetCHBMITDataset(self.PATH)
def test_loader(self):
- loader = CHBMITLoader(self.PATH)
- loader.get_dataset()
- loader.get_dataset_text()
- loader.get_channels_set()
- loader.get_lowest_frequency()
+ loader = PhysioNetCHBMITDataset(self.PATH)
+ loader.maximal_channels_subset
+ loader.lowest_frequency
+ loader.signal_min_max_range
def test_dataset(self):
- dataset = CHBMITDataset(self.PATH)
+ dataset = PhysioNetCHBMITDataset(self.PATH)
preprocessing = Pipeline([
CommonChannelSet(),
LowestFrequency(),
BandPassFrequency(0.1, 47),
ToDataframe(),
DynamicWindow(4),
- ForkedPreprocessor(
- inputs=[
- [BinarizedSpearmanCorrelation(), CorrelationToAdjacency()],
- Bandpower()
- ],
- output=GraphWithFeatures()
- )
+ Skewness(),
+ ToNumpy()
])
dataset = dataset.set_pipeline(preprocessing).load()
diff --git a/tests/test_eegmmidb.py b/tests/test_eegmmidb.py
index 8b29b66..4d9248e 100644
--- a/tests/test_eegmmidb.py
+++ b/tests/test_eegmmidb.py
@@ -3,38 +3,29 @@
import unittest
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
-from pyeeglab import EEGMMIDBLoader, EEGMMIDBDataset, \
- Pipeline, CommonChannelSet, LowestFrequency, BandPassFrequency, ToDataframe, \
- DynamicWindow, ForkedPreprocessor, BinarizedSpearmanCorrelation, \
- CorrelationToAdjacency, Bandpower, GraphWithFeatures
+from pyeeglab import *
class TestEEGMMIDB(unittest.TestCase):
- PATH = './tests/samples/physionet.org/files/eegmmidb/1.0.0'
+ PATH = './tests/samples/physionet.org/files/eegmmidb/'
def test_index(self):
- EEGMMIDBLoader(self.PATH)
+ PhysioNetEEGMMIDBDataset(self.PATH)
def test_loader(self):
- loader = EEGMMIDBLoader(self.PATH)
- loader.get_dataset()
- loader.get_dataset_text()
- loader.get_channels_set()
- loader.get_lowest_frequency()
+ loader = PhysioNetEEGMMIDBDataset(self.PATH)
+ loader.maximal_channels_subset
+ loader.lowest_frequency
+ loader.signal_min_max_range
def test_dataset(self):
- dataset = EEGMMIDBDataset(self.PATH)
+ dataset = PhysioNetEEGMMIDBDataset(self.PATH)
preprocessing = Pipeline([
CommonChannelSet(),
LowestFrequency(),
BandPassFrequency(0.1, 47),
ToDataframe(),
DynamicWindow(4),
- ForkedPreprocessor(
- inputs=[
- [BinarizedSpearmanCorrelation(), CorrelationToAdjacency()],
- Bandpower()
- ],
- output=GraphWithFeatures()
- )
+ Skewness(),
+ ToNumpy()
])
dataset = dataset.set_pipeline(preprocessing).load()
diff --git a/tests/test_tuh_eeg_abnormal.py b/tests/test_tuh_eeg_abnormal.py
index 159f72a..a9e8bcb 100644
--- a/tests/test_tuh_eeg_abnormal.py
+++ b/tests/test_tuh_eeg_abnormal.py
@@ -3,23 +3,19 @@
import unittest
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
-from pyeeglab import TUHEEGAbnormalLoader, TUHEEGAbnormalDataset, \
- Pipeline, CommonChannelSet, LowestFrequency, BandPassFrequency, ToDataframe, \
- DynamicWindow, ForkedPreprocessor, BinarizedSpearmanCorrelation, \
- CorrelationToAdjacency, Bandpower, GraphWithFeatures
+from pyeeglab import *
class TestTUHEEGAbnormal(unittest.TestCase):
- PATH = './tests/samples/tuh_eeg_abnormal/v2.0.0/edf'
+ PATH = './tests/samples/tuh_eeg_abnormal/'
def test_index(self):
- TUHEEGAbnormalLoader(self.PATH)
+ TUHEEGAbnormalDataset(self.PATH)
def test_loader(self):
- loader = TUHEEGAbnormalLoader(self.PATH)
- loader.get_dataset()
- loader.get_dataset_text()
- loader.get_channels_set()
- loader.get_lowest_frequency()
+ loader = TUHEEGAbnormalDataset(self.PATH)
+ loader.maximal_channels_subset
+ loader.lowest_frequency
+ loader.signal_min_max_range
def test_dataset(self):
dataset = TUHEEGAbnormalDataset(self.PATH)
@@ -29,12 +25,7 @@ def test_dataset(self):
BandPassFrequency(0.1, 47),
ToDataframe(),
DynamicWindow(4),
- ForkedPreprocessor(
- inputs=[
- [BinarizedSpearmanCorrelation(), CorrelationToAdjacency()],
- Bandpower()
- ],
- output=GraphWithFeatures()
- )
+ Skewness(),
+ ToNumpy()
])
dataset = dataset.set_pipeline(preprocessing).load()
diff --git a/tests/test_tuh_eeg_artifact.py b/tests/test_tuh_eeg_artifact.py
index b03ecc1..595ca91 100644
--- a/tests/test_tuh_eeg_artifact.py
+++ b/tests/test_tuh_eeg_artifact.py
@@ -3,40 +3,29 @@
import unittest
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
-from pyeeglab import TUHEEGArtifactLoader, TUHEEGArtifactDataset, \
- Pipeline, CommonChannelSet, LowestFrequency, BandPassFrequency, ToDataframe, \
- DynamicWindow, ForkedPreprocessor, BinarizedSpearmanCorrelation, \
- CorrelationToAdjacency, Bandpower, GraphWithFeatures
+from pyeeglab import *
class TestTUHEEGArtifact(unittest.TestCase):
- PATH = './tests/samples/tuh_eeg_artifact/v1.0.0/edf'
+ PATH = './tests/samples/tuh_eeg_artifact/'
def test_index(self):
- TUHEEGArtifactLoader(self.PATH)
+ TUHEEGArtifactDataset(self.PATH)
def test_loader(self):
- loader = TUHEEGArtifactLoader(self.PATH)
- loader.get_dataset()
- loader.get_dataset_text()
- loader.get_channels_set()
- loader.get_lowest_frequency()
+ loader = TUHEEGArtifactDataset(self.PATH)
+ loader.maximal_channels_subset
+ loader.lowest_frequency
+ loader.signal_min_max_range
def test_dataset(self):
dataset = TUHEEGArtifactDataset(self.PATH)
- """
preprocessing = Pipeline([
- CommonChannelSet(['EEG T1-REF', 'EEG T2-REF']),
+ CommonChannelSet(),
LowestFrequency(),
BandPassFrequency(0.1, 47),
ToDataframe(),
DynamicWindow(4),
- ForkedPreprocessor(
- inputs=[
- [BinarizedSpearmanCorrelation(), CorrelationToAdjacency()],
- Bandpower()
- ],
- output=GraphWithFeatures()
- )
+ Skewness(),
+ ToNumpy()
])
dataset = dataset.set_pipeline(preprocessing).load()
- """