diff --git a/.travis.yml b/.travis.yml index 43400cb..1f9c0d4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,6 +2,7 @@ language: python cache: pip jobs: allow_failures: + - os: windows - os: osx include: - name: "Python 3.6.0 on Xenial Linux" diff --git a/CHANGELOG.md b/CHANGELOG.md index cfbe08d..de57125 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.10.0] - 2020-09-28 +### Added +* Added new dataset versioning system, + you can work with multiple versions + of the same dataset during creation +* Added new in-Python download system, + you can download a specific version + of the dataset during initialization + +### Changed +* Refactored dataset indexing system, + a dedicated directory will be created + for the index database during initialization +* Refactored preprocessing cache system, + a dedicated directory will be created + and a faster version for cache coherence + check has been implemented +* Refactored logging system +* Update package requirements +* Update README + +### Removed +* Deprecated Makefile + + ## [0.9.3] - 2020-07-05 ### Change diff --git a/Makefile b/Makefile deleted file mode 100755 index e9c93da..0000000 --- a/Makefile +++ /dev/null @@ -1,21 +0,0 @@ -.PHONY: tuh_eeg_abnormal tuh_eeg_artifact eegmmidb clean - -tuh_eeg_abnormal: - echo "Request your access password at: https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php" - rsync -auxvL nedc_tuh_eeg@www.isip.piconepress.com:~/data/tuh_eeg_abnormal/ data/tuh_eeg_abnormal - -tuh_eeg_artifact: - echo "Request your access password at: https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php" - rsync -auxvL nedc_tuh_eeg@www.isip.piconepress.com:~/data/tuh_eeg_artifact/ data/tuh_eeg_artifact - -tuh_eeg_seizure: - echo "Request your access password at: https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php" - rsync -auxvL nedc_tuh_eeg@www.isip.piconepress.com:~/data/tuh_eeg_seizure/v1.5.2 data/tuh_eeg_seizure - -eegmmidb: - wget -r -N -c -np https://physionet.org/files/eegmmidb/1.0.0/ -P data - -clean: - find -L data -iname "*.db" -type f -delete - find -L data -iname "*.log" -type f -delete - find -L data -iname "*.fif.gz" -type f -delete diff --git a/README.md b/README.md index 21d59e3..205ed9f 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Here is a simple quickstart: from pyeeglab import * dataset = TUHEEGAbnormalDataset() - pipeline = Pipeline([ + preprocessing = Pipeline([ CommonChannelSet(), LowestFrequency(), ToDataframe(), @@ -35,10 +35,6 @@ PyEEGLab is distributed using the pip repository: pip install PyEEGLab -If you use Python 3.6, the dataclasses package must be installed as backport of Python 3.7 dataclasses: - - pip install dataclasses - If you need a bleeding edge version, you can install it directly from GitHub: pip install git+https://github.com/AlessioZanga/PyEEGLab@develop @@ -57,6 +53,8 @@ The following datasets will work upon downloading: ## How to Class Meaning - From the TUH Seizure docs +
+ | **Class Code** | **Event Name** | **Description** | | -------------- | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------ | | _NULL_ | No Event | An unclassified event | @@ -91,15 +89,24 @@ The following datasets will work upon downloading: | _TRIP_ | Triphasic Wave | Large, three-phase waves frequently caused by an underlying metabolic condition. | | _ELEC_ | Electrode Artifact | Electrode pop, Electrostatic artifacts, Lead artifacts. | +
+ ## How to Get a Dataset > **WARNING**: Retriving the TUH EEG datasets require valid credentials, you can get your own at: https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php. -In the root directory of this project there is a Makefile, by typing: +Given the dataset instance, trigger the download using the "download" method: + + from pyeeglab import * + dataset = TUHEEGAbnormalDataset() + dataset.download(user='USER', password='PASSWORD') + dataset.index() + +then index the new downloaded files. - make tuh_eeg_abnormal +It should be noted that the download mechanism work on Unix-like systems given the following packages: -you will trigger the dataset download. + sudo apt install sshpass rsync wget ## Documentation diff --git a/examples/tensorboard/example_tensorboard.py b/examples/tensorboard/example_tensorboard.py index 448ac8e..de8ae71 100755 --- a/examples/tensorboard/example_tensorboard.py +++ b/examples/tensorboard/example_tensorboard.py @@ -26,7 +26,6 @@ import numpy as np from random import shuffle from itertools import product -from networkx import to_numpy_matrix from sklearn.model_selection import train_test_split from tensorflow.python.keras.utils.np_utils import to_categorical @@ -38,8 +37,6 @@ from pyeeglab import * def build_data(dataset): - dataset.set_cache_manager(PickleCache('../../export')) - preprocessing = Pipeline([ CommonChannelSet(), LowestFrequency(), @@ -96,7 +93,7 @@ def get_correlation_matrix(x, frame, N, F): input_0 = tf.keras.Input((frames, N, F + N)) - gans = [] + layers = [] for frame in range(frames): feature_matrix = tf.keras.layers.Lambda( get_feature_matrix, @@ -110,9 +107,9 @@ def get_correlation_matrix(x, frame, N, F): x = sp.layers.GraphAttention(hparams['output_shape'])([feature_matrix, correlation_matrix]) x = tf.keras.layers.Flatten()(x) - gans.append(x) + layers.append(x) - combine = tf.keras.layers.Concatenate()(gans) + combine = tf.keras.layers.Concatenate()(layers) reshape = tf.keras.layers.Reshape((frames, N * hparams['output_shape']))(combine) lstm = tf.keras.layers.LSTM(hparams['hidden_units'])(reshape) dropout = tf.keras.layers.Dropout(hparams['dropout'])(lstm) @@ -125,20 +122,25 @@ def get_correlation_matrix(x, frame, N, F): metrics=[ 'accuracy', Recall(class_id=0, name='recall'), + Specificity(class_id=0, name='specificity'), Precision(class_id=0, name='precision'), + F1Score(class_id=0, name='f1score'), ] ) model.summary() + model.save('logs/plot_gat.h5') return model def run_trial(path, step, model, hparams, x_train, y_train, x_val, y_val, x_test, y_test, epochs): with tf.summary.create_file_writer(path).as_default(): hp.hparams(hparams) model.fit(x_train, y_train, epochs=epochs, batch_size=32, shuffle=True, validation_data=(x_val, y_val)) - loss, accuracy, recall, precision = model.evaluate(x_test, y_test) + loss, accuracy, recall, specificity, precision, f1score = model.evaluate(x_test, y_test) tf.summary.scalar('accuracy', accuracy, step=step) tf.summary.scalar('recall', recall, step=step) + tf.summary.scalar('specificity', specificity, step=step) tf.summary.scalar('precision', precision, step=step) + tf.summary.scalar('f1score', f1score, step=step) def hparams_combinations(hparams): hp.hparams_config( @@ -146,7 +148,9 @@ def hparams_combinations(hparams): metrics=[ hp.Metric('accuracy', display_name='Accuracy'), hp.Metric('recall', display_name='Recall'), + hp.Metric('specificity', display_name='Specificity'), hp.Metric('precision', display_name='Precision'), + hp.Metric('f1score', display_name='F1Score'), ] ) hparams_keys = list(hparams.keys()) @@ -162,7 +166,7 @@ def hparams_combinations(hparams): return hparams def tune_model(dataset_name, data): - LOGS_DIR = join('./logs/generic', dataset_name) + LOGS_DIR = join('./logs/gat', dataset_name) os.makedirs(LOGS_DIR, exist_ok=True) # Prepare the data x_train, y_train, x_val, y_val, x_test, y_test = adapt_data(data) @@ -206,18 +210,18 @@ def tune_model(dataset_name, data): if __name__ == '__main__': dataset = {} - dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') + dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') - dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/v1.0.0/edf') + dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/') dataset['tuh_eeg_artifact'].set_minimum_event_duration(4) - dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/v1.5.2/edf') + dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/') dataset['tuh_eeg_seizure'].set_minimum_event_duration(4) - # dataset['eegmmidb'] = EEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/1.0.0') + # dataset['eegmmidb'] = PhysioNetEEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/') # dataset['eegmmidb'].set_minimum_event_duration(4) - dataset['chbmit'] = CHBMITDataset('../../data/physionet.org/files/chbmit/1.0.0') + dataset['chbmit'] = PhysioNetCHBMITDataset('../../data/physionet.org/files/chbmit/') dataset['chbmit'].set_minimum_event_duration(4) """ diff --git a/examples/tensorboard/example_tensorboard_cnn.py b/examples/tensorboard/example_tensorboard_cnn.py new file mode 100755 index 0000000..b345afc --- /dev/null +++ b/examples/tensorboard/example_tensorboard_cnn.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python + +# Ignore MNE and TensorFlow warnings +import warnings +warnings.simplefilter(action='ignore') + +# Import TensorFlow with GPU memory settings +import tensorflow as tf +gpus = tf.config.experimental.list_physical_devices('GPU') +try: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) +except RuntimeError as e: + print(e) + +# Import TensorBoard params and metrics +from tensorboard.plugins.hparams import api as hp +from tensorflow.keras.metrics import CategoricalAccuracy, Precision, Recall + +# Import Spektral for GraphAttention +import spektral as sp + +# Others imports +import os +import pickle +import numpy as np +from random import shuffle +from itertools import product +from scipy.sparse import csc_matrix +from spektral.layers.ops import sp_matrix_to_sp_tensor +from sklearn.model_selection import train_test_split +from tensorflow.python.keras.utils.np_utils import to_categorical + +# Relative import PyEEGLab +import sys +from os.path import abspath, dirname, join + +sys.path.insert(0, abspath(join(dirname(__file__), '../..'))) +from pyeeglab import * + +def build_data(dataset): + preprocessing = Pipeline([ + CommonChannelSet(), + LowestFrequency(), + ToDataframe(), + MinMaxCentralizedNormalization(), + DynamicWindow(8), + ForkedPreprocessor( + inputs=[ + SpearmanCorrelation(), + Mean(), + Variance(), + Skewness(), + Kurtosis(), + ZeroCrossing(), + AbsoluteArea(), + PeakToPeak(), + Bandpower(['Delta', 'Theta', 'Alpha', 'Beta']) + ], + output=ToMergedDataframes() + ), + ToNumpy() + ]) + + return dataset.set_pipeline(preprocessing).load() + +def adapt_data(data, test_size=0.1, shuffle=True): + if isinstance(data, str): + with open(data, 'rb') as f: + data = pickle.load(f) + samples, labels = data['data'], data['labels'] + x_train, x_test, y_train, y_test = train_test_split(samples, labels, test_size=test_size, shuffle=shuffle, stratify=labels) + x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=test_size, shuffle=shuffle, stratify=y_train) + classes = np.sort(np.unique(labels)) + y_train = to_categorical(y_train, num_classes=len(classes)) + y_test = to_categorical(y_test, num_classes=len(classes)) + y_val = to_categorical(y_val, num_classes=len(classes)) + return x_train, y_train, x_val, y_val, x_test, y_test + +def build_model(shape, classes, hparams): + print(hparams) + N = shape[2] + F = shape[3] - N + frames = shape[1] + + def get_frame(x, frame, N, F): + x = tf.slice(x, [0, frame, 0, 0], [-1, 1, N, F]) + x = tf.squeeze(x, axis=[1]) + return x + + input_0 = tf.keras.Input((frames, N, F + N)) + + layers = [] + for frame in range(frames): + frame_matrix = tf.keras.layers.Lambda( + get_frame, + arguments={'frame': frame, 'N': N, 'F': F} + )(input_0) + + x = tf.keras.layers.Conv1D(hparams['filters'], hparams['kernel'], data_format='channels_first')(frame_matrix) + x = tf.keras.layers.MaxPooling1D(hparams['pool'])(x) + x = tf.keras.layers.Flatten()(x) + layers.append(x) + + combine = tf.keras.layers.Concatenate()(layers) + reshape = tf.keras.layers.Reshape((-1, frames))(combine) + lstm = tf.keras.layers.LSTM(hparams['hidden_units'])(reshape) + dropout = tf.keras.layers.Dropout(hparams['dropout'])(lstm) + out = tf.keras.layers.Dense(classes, activation='softmax')(dropout) + + model = tf.keras.Model(inputs=[input_0], outputs=out) + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=hparams['learning_rate']), + loss='categorical_crossentropy', + metrics=[ + 'accuracy', + Recall(class_id=0, name='recall'), + Precision(class_id=0, name='precision'), + ] + ) + model.summary() + model.save('logs/plot_cnn.h5') + return model + +def run_trial(path, step, model, hparams, x_train, y_train, x_val, y_val, x_test, y_test, epochs): + with tf.summary.create_file_writer(path).as_default(): + hp.hparams(hparams) + model.fit(x_train, y_train, epochs=epochs, batch_size=32, shuffle=True, validation_data=(x_val, y_val)) + loss, accuracy, recall, precision = model.evaluate(x_test, y_test) + tf.summary.scalar('accuracy', accuracy, step=step) + tf.summary.scalar('recall', recall, step=step) + tf.summary.scalar('precision', precision, step=step) + +def hparams_combinations(hparams): + hp.hparams_config( + hparams=list(hparams.values()), + metrics=[ + hp.Metric('accuracy', display_name='Accuracy'), + hp.Metric('recall', display_name='Recall'), + hp.Metric('precision', display_name='Precision'), + ] + ) + hparams_keys = list(hparams.keys()) + hparams_values = list(product(*[ + h.domain.values + for h in hparams.values() + ])) + hparams = [ + dict(zip(hparams_keys, values)) + for values in hparams_values + ] + shuffle(hparams) + return hparams + +def tune_model(dataset_name, data): + LOGS_DIR = join('./logs/cnn', dataset_name) + os.makedirs(LOGS_DIR, exist_ok=True) + # Prepare the data + x_train, y_train, x_val, y_val, x_test, y_test = adapt_data(data) + # Set tuning session + counter = 0 + # Parameters to be tuned + hparams = { + 'learning_rate': [1e-4], + 'hidden_units': [8, 16, 32, 64], + 'filters': [8, 16, 32], + 'kernel': [3, 5, 7], + 'pool': [2, 3], + 'dropout': [0.00, 0.05, 0.10, 0.15, 0.20], + } + hparams = { + key: hp.HParam(key, hp.Discrete(value)) + for key, value in hparams.items() + } + hparams = hparams_combinations(hparams) + for hparam in hparams: + # Build the model + model = build_model(data['data'].shape, len(data['labels_encoder']), hparam) + # Run session + run_name = f'run-{counter}' + print(f'--- Starting trial: {run_name}') + print(hparam) + run_trial( + join(LOGS_DIR, run_name), + counter, + model, + hparam, + x_train, + y_train, + x_val, + y_val, + x_test, + y_test, + epochs=50 + ) + counter += 1 + + +if __name__ == '__main__': + dataset = {} + + # dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') + + # dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/') + # dataset['tuh_eeg_artifact'].set_minimum_event_duration(4) + + dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/') + dataset['tuh_eeg_seizure'].set_minimum_event_duration(4) + + # dataset['eegmmidb'] = PhysioNetEEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/') + # dataset['eegmmidb'].set_minimum_event_duration(4) + + dataset['chbmit'] = PhysioNetCHBMITDataset('../../data/physionet.org/files/chbmit/') + dataset['chbmit'].set_minimum_event_duration(4) + + """ + Note: You can just use paths as values in the dictionary + and comment-out the first line of the following for cycle ;) + """ + + for key, value in dataset.items(): + value = build_data(value) + tune_model(key, value) diff --git a/examples/tensorboard/example_tensorboard_gat.py b/examples/tensorboard/example_tensorboard_gat.py index cb839a9..f74c1dc 100755 --- a/examples/tensorboard/example_tensorboard_gat.py +++ b/examples/tensorboard/example_tensorboard_gat.py @@ -26,7 +26,6 @@ import numpy as np from random import shuffle from itertools import product -from networkx import to_numpy_matrix from sklearn.model_selection import train_test_split from tensorflow.python.keras.utils.np_utils import to_categorical @@ -38,8 +37,6 @@ from pyeeglab import * def build_data(dataset): - dataset.set_cache_manager(PickleCache('../../export')) - preprocessing = Pipeline([ CommonChannelSet(), LowestFrequency(), @@ -96,7 +93,7 @@ def get_correlation_matrix(x, frame, N, F): input_0 = tf.keras.Input((frames, N, F + N)) - gans = [] + layers = [] for frame in range(frames): feature_matrix = tf.keras.layers.Lambda( get_feature_matrix, @@ -110,9 +107,9 @@ def get_correlation_matrix(x, frame, N, F): x = sp.layers.GraphAttention(hparams['output_shape'])([feature_matrix, correlation_matrix]) x = tf.keras.layers.Flatten()(x) - gans.append(x) + layers.append(x) - combine = tf.keras.layers.Concatenate()(gans) + combine = tf.keras.layers.Concatenate()(layers) reshape = tf.keras.layers.Reshape((frames, N * hparams['output_shape']))(combine) lstm = tf.keras.layers.LSTM(hparams['hidden_units'])(reshape) dropout = tf.keras.layers.Dropout(hparams['dropout'])(lstm) @@ -129,6 +126,7 @@ def get_correlation_matrix(x, frame, N, F): ] ) model.summary() + model.save('logs/plot_gat.h5') return model def run_trial(path, step, model, hparams, x_train, y_train, x_val, y_val, x_test, y_test, epochs): @@ -162,7 +160,7 @@ def hparams_combinations(hparams): return hparams def tune_model(dataset_name, data): - LOGS_DIR = join('./logs/gan', dataset_name) + LOGS_DIR = join('./logs/gat', dataset_name) os.makedirs(LOGS_DIR, exist_ok=True) # Prepare the data x_train, y_train, x_val, y_val, x_test, y_test = adapt_data(data) @@ -206,18 +204,18 @@ def tune_model(dataset_name, data): if __name__ == '__main__': dataset = {} - dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') + dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') - dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/v1.0.0/edf') + dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/') dataset['tuh_eeg_artifact'].set_minimum_event_duration(4) - dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/v1.5.2/edf') + dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/') dataset['tuh_eeg_seizure'].set_minimum_event_duration(4) - # dataset['eegmmidb'] = EEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/1.0.0') + # dataset['eegmmidb'] = PhysioNetEEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/') # dataset['eegmmidb'].set_minimum_event_duration(4) - dataset['chbmit'] = CHBMITDataset('../../data/physionet.org/files/chbmit/1.0.0') + dataset['chbmit'] = PhysioNetCHBMITDataset('../../data/physionet.org/files/chbmit/') dataset['chbmit'].set_minimum_event_duration(4) """ diff --git a/examples/tensorboard/example_tensorboard_gnn.py b/examples/tensorboard/example_tensorboard_gnn.py index 551e69c..40c9977 100755 --- a/examples/tensorboard/example_tensorboard_gnn.py +++ b/examples/tensorboard/example_tensorboard_gnn.py @@ -39,8 +39,6 @@ from pyeeglab import * def build_data(dataset): - dataset.set_cache_manager(PickleCache('../../export')) - preprocessing = Pipeline([ CommonChannelSet(), LowestFrequency(), @@ -97,7 +95,7 @@ def get_correlation_matrix(x, frame, N, F): input_0 = tf.keras.Input((frames, N, F + N)) - gans = [] + layers = [] for frame in range(frames): feature_matrix = tf.keras.layers.Lambda( get_feature_matrix, @@ -111,9 +109,9 @@ def get_correlation_matrix(x, frame, N, F): x = sp.layers.GraphConv(hparams['output_shape'])([feature_matrix, correlation_matrix]) x = tf.keras.layers.Flatten()(x) - gans.append(x) + layers.append(x) - combine = tf.keras.layers.Concatenate()(gans) + combine = tf.keras.layers.Concatenate()(layers) reshape = tf.keras.layers.Reshape((frames, N * hparams['output_shape']))(combine) lstm = tf.keras.layers.LSTM(hparams['hidden_units'])(reshape) dropout = tf.keras.layers.Dropout(hparams['dropout'])(lstm) @@ -130,6 +128,7 @@ def get_correlation_matrix(x, frame, N, F): ] ) model.summary() + model.save('logs/plot_gnn.h5') return model def run_trial(path, step, model, hparams, x_train, y_train, x_val, y_val, x_test, y_test, epochs): @@ -207,18 +206,18 @@ def tune_model(dataset_name, data): if __name__ == '__main__': dataset = {} - # dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') + # dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') - # dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/v1.0.0/edf') + # dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/') # dataset['tuh_eeg_artifact'].set_minimum_event_duration(4) - dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/v1.5.2/edf') + dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/') dataset['tuh_eeg_seizure'].set_minimum_event_duration(4) - # dataset['eegmmidb'] = EEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/1.0.0') + # dataset['eegmmidb'] = PhysioNetEEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/') # dataset['eegmmidb'].set_minimum_event_duration(4) - dataset['chbmit'] = CHBMITDataset('../../data/physionet.org/files/chbmit/1.0.0') + dataset['chbmit'] = PhysioNetCHBMITDataset('../../data/physionet.org/files/chbmit/') dataset['chbmit'].set_minimum_event_duration(4) """ diff --git a/examples/tensorboard/example_validation.py b/examples/tensorboard/example_validation.py new file mode 100755 index 0000000..cc906f7 --- /dev/null +++ b/examples/tensorboard/example_validation.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python + +# Ignore MNE and TensorFlow warnings +import warnings +warnings.simplefilter(action='ignore') + +# Import TensorFlow with GPU memory settings +import tensorflow as tf +gpus = tf.config.experimental.list_physical_devices('GPU') +try: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) +except RuntimeError as e: + print(e) + +# Import TensorBoard params and metrics +from tensorboard.plugins.hparams import api as hp +from tensorflow.keras.metrics import Accuracy, Precision, Recall + +# Import Spektral for GraphAttention +import spektral as sp + +# Others imports +import os +import pickle +import numpy as np +import pandas as pd +from random import shuffle +from itertools import product +from sklearn.model_selection import train_test_split, StratifiedKFold +from tensorflow.python.keras.utils.np_utils import to_categorical + +# Relative import PyEEGLab +import sys +from os.path import abspath, dirname, join + +sys.path.insert(0, abspath(join(dirname(__file__), '../..'))) +from pyeeglab import * + +def build_data(dataset): + preprocessing = Pipeline([ + CommonChannelSet(), + LowestFrequency(), + ToDataframe(), + MinMaxCentralizedNormalization(), + DynamicWindow(8), + ForkedPreprocessor( + inputs=[ + SpearmanCorrelation(), + Mean(), + Variance(), + Skewness(), + Kurtosis(), + ZeroCrossing(), + AbsoluteArea(), + PeakToPeak(), + Bandpower(['Delta', 'Theta', 'Alpha', 'Beta']) + ], + output=ToMergedDataframes() + ), + ToNumpy() + ]) + + return dataset.set_pipeline(preprocessing).load() + +def adapt_data(data, test_size=0.1): + if isinstance(data, str): + with open(data, 'rb') as f: + data = pickle.load(f) + samples, labels = data['data'], data['labels'] + x_train, x_test, y_train, y_test = train_test_split(samples, labels, test_size=test_size, shuffle=True, stratify=labels) + x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=test_size, shuffle=True, stratify=y_train) + classes = np.sort(np.unique(labels)) + y_train = to_categorical(y_train, num_classes=len(classes)) + y_test = to_categorical(y_test, num_classes=len(classes)) + y_val = to_categorical(y_val, num_classes=len(classes)) + return x_train, y_train, x_val, y_val, x_test, y_test + +def stratified_k_cross_validation_split(data, folds): + if isinstance(data, str): + with open(data, 'rb') as f: + data = pickle.load(f) + samples, labels = data['data'], data['labels'] + skf = StratifiedKFold(n_splits=folds, shuffle=True) + return samples, labels, skf.split(samples, labels) + +def build_model(shape, classes, hparams, metrics): + print(hparams) + N = shape[2] + F = shape[3] - N + frames = shape[1] + + def get_frame(x, frame, N, F): + x = tf.slice(x, [0, frame, 0, 0], [-1, 1, N, F]) + x = tf.squeeze(x, axis=[1]) + return x + + input_0 = tf.keras.Input((frames, N, F + N)) + + layers = [] + for frame in range(frames): + frame_matrix = tf.keras.layers.Lambda( + get_frame, + arguments={'frame': frame, 'N': N, 'F': F} + )(input_0) + + x = tf.keras.layers.Conv1D(hparams['filters'], hparams['kernel'], data_format='channels_first')(frame_matrix) + x = tf.keras.layers.MaxPooling1D(hparams['pool'])(x) + x = tf.keras.layers.Flatten()(x) + layers.append(x) + + combine = tf.keras.layers.Concatenate()(layers) + reshape = tf.keras.layers.Reshape((-1, frames))(combine) + lstm = tf.keras.layers.LSTM(hparams['hidden_units'])(reshape) + dropout = tf.keras.layers.Dropout(hparams['dropout'])(lstm) + out = tf.keras.layers.Dense(classes, activation='softmax')(dropout) + + model = tf.keras.Model(inputs=[input_0], outputs=out) + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=hparams['learning_rate']), + loss='categorical_crossentropy', + metrics=[ + 'accuracy', + Recall(class_id=0, name='recall'), + Specificity(class_id=0, name='specificity'), + Precision(class_id=0, name='precision'), + F1Score(class_id=0, name='f1score'), + ] + ) + model.summary() + model.save('logs/plot_cnn.h5') + return model + +def run_trial(path, step, model, hparams, metrics, x_train, y_train, x_test, y_test, epochs): + with tf.summary.create_file_writer(path).as_default(): + hp.hparams(hparams) + model.fit(x_train, y_train, epochs=epochs, batch_size=32, shuffle=True) + metrics_output = model.evaluate(x_test, y_test) + metrics_output = dict(zip(model.metrics_names, metrics_output)) + for metric_name, metric_value in metrics_output.items(): + if metric_name in metrics: + tf.summary.scalar(metric_name, metric_value, step=step) + return metrics_output + +def hparams_combinations(hparams, metrics): + hp.hparams_config( + hparams=list(hparams.values()), + metrics=[ + hp.Metric(metric, display_name=metric.capitalize()) + for metric in metrics + ] + ) + hparams_keys = list(hparams.keys()) + hparams_values = list(product(*[ + h.domain.values + for h in hparams.values() + ])) + hparams = [ + dict(zip(hparams_keys, values)) + for values in hparams_values + ] + shuffle(hparams) + return hparams + +def tune_model(dataset_name, data, metrics, folds, val_size=0.1): + LOGS_DIR = join('./logs/cnn', dataset_name) + os.makedirs(LOGS_DIR, exist_ok=True) + # Prepare the data + data, labels, splitter = stratified_k_cross_validation_split(data, folds) + # Set tuning session and results + counter = 0 + results = [] + # Parameters to be tuned + hparam = { + 'learning_rate': 1e-4, + 'kernel': 5, + 'filters': 32, + 'pool': 5, + 'hidden_units': 64, + 'dropout': 0.10, + } + for train_idx, test_idx in splitter: + # Generate folds + x_train, y_train = data[train_idx], labels[train_idx] + x_test, y_test = data[test_idx], labels[test_idx] + classes = np.sort(np.unique(labels)) + y_train = to_categorical(y_train, num_classes=len(classes)) + y_test = to_categorical(y_test, num_classes=len(classes)) + # Build the model + model = build_model(data.shape, len(classes), hparam, metrics) + # Run session + run_name = f'run-{counter}' + print(f'--- Starting trial: {run_name}') + print(hparam) + result = run_trial( + join(LOGS_DIR, run_name), + counter, + model, + hparam, + metrics, + x_train, + y_train, + x_test, + y_test, + epochs=50 + ) + counter += 1 + results.append(result) + results = pd.DataFrame(results) + hparam = '-'.join([ + str(key) + '_' + str(value) + for key, value in hparam.items() + ]) + results.to_csv(join(LOGS_DIR, "validation-{}-{}-results.csv".format(dataset_name, hparam))) + print(results) + + +if __name__ == '__main__': + dataset = {} + + # dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') + + # dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/') + # dataset['tuh_eeg_artifact'].set_minimum_event_duration(4) + + # dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/') + # dataset['tuh_eeg_seizure'].set_minimum_event_duration(4) + + # dataset['eegmmidb'] = PhysioNetEEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/') + # dataset['eegmmidb'].set_minimum_event_duration(4) + + dataset['chbmit'] = PhysioNetCHBMITDataset('../../data/physionet.org/files/chbmit/') + dataset['chbmit'].set_minimum_event_duration(4) + + """ + Note: You can just use paths as values in the dictionary + and comment-out the first line of the following for cycle ;) + """ + + for key, value in dataset.items(): + metrics = [ + 'accuracy', + 'recall', + 'specificity', + 'precision', + 'f1score', + ] + value = build_data(value) + tune_model(key, value, metrics, folds=10) diff --git a/examples/tuh_eeg_abnormal/example_cnn_dense_classification.py b/examples/tuh_eeg_abnormal/example_cnn_dense_classification.py index 392d04f..f813d4f 100755 --- a/examples/tuh_eeg_abnormal/example_cnn_dense_classification.py +++ b/examples/tuh_eeg_abnormal/example_cnn_dense_classification.py @@ -21,12 +21,11 @@ import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) -from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \ +from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \ LowestFrequency, ToDataframe, DynamicWindow, BinarizedSpearmanCorrelation, \ ToNumpy -dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') -dataset.set_cache_manager(PickleCache('../../export')) +dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') preprocessing = Pipeline([ CommonChannelSet(), diff --git a/examples/tuh_eeg_abnormal/example_cnn_dense_classification_with_filter.py b/examples/tuh_eeg_abnormal/example_cnn_dense_classification_with_filter.py index 3455490..59674b2 100755 --- a/examples/tuh_eeg_abnormal/example_cnn_dense_classification_with_filter.py +++ b/examples/tuh_eeg_abnormal/example_cnn_dense_classification_with_filter.py @@ -21,12 +21,11 @@ import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) -from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \ +from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \ LowestFrequency, BandPassFrequency, ToDataframe, DynamicWindow, \ BinarizedSpearmanCorrelation, ToNumpy -dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') -dataset.set_cache_manager(PickleCache('../../export')) +dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') preprocessing = Pipeline([ CommonChannelSet(), diff --git a/examples/tuh_eeg_abnormal/example_cnn_lstm_classification.py b/examples/tuh_eeg_abnormal/example_cnn_lstm_classification.py index 38407f1..e553bd3 100755 --- a/examples/tuh_eeg_abnormal/example_cnn_lstm_classification.py +++ b/examples/tuh_eeg_abnormal/example_cnn_lstm_classification.py @@ -21,12 +21,11 @@ import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) -from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \ +from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \ LowestFrequency, ToDataframe, DynamicWindow, BinarizedSpearmanCorrelation, \ ToNumpy -dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') -dataset.set_cache_manager(PickleCache('../../export')) +dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') preprocessing = Pipeline([ CommonChannelSet(), diff --git a/examples/tuh_eeg_abnormal/example_cnn_lstm_classification_with_filter.py b/examples/tuh_eeg_abnormal/example_cnn_lstm_classification_with_filter.py index ea7dab1..faee56a 100755 --- a/examples/tuh_eeg_abnormal/example_cnn_lstm_classification_with_filter.py +++ b/examples/tuh_eeg_abnormal/example_cnn_lstm_classification_with_filter.py @@ -21,12 +21,11 @@ import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) -from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \ +from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \ LowestFrequency, BandPassFrequency, ToDataframe, DynamicWindow, \ BinarizedSpearmanCorrelation, ToNumpy -dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') -dataset.set_cache_manager(PickleCache('../../export')) +dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') preprocessing = Pipeline([ CommonChannelSet(), diff --git a/examples/tuh_eeg_abnormal/example_gat_lstm_classification.py b/examples/tuh_eeg_abnormal/example_gat_lstm_classification.py index 3f85e02..cb6e7be 100755 --- a/examples/tuh_eeg_abnormal/example_gat_lstm_classification.py +++ b/examples/tuh_eeg_abnormal/example_gat_lstm_classification.py @@ -23,12 +23,11 @@ import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) -from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \ +from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \ LowestFrequency, ToDataframe, DynamicWindow, BinarizedSpearmanCorrelation, \ CorrelationToAdjacency, Bandpower, GraphWithFeatures, ForkedPreprocessor -dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') -dataset.set_cache_manager(PickleCache('../../export')) +dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') preprocessing = Pipeline([ CommonChannelSet(), diff --git a/examples/tuh_eeg_abnormal/example_gat_lstm_classification_with_filter.py b/examples/tuh_eeg_abnormal/example_gat_lstm_classification_with_filter.py index 5bd0c3f..859ff78 100755 --- a/examples/tuh_eeg_abnormal/example_gat_lstm_classification_with_filter.py +++ b/examples/tuh_eeg_abnormal/example_gat_lstm_classification_with_filter.py @@ -23,13 +23,12 @@ import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) -from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \ +from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \ LowestFrequency, BandPassFrequency, ToDataframe, DynamicWindow, \ BinarizedSpearmanCorrelation, CorrelationToAdjacency, Bandpower, \ GraphWithFeatures, ForkedPreprocessor -dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') -dataset.set_cache_manager(PickleCache('../../export')) +dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') preprocessing = Pipeline([ CommonChannelSet(), diff --git a/pyeeglab/__init__.py b/pyeeglab/__init__.py index cce3be6..df80b73 100644 --- a/pyeeglab/__init__.py +++ b/pyeeglab/__init__.py @@ -1,15 +1,11 @@ import logging import warnings -from importlib.util import find_spec -from mne.utils import set_config +logging.basicConfig(format="%(asctime)s %(levelname)7s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S") from .dataset import * -from .io import Raw -from .cache import PickleCache -from .pipeline import Pipeline, ForkedPreprocessor +from .pipeline import * from .preprocess import * logging.getLogger().setLevel(logging.DEBUG) - warnings.filterwarnings("ignore", category=RuntimeWarning) diff --git a/pyeeglab/cache/__init__.py b/pyeeglab/cache/__init__.py deleted file mode 100644 index eb3bba2..0000000 --- a/pyeeglab/cache/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .cache import Cache, PickleCache diff --git a/pyeeglab/cache/cache.py b/pyeeglab/cache/cache.py deleted file mode 100644 index 202b9bd..0000000 --- a/pyeeglab/cache/cache.py +++ /dev/null @@ -1,65 +0,0 @@ -import json -import logging -from abc import ABC, abstractmethod -from typing import List, Dict - -from os.path import isfile, join -from pathlib import Path -from hashlib import md5 -from pickle import load, dump - -from ..io import DataLoader -from ..pipeline import Pipeline - -class Cache(ABC): - - def __init__(self) -> None: - logging.debug('Create cache manager') - - def _get_cache_key(self, dataset: str, loader: DataLoader, pipeline: Pipeline) -> str: - if dataset.endswith('dataset'): - dataset = dataset[:-7] - key = [loader, pipeline] - key = [hash(k) for k in key] - key = [str(k).encode() for k in key] - key = [md5(k).hexdigest()[:10] for k in key] - key = list(zip(['loader', 'pipeline'], key)) - key = ['_'.join(k) for k in key] - key = dataset + '_' + '_'.join(key) - return key - - @abstractmethod - def load(self, dataset: str, loader: DataLoader, pipeline: Pipeline) -> Dict: - pass - - -class PickleCache(Cache): - - def __init__(self, path: str): - super().__init__() - logging.debug('Create single pickle cache manager') - Path(path).mkdir(parents=True, exist_ok=True) - self.path = path - - def load(self, dataset: str, loader: DataLoader, pipeline: Pipeline): - logging.debug('Computing cache key') - key = self._get_cache_key(dataset, loader, pipeline) - logging.debug('Computed cache key: %s', key) - key = join(self.path, key + '.pkl') - if isfile(key): - logging.debug('Cache file found') - with open(key, 'rb') as file: - try: - logging.debug('Loading cache file') - data = load(file) - return data - except: - logging.debug('Loading cache file failed') - pass - logging.debug('Cache file not found, genereting new one') - data = loader.get_dataset() - data = pipeline.run(data) - with open(key, 'wb') as file: - logging.debug('Dumping cache file') - dump(data, file) - return data diff --git a/pyeeglab/database/__init__.py b/pyeeglab/database/__init__.py deleted file mode 100644 index fdcd6bb..0000000 --- a/pyeeglab/database/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .tables import * diff --git a/pyeeglab/database/tables.py b/pyeeglab/database/tables.py deleted file mode 100644 index a5fd22b..0000000 --- a/pyeeglab/database/tables.py +++ /dev/null @@ -1,106 +0,0 @@ -import logging - -from dataclasses import dataclass - -from sqlalchemy import Column, Integer, Float, Text, ForeignKey -from sqlalchemy.ext.declarative import declarative_base - -from ..io.raw import Raw - - -BASE_TABLE = declarative_base() - - -@dataclass -class File(BASE_TABLE): - """File represents a single file contained in the dataset. - - This is an ORM class derived from the BASE_TABLE in a declarative base - used by SQLAlchemy. - - Attributes - ---------- - id : str - The primary key generated randomly using an UUID4 generator. - extension : str - A not null indexed string reporting the EEG recording format. - path : str - A not null string used to point to the relative path of the file respect - to the current sqlite database location. - """ - __tablename__ = 'file' - id: str = Column(Text, primary_key=True) - extension: str = Column(Text, nullable=False, index=True) - path: str = Column(Text, nullable=False) - - -@dataclass -class Metadata(BASE_TABLE): - """ Metadata represents a single metadata record associated with a single - file contained in the dataset. - - This is an ORM class derived from the BASE_TABLE in a declarative base - used by SQLAlchemy. - - Attributes - ---------- - file_id : str - The foreign key related to file_id, used also as a primary key for metadata - table since this is a one-to-one relationship. - duration : int - This is the EEG sample duration reported in seconds. This is not inteded as - a precise duration estimate, but only a reference for statistical analysis. - For more precise duration measurement, please, use the Raw record class methods. - channels_count : int - This field report the number of channels reported in the EEG header. - channels_reference : str - An indexed string describing the EEG channel reference system. - channels_set : str - The list of channels saved as a JSON string extracted from the EEG header. - sampling_frequency : int - The sampling fequency expressend in Hz extrated from the EEG header. - max_value : float - The max value sampled in this record across all channels. - min_value : float - The min value sampled in this record across all channels. - """ - __tablename__ = 'metadata' - file_id: str = Column(Text, ForeignKey('file.id'), primary_key=True) - duration: int = Column(Integer, nullable=False) - channels_count: int = Column(Integer, nullable=False) - channels_reference: str = Column(Text, nullable=True, index=True) - channels_set: str = Column(Text, nullable=False, index=True) - sampling_frequency: int = Column(Integer, nullable=False, index=True) - max_value: float = Column(Float, nullable=False) - min_value: float = Column(Float, nullable=False) - - -@dataclass -class Event(BASE_TABLE): - """Event represents a single event associated to a single file. - Multiple events can be associated to multiple files, according to - the record annotations. - - Attributes - ---------- - id : str - This is the primary key that is associated to each event. It is - created from a UUID4. - file_id : str - This the foreign key that link the event to the related file. - begin : float - The timestep expressed in seconds from which the event begins. - end : float - The timestep expressed in seconds at which the event ends. - duration : float - The entire duration of the event. - label : str - The label of this event as written inthe annotations. - """ - __tablename__ = 'event' - id: str = Column(Text, primary_key=True) - file_id: str = Column(Text, ForeignKey('file.id'), index=True) - begin: float = Column(Float, nullable=False) - end: float = Column(Float, nullable=False) - duration: float = Column(Float, nullable=False) - label: str = Column(Text, nullable=False, index=True) diff --git a/pyeeglab/dataset/__init__.py b/pyeeglab/dataset/__init__.py index 6b1dac2..3880a61 100644 --- a/pyeeglab/dataset/__init__.py +++ b/pyeeglab/dataset/__init__.py @@ -1,5 +1,2 @@ -from .tuh_eeg_abnormal import TUHEEGAbnormalLoader, TUHEEGAbnormalDataset -from .tuh_eeg_artifact import TUHEEGArtifactLoader, TUHEEGArtifactDataset -from .tuh_eeg_seizure import TUHEEGSeizureLoader, TUHEEGSeizureDataset -from .eegmmidb import EEGMMIDBLoader, EEGMMIDBDataset -from .chbmit import CHBMITLoader, CHBMITDataset +from .physionet import * +from .tuh_eeg import * diff --git a/pyeeglab/dataset/annotation.py b/pyeeglab/dataset/annotation.py new file mode 100644 index 0000000..db8aee8 --- /dev/null +++ b/pyeeglab/dataset/annotation.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from mne.io import Raw, read_raw +from sqlalchemy import Column, ForeignKey, Text, Float +from sqlalchemy.orm import relationship +from .declarative_base import Base + + +@dataclass +class Annotation(Base): + __tablename__ = "annotation" + uuid: str = Column(Text, primary_key=True) + file_uuid: str = Column(Text, ForeignKey("file.uuid"), nullable=False) + begin: float = Column(Float, nullable=False) + end: float = Column(Float, nullable=False) + label: str = Column(Text, nullable=False, index=True) + + file = relationship("File", lazy="subquery") + + @property + def duration(self) -> float: + return self.end - self.begin + + def __enter__(self) -> Raw: + self.reader = read_raw(self.file.path) + tmax = self.reader.n_times / self.reader.info["sfreq"] - 0.1 + tmax = tmax if self.end > tmax else self.end + self.reader.crop(self.begin, tmax) + return self.reader + + def __exit__(self, *args, **kwargs) -> None: + self.reader.close() + del self.reader + diff --git a/pyeeglab/dataset/chbmit/__init__.py b/pyeeglab/dataset/chbmit/__init__.py deleted file mode 100644 index a592069..0000000 --- a/pyeeglab/dataset/chbmit/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .chbmit_loader import CHBMITLoader -from .chbmit_dataset import CHBMITDataset diff --git a/pyeeglab/dataset/chbmit/chbmit_dataset.py b/pyeeglab/dataset/chbmit/chbmit_dataset.py deleted file mode 100644 index 4692ae1..0000000 --- a/pyeeglab/dataset/chbmit/chbmit_dataset.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Dict - -from .chbmit_loader import CHBMITLoader -from ..dataset import Dataset - - -class CHBMITDataset(Dataset): - - def __init__(self, path: str = './data/physionet.org/files/chbmit/1.0.0/') -> None: - super().__init__(CHBMITLoader(path)) - - def _get_dataset_env(self) -> Dict: - env = super()._get_dataset_env() - env['class_id'] = 'noseizure' - return env diff --git a/pyeeglab/dataset/chbmit/chbmit_index.py b/pyeeglab/dataset/chbmit/chbmit_index.py deleted file mode 100644 index 05a1df9..0000000 --- a/pyeeglab/dataset/chbmit/chbmit_index.py +++ /dev/null @@ -1,57 +0,0 @@ -import logging - -from uuid import uuid4, uuid5, NAMESPACE_X500 -from os.path import isfile, join, sep -from wfdb import rdann - -from ...io import Index -from ...database import File, Event - -from typing import List - - -class CHBMITIndex(Index): - def __init__(self, path: str) -> None: - logging.debug('Create CHB-MIT Scalp EEG Index') - super().__init__('sqlite:///' + join(path, 'index.db'), path) - self.index() - - def _get_file(self, path: str) -> File: - length = len(self.path) - meta = path[length:].split(sep) - return File( - id=str(uuid5(NAMESPACE_X500, path[length:])), - extension=meta[-1].split('.')[-1], - path=path[length:], - ) - - def _get_record_events(self, file: File) -> List[Event]: - logging.debug('Add file %s raw events to index', file.id) - path = join(self.path, file.path) - if isfile(path + '.seizures'): - events = rdann(path, 'seizures') - events = list(events.sample / events.fs) - events = [events[i:i+2] for i in range(0, len(events), 2)] - events = [ - Event( - id=str(uuid4()), - file_id=file.id, - begin=event[0], - end=event[1], - duration=(event[1] - event[0]), - label='seizure' - ) - for event in events - ] - else: - events = [ - Event( - id=str(uuid4()), - file_id=file.id, - begin=60, - end=120, - duration=60, - label='noseizure' - ) - ] - return events diff --git a/pyeeglab/dataset/chbmit/chbmit_loader.py b/pyeeglab/dataset/chbmit/chbmit_loader.py deleted file mode 100644 index 022c057..0000000 --- a/pyeeglab/dataset/chbmit/chbmit_loader.py +++ /dev/null @@ -1,20 +0,0 @@ -import logging - -from typing import List - -from ...io import DataLoader -from .chbmit_index import CHBMITIndex - - -class CHBMITLoader(DataLoader): - - def __init__(self, path: str) -> None: - exclude_files = [ - 'chb03/chb03_35.edf', # Corrupted data - 'chb12/chb12_27.edf', # Bad channel names - 'chb12/chb12_28.edf', # Bad channel names - 'chb12/chb12_29.edf', # Bad channel names - ] - super().__init__(path, exclude_files=exclude_files) - logging.debug('Create CHB-MIT Scalp EEG Loader') - self.index = CHBMITIndex(self.path) diff --git a/pyeeglab/dataset/dataset.py b/pyeeglab/dataset/dataset.py index 692b3a2..fc67639 100644 --- a/pyeeglab/dataset/dataset.py +++ b/pyeeglab/dataset/dataset.py @@ -1,45 +1,286 @@ +import os +import json import logging -from abc import ABC -from typing import Dict +import hashlib +import pickle + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from functools import reduce +from multiprocessing import Pool, cpu_count +from operator import add, and_ +from uuid import uuid4, uuid5, NAMESPACE_X500 + +from typing import Dict, List, Tuple + +import mne +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker, Query + +from .declarative_base import Base +from .file import File +from .metadata import Metadata +from .annotation import Annotation -from ..io import DataLoader -from ..cache import Cache, PickleCache from ..pipeline import Pipeline + +@dataclass(init=False) class Dataset(ABC): + path: str + name: str + version: str - cache: Cache - loader: DataLoader - pipeline: Pipeline + extensions: List[str] + exclude_file: List[str] + exclude_channels_set: List[str] + exclude_channels_reference: List[str] + exclude_sampling_frequency: List[int] + minimum_annotation_duration: float - def __init__(self, loader: DataLoader) -> None: - logging.debug('Create dataset') - self.loader = loader - self.set_cache_manager(PickleCache('export')) - self.set_pipeline(Pipeline()) + session: Session + query: Query - def load(self) -> Dict: - dataset = self.__class__.__name__.lower() - return self.cache.load(dataset, self.loader, self.pipeline) + pipeline: Pipeline = None + + def __init__( + self, + path: str, + name: str, + version: str = None, + extensions: List[str] = [".edf"], + exclude_file: List[str] = None, + exclude_channels_set: List[str] = None, + exclude_channels_reference: List[str] = None, + exclude_sampling_frequency: List[str] = None, + minimum_annotation_duration: float = None + ) -> None: + # Set basic attributes + self.path = os.path.abspath(os.path.join(path, version)) + self.name = name + self.version = version - def _get_dataset_env(self) -> Dict: + # Set data set filter attributes + self.extensions = extensions if extensions else [] + self.exclude_file = exclude_file if exclude_file else [] + self.exclude_channels_set = exclude_channels_set if exclude_channels_set else [] + self.exclude_channels_reference = exclude_channels_reference if exclude_channels_reference else [] + self.exclude_sampling_frequency = exclude_sampling_frequency if exclude_sampling_frequency else [] + self.minimum_annotation_duration = minimum_annotation_duration if minimum_annotation_duration else 0 + + logging.info("Init dataset '%s'@'%s' at '%s'", self.name, self.version, self.path) + + # Make workspace directory + logging.debug("Make .pyeeglab directory") + workspace = os.path.join(self.path, ".pyeeglab") + os.makedirs(workspace, exist_ok=True) + logging.debug("Make .pyeeglab/cache directory") + os.makedirs(os.path.join(workspace, "cache"), exist_ok=True) + logging.debug("Set MNE log .pyeeglab/mne.log") + mne.set_log_file(os.path.join(workspace, "mne.log"), overwrite=False) + + # Index data set files + self.index() + + def __getstate__(self): + # Workaround for unpickable sqlalchemy.orm.session + # during multiprocess dataset loading + state = self.__dict__.copy() + for attribute in ["session", "query"]: + if hasattr(self, attribute): + del state[attribute] + return state + + @abstractmethod + def download(self, user: str = None, password: str = None) -> None: + pass + + def index(self) -> None: + # Init index session + logging.debug("Make index session") + connection = os.path.join(self.path, ".pyeeglab", "index.sqlite3") + connection = create_engine("sqlite:///" + connection) + Base.metadata.create_all(connection) + self.session = sessionmaker(bind=connection)() + # Open multiprocess pool + logging.info("Index data set directory") + pool = Pool(cpu_count()) + # Get files path from data set path + paths = [ + os.path.join(directory, filename) + for directory, _, filenames in os.walk(self.path) + for filename in filenames + ] + # Get Files instances form paths, filtering already indexed + files = self.session.query(File).all() + files = [file.uuid for file in files] + files = [ + file + for file in pool.map(self._get_file, paths) + if file.uuid not in files + ] + for file in files: + logging.debug("Add file %s to index", file.uuid) + # Filter raw data files by extension + raws = [ + file + for file in files + if os.path.splitext(file.path)[-1] in self.extensions + ] + # Get metadata and annotation for data files + metadatas = pool.map(self._get_metadata, raws) + annotations = pool.map(self._get_annotation, raws) + # Close multiprocess pool + pool.close() + pool.join() + # Commit insertions to index + commits = files + metadatas + reduce(add, annotations, []) + if commits: + logging.info("Commit insertions to index") + self.session.add_all(commits) + self.session.commit() + logging.info("Index data set completed") + # Init default query + logging.debug("Init default query") + self.query = self.session.query(File, Metadata, Annotation).\ + join(File.meta).\ + join(File.annotations).\ + filter(~Metadata.channels_reference.in_(self.exclude_channels_reference)).\ + filter(~Metadata.sampling_frequency.in_(self.exclude_sampling_frequency)).\ + filter((Annotation.end - Annotation.begin) >= self.minimum_annotation_duration) + # Filter exclude file paths + for file in self.exclude_file: + self.query = self.query.filter(~File.path.like("%{}%".format(file))) + logging.debug("SQL query representation: '%s'", str(self.query).replace("\n", "")) + + def _get_file(self, path: str) -> File: + return File( + uuid=str(uuid5(NAMESPACE_X500, path)), + path=path, + extension=os.path.splitext(path)[-1] + ) + + def _get_metadata(self, file: File) -> Metadata: + logging.debug("Add file %s metadata to index", file.uuid) + with file as reader: + info = reader.info + metadata = Metadata( + file_uuid=file.uuid, + duration=reader.n_times/info["sfreq"], + channels_set=json.dumps(info["ch_names"]), + sampling_frequency=info["sfreq"], + max_value=reader.get_data().max(), + min_value=reader.get_data().min(), + ) + return metadata + + def _get_annotation(self, file: File) -> List[Annotation]: + logging.debug("Add file %s annotations to index", file.uuid) + with file as reader: + annotations = [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=annotation[0], + end=annotation[0]+annotation[1], + label=annotation[2], + ) + for annotation in reader.annotations + ] + return annotations + + @property + def environment(self) -> Dict: + min_max = self.signal_min_max_range return { - 'channels_set': self.loader.get_channels_set(), - 'lowest_frequency': self.loader.get_lowest_frequency(), - 'max_value': self.loader.get_max_value(), - 'min_value': self.loader.get_min_value(), + "channels_set": self.maximal_channels_subset, + "lowest_frequency": self.lowest_frequency, + "min_value": min_max[0], + "max_value": min_max[1], } - def set_minimum_event_duration(self, minimum_event_duration: float) -> 'Dataset': - self.loader.set_minimum_event_duration(minimum_event_duration) - return self + @property + def lowest_frequency(self) -> float: + frequency = self.query.all() + frequency = min([ + f[1].sampling_frequency + for f in frequency + ], default=0) + return frequency - def set_cache_manager(self, cache: Cache) -> 'Dataset': - self.cache = cache - return self - - def set_pipeline(self, pipeline: Pipeline) -> 'Dataset': + @property + def maximal_channels_subset(self) -> List[str]: + channels = self.query.group_by(Metadata.channels_set).all() + channels = [ + frozenset(json.loads(channel[1].channels_set)) + for channel in channels + ] + channels = reduce(and_, channels) + channels = channels - frozenset(self.exclude_channels_set) + channels = sorted(channels) + return channels + + @property + def signal_min_max_range(self) -> Tuple[float]: + min_max = self.query.all() + min_max = [m[1] for m in min_max] + min_max = tuple([ + min([m.min_value for m in min_max], default=0), + max([m.max_value for m in min_max], default=0), + ]) + return min_max + + def set_pipeline(self, pipeline: Pipeline) -> "Dataset": self.pipeline = pipeline - environment = self._get_dataset_env() - self.pipeline.environment.update(environment) + self.pipeline.environment.update(self.environment) + return self + + def set_minimum_event_duration(self, duration: float) -> "Dataset": + logging.warning("This function will be deprecated in the near future") + self.minimum_annotation_duration = duration return self + + def load(self) -> Dict: + # Compute cache path + cache = os.path.join(self.path, ".pyeeglab", "cache") + # Compute cache key + logging.info("Compute cache key") + name = self.__class__.__name__.lower() + if name.endswith("dataset"): + name = name[:-len("dataset")] + key = [hash(self), hash(self.pipeline)] + key = [str(k).encode() for k in key] + key = [hashlib.md5(k).hexdigest()[:10] for k in key] + key = list(zip(["loader", "pipeline"], key)) + key = ["_".join(k) for k in key] + key = name + "_" + "_".join(key) + logging.info("Computed cache key: %s", key) + # Load file cache + cache = os.path.join(cache, key + ".pkl") + if os.path.exists(cache): + logging.info("Cache file found at %s", cache) + with open(cache, "rb") as reader: + try: + logging.info("Loading cache file") + return pickle.load(reader) + except: + logging.error("Loading cache file failed") + # Cache file not found, start preprocessing + logging.info("Cache file not found, genereting new one") + data = [row[2] for row in self.query.all()] + data = self.pipeline.run(data) + with open(cache, "wb") as file: + logging.info("Dumping cache file") + pickle.dump(data, file) + return data + + def __hash__(self) -> int: + key = [self.path, self.version, self.minimum_annotation_duration] + key += self.exclude_file + key += self.exclude_channels_set + key += self.exclude_channels_reference + key += self.exclude_sampling_frequency + key = json.dumps(key).encode() + key = hashlib.md5(key).hexdigest() + key = int(key, 16) + return key diff --git a/pyeeglab/dataset/declarative_base.py b/pyeeglab/dataset/declarative_base.py new file mode 100644 index 0000000..c64447d --- /dev/null +++ b/pyeeglab/dataset/declarative_base.py @@ -0,0 +1,2 @@ +from sqlalchemy.ext.declarative import declarative_base +Base = declarative_base() diff --git a/pyeeglab/dataset/eegmmidb/__init__.py b/pyeeglab/dataset/eegmmidb/__init__.py deleted file mode 100644 index 4cbd826..0000000 --- a/pyeeglab/dataset/eegmmidb/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .eegmmidb_loader import EEGMMIDBLoader -from .eegmmidb_dataset import EEGMMIDBDataset diff --git a/pyeeglab/dataset/eegmmidb/eegmmidb_dataset.py b/pyeeglab/dataset/eegmmidb/eegmmidb_dataset.py deleted file mode 100644 index a5aa15d..0000000 --- a/pyeeglab/dataset/eegmmidb/eegmmidb_dataset.py +++ /dev/null @@ -1,8 +0,0 @@ -from .eegmmidb_loader import EEGMMIDBLoader -from ..dataset import Dataset - - -class EEGMMIDBDataset(Dataset): - - def __init__(self, path: str = './data/physionet.org/files/eegmmidb/1.0.0/') -> None: - super().__init__(EEGMMIDBLoader(path)) diff --git a/pyeeglab/dataset/eegmmidb/eegmmidb_index.py b/pyeeglab/dataset/eegmmidb/eegmmidb_index.py deleted file mode 100644 index 83af529..0000000 --- a/pyeeglab/dataset/eegmmidb/eegmmidb_index.py +++ /dev/null @@ -1,23 +0,0 @@ -import logging - -from uuid import uuid5, NAMESPACE_X500 -from os.path import join, sep - -from ...io import Index -from ...database import File - - -class EEGMMIDBIndex(Index): - def __init__(self, path: str) -> None: - logging.debug('Create EEG Motor Movement/Imagery Index') - super().__init__('sqlite:///' + join(path, 'index.db'), path) - self.index() - - def _get_file(self, path: str) -> File: - length = len(self.path) - meta = path[length:].split(sep) - return File( - id=str(uuid5(NAMESPACE_X500, path[length:])), - extension=meta[-1].split('.')[-1], - path=path[length:] - ) diff --git a/pyeeglab/dataset/eegmmidb/eegmmidb_loader.py b/pyeeglab/dataset/eegmmidb/eegmmidb_loader.py deleted file mode 100644 index b9bde7b..0000000 --- a/pyeeglab/dataset/eegmmidb/eegmmidb_loader.py +++ /dev/null @@ -1,18 +0,0 @@ -import logging - -from typing import List - -from ...io import DataLoader -from .eegmmidb_index import EEGMMIDBIndex - - -class EEGMMIDBLoader(DataLoader): - - def __init__(self, path: str, exclude_frequency: List[int] = [128]) -> None: - exclude_files = [ - 'S021/S021R08.edf', # Corrupted data - 'S104/S104R04.edf', # Corrupted data - ] - super().__init__(path, exclude_files = exclude_files, exclude_frequency = exclude_frequency) - logging.debug('Create EEG Motor Movement/Imagery Loader') - self.index = EEGMMIDBIndex(self.path) diff --git a/pyeeglab/dataset/file.py b/pyeeglab/dataset/file.py new file mode 100644 index 0000000..e4cbbb1 --- /dev/null +++ b/pyeeglab/dataset/file.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from mne.io import Raw, read_raw +from sqlalchemy import Column, Text +from sqlalchemy.orm import relationship +from .declarative_base import Base + + +@dataclass +class File(Base): + __tablename__ = "file" + uuid: str = Column(Text, primary_key=True) + path: str = Column(Text, nullable=False) + extension: str = Column(Text, nullable=False) + + meta = relationship("Metadata", cascade="all,delete", backref="File") + annotations = relationship("Annotation", cascade="all,delete", backref="File") + + def __enter__(self) -> Raw: + self.reader = read_raw(self.path) + return self.reader + + def __exit__(self, *args, **kwargs) -> None: + self.reader.close() + del self.reader diff --git a/pyeeglab/dataset/metadata.py b/pyeeglab/dataset/metadata.py new file mode 100644 index 0000000..fb770e4 --- /dev/null +++ b/pyeeglab/dataset/metadata.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from sqlalchemy import Column, ForeignKey, Text, Float, Integer +from .declarative_base import Base + + +@dataclass +class Metadata(Base): + __tablename__ = "metadata" + file_uuid: str = Column(Text, ForeignKey("file.uuid"), primary_key=True) + duration: int = Column(Float, nullable=False) + channels_set: str = Column(Text, nullable=False, index=True) + channels_reference: str = Column(Text, nullable=True, index=True) + sampling_frequency: int = Column(Integer, nullable=False, index=True) + max_value: float = Column(Float, nullable=False) + min_value: float = Column(Float, nullable=False) diff --git a/pyeeglab/dataset/physionet/__init__.py b/pyeeglab/dataset/physionet/__init__.py new file mode 100644 index 0000000..3e3a51d --- /dev/null +++ b/pyeeglab/dataset/physionet/__init__.py @@ -0,0 +1,2 @@ +from .chbmit_dataset import * +from .eegmmidb_dataset import * diff --git a/pyeeglab/dataset/physionet/chbmit_dataset.py b/pyeeglab/dataset/physionet/chbmit_dataset.py new file mode 100644 index 0000000..8fe88ea --- /dev/null +++ b/pyeeglab/dataset/physionet/chbmit_dataset.py @@ -0,0 +1,72 @@ +import os +import logging + +from uuid import uuid4 + +from typing import List + +from .utils import wget + +from ..dataset import Dataset +from ..file import File +from ..annotation import Annotation + +import wfdb + + +class PhysioNetCHBMITDataset(Dataset): + + def __init__( + self, + path: str = "./data/physionet.org/files/chbmit/", + version: str = "1.0.0", + exclude_file: List[str] = [ + "chb03/chb03_35.edf", # Corrupted data + "chb12/chb12_27.edf", # Bad channel names + "chb12/chb12_28.edf", # Bad channel names + "chb12/chb12_29.edf", # Bad channel names + ], + exclude_channels_set: List[str] = [ + "ECG", + "LOC-ROC", + "LUE-RAE", + "VNS", + ], + ) -> None: + super().__init__( + path=path, + name="CHB-MIT Scalp EEG Database", + version=version, + exclude_file=exclude_file, + exclude_channels_set=exclude_channels_set, + ) + + def download(self, user: str = None, password: str = None) -> None: + wget(self.path, user, password, "chbmit", self.version) + + def _get_annotation(self, file: File) -> List[Annotation]: + logging.debug("Add file %s annotations to index", file.uuid) + annotations = [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=60, + end=120, + label="noseizure" + ) + ] + if os.path.isfile(file.path + ".seizures"): + annotations = wfdb.rdann(file.path, "seizures") + annotations = list(annotations.sample / annotations.fs) + annotations = [annotations[i:i+2] for i in range(0, len(annotations), 2)] + annotations = [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=annotation[0], + end=annotation[1], + label="seizure" + ) + for annotation in annotations + ] + return annotations diff --git a/pyeeglab/dataset/physionet/eegmmidb_dataset.py b/pyeeglab/dataset/physionet/eegmmidb_dataset.py new file mode 100644 index 0000000..f6da608 --- /dev/null +++ b/pyeeglab/dataset/physionet/eegmmidb_dataset.py @@ -0,0 +1,64 @@ +import os +import logging + +from uuid import uuid4 + +from typing import List + +from .utils import wget + +from ..dataset import Dataset +from ..file import File +from ..annotation import Annotation + + +class PhysioNetEEGMMIDBDataset(Dataset): + + def __init__( + self, + path: str = "./data/physionet.org/files/eegmmidb/", + version: str = "1.0.0", + exclude_file: List[str] = [ + "S021/S021R08.edf", # Corrupted data + "S104/S104R04.edf", # Corrupted data + ], + exclude_sampling_frequency: List[int] = [ 128 ], + ) -> None: + super().__init__( + path=path, + name="EEG Motor Movement/Imagery Dataset", + version=version, + exclude_file=exclude_file, + exclude_sampling_frequency=exclude_sampling_frequency, + ) + + def download(self, user: str = None, password: str = None) -> None: + wget(self.path, user, password, "eegmmidb", self.version) + + def _get_annotation(self, file: File) -> List[Annotation]: + logging.debug("Add file %s annotations to index", file.uuid) + with file as reader: + try: + annotations = [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=annotation[0], + end=annotation[0]+annotation[1], + label=annotation[2], + ) + for annotation in reader.annotations + ] + except KeyError: + # Alternative annotation format + annotations = [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=annotation["onset"], + end=annotation["onset"]+annotation["duration"], + label=annotation["description"], + ) + for annotation in reader.annotations + ] + return annotations diff --git a/pyeeglab/dataset/physionet/utils.py b/pyeeglab/dataset/physionet/utils.py new file mode 100644 index 0000000..c379903 --- /dev/null +++ b/pyeeglab/dataset/physionet/utils.py @@ -0,0 +1,30 @@ +import logging +import subprocess + + +def wget(path: str, user: str, password: str, slug: str, version: str) -> None: + logging.info("Download started, it will take some time") + url = "https://physionet.org/files/" + slug + "/" + version + "/" + process = subprocess.Popen( + [ + "wget", + "-r", + "-N", + "-c", + "-np", + "-nH", + "--cut-dirs=3", + url, + "-P", + path + ], + stdout=subprocess.PIPE, + universal_newlines=True, + ) + while True: + output = process.stdout.readline() + logging.info(output.strip()) + if process.poll() is not None: + for output in process.stdout.readlines(): + logging.info(output.strip()) + break diff --git a/pyeeglab/dataset/tuh_eeg/__init__.py b/pyeeglab/dataset/tuh_eeg/__init__.py new file mode 100644 index 0000000..d3db617 --- /dev/null +++ b/pyeeglab/dataset/tuh_eeg/__init__.py @@ -0,0 +1,3 @@ +from .abnormal_dataset import * +from .artifact_dataset import * +from .seizure_dataset import * diff --git a/pyeeglab/dataset/tuh_eeg/abnormal_dataset.py b/pyeeglab/dataset/tuh_eeg/abnormal_dataset.py new file mode 100644 index 0000000..62e885d --- /dev/null +++ b/pyeeglab/dataset/tuh_eeg/abnormal_dataset.py @@ -0,0 +1,59 @@ +import os +import logging + +from uuid import uuid4 + +from typing import List + +from .utils import rsync + +from ..dataset import Dataset +from ..file import File +from ..metadata import Metadata +from ..annotation import Annotation + + +class TUHEEGAbnormalDataset(Dataset): + + def __init__( + self, + path: str = "./data/tuh_eeg_abnormal/", + version: str = "2.0.0", + exclude_channels_set: List[str] = [ + "BURSTS", + "ECG EKG-REF", + "EMG-REF", + "IBI", + "PHOTIC-REF", + "PULSE RATE", + "STI 014", + "SUPPR" + ], + ) -> None: + super().__init__( + path=path, + name="Temple University Hospital EEG Abnormal Dataset", + version="v"+version, + exclude_channels_set=exclude_channels_set, + ) + + def download(self, user: str = None, password: str = None) -> None: + rsync(self.path, user, password, "tuh_eeg_abnormal", self.version) + + def _get_metadata(self, file: File) -> Metadata: + meta = file.path.split(os.path.sep) + metadata = super()._get_metadata(file) + metadata.channels_reference = meta[-5] + return metadata + + def _get_annotation(self, file: File) -> List[Annotation]: + logging.debug("Add file %s annotations to index", file.uuid) + return [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=60, + end=120, + label=file.path.split(os.path.sep)[-6], + ) + ] diff --git a/pyeeglab/dataset/tuh_eeg/artifact_dataset.py b/pyeeglab/dataset/tuh_eeg/artifact_dataset.py new file mode 100644 index 0000000..23ee37e --- /dev/null +++ b/pyeeglab/dataset/tuh_eeg/artifact_dataset.py @@ -0,0 +1,51 @@ +import os +import logging + +from typing import List + +from .utils import rsync, parse_tse + +from ..dataset import Dataset +from ..file import File +from ..metadata import Metadata +from ..annotation import Annotation + + +class TUHEEGArtifactDataset(Dataset): + + def __init__( + self, + path: str = "./data/tuh_eeg_artifact/", + version: str = "1.0.0", + exclude_channels_set: List[str] = [ + "BURSTS", + "EMG-REF", + "IBI", + "PHOTIC-REF", + "SUPPR" + ], + exclude_channels_reference: List[str] = [ + "02_tcp_le", + "03_tcp_ar_a" + ], + ) -> None: + super().__init__( + path=path, + name="Temple University Hospital EEG Artifact Dataset", + version="v"+version, + exclude_channels_set=exclude_channels_set, + exclude_channels_reference=exclude_channels_reference, + ) + + def download(self, user: str = None, password: str = None) -> None: + rsync(self.path, user, password, "tuh_eeg_artifact", self.version) + + def _get_metadata(self, file: File) -> Metadata: + meta = file.path.split(os.path.sep) + metadata = super()._get_metadata(file) + metadata.channels_reference = meta[-5] + return metadata + + def _get_annotation(self, file: File) -> List[Annotation]: + logging.debug("Add file %s annotations to index", file.uuid) + return parse_tse(file) diff --git a/pyeeglab/dataset/tuh_eeg/seizure_dataset.py b/pyeeglab/dataset/tuh_eeg/seizure_dataset.py new file mode 100644 index 0000000..593ef3d --- /dev/null +++ b/pyeeglab/dataset/tuh_eeg/seizure_dataset.py @@ -0,0 +1,57 @@ +import os +import re +import logging + +from uuid import uuid4 + +from typing import List + +from .utils import rsync, parse_lbl + +from ..dataset import Dataset +from ..file import File +from ..metadata import Metadata +from ..annotation import Annotation + + +class TUHEEGSeizureDataset(Dataset): + + def __init__( + self, + path: str = "./data/tuh_eeg_seizure/", + version: str = "1.5.2", + exclude_channels_set: List[str] = [ + "BURSTS", + "ECG EKG-REF", + "EMG-REF", + "IBI", + "PHOTIC-REF", + "PULSE RATE", + "RESP ABDOMEN-RE", + "SUPPR" + ], + exclude_channels_reference: List[str] = [ + "02_tcp_le", + "03_tcp_ar_a" + ], + ) -> None: + super().__init__( + path=path, + name="Temple University Hospital EEG Seizure Dataset", + version="v"+version, + exclude_channels_set=exclude_channels_set, + exclude_channels_reference=exclude_channels_reference, + ) + + def download(self, user: str = None, password: str = None) -> None: + rsync(self.path, user, password, "tuh_eeg_seizure", self.version) + + def _get_metadata(self, file: File) -> Metadata: + meta = file.path.split(os.path.sep) + metadata = super()._get_metadata(file) + metadata.channels_reference = meta[-5] + return metadata + + def _get_annotation(self, file: File) -> List[Annotation]: + logging.debug("Add file %s annotations to index", file.uuid) + return parse_lbl(file) diff --git a/pyeeglab/dataset/tuh_eeg/utils.py b/pyeeglab/dataset/tuh_eeg/utils.py new file mode 100644 index 0000000..d4a1d4c --- /dev/null +++ b/pyeeglab/dataset/tuh_eeg/utils.py @@ -0,0 +1,96 @@ +import re +import logging +import subprocess + +from uuid import uuid4 + +from typing import List + +from ..file import File +from ..annotation import Annotation + + +def rsync(path: str, user: str, password: str, slug: str, version: str) -> None: + if user is not None and password is not None: + logging.info("Download started, it will take some time") + url = user + "@" + "www.isip.piconepress.com:~/data/" + url = url + slug + "/v" + version + "/" + process = subprocess.Popen( + [ + "sshpass", + "-p", + password, + "rsync", + "-auxvL", + url, + path + ], + stdout=subprocess.PIPE, + universal_newlines=True, + ) + while True: + output = process.stdout.readline() + logging.info(output.strip()) + if process.poll() is not None: + for output in process.stdout.readlines(): + logging.info(output.strip()) + break + else: + logging.warn("Download disabled, add 'user' and 'password' as optional parameters during data set creation") + logging.warn("Request your login credentials at: https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php") + + +def parse_lbl(file: File) -> List[Annotation]: + path = file.path[:-4] + ".lbl" + with open(path, "r") as reader: + annotations = reader.read() + symbols = re.compile( + r"^symbols\[0\] = ({.*})$", + re.MULTILINE + ) + symbols = re.findall(symbols, annotations) + symbols = eval(symbols[0]) + pattern = re.compile( + r"^label = {([^,]*), ([^,]*), ([^,]*), ([^,]*), ([^,]*), ([^}]*)}$", + re.MULTILINE + ) + annotations = re.findall(pattern, annotations) + annotations = { + (annotation[2], annotation[3], symbols[index]) + for annotation in annotations + for index, value in enumerate(eval(annotation[5])) + if value > 0 + } + annotations = [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=float(annotation[0]), + end=float(annotation[1]), + label=annotation[2] + ) + for annotation in annotations + ] + return annotations + + +def parse_tse(file: File) -> List[Annotation]: + path = file.path[:-4] + ".tse" + with open(path, "r") as reader: + annotations = reader.read() + pattern = re.compile( + r"^(\d+.\d+) (\d+.\d+) (\w+) (\d.\d+)$", + re.MULTILINE + ) + annotations = re.findall(pattern, annotations) + annotations = [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=float(annotation[0]), + end=float(annotation[1]), + label=annotation[2] + ) + for annotation in annotations + ] + return annotations diff --git a/pyeeglab/dataset/tuh_eeg_abnormal/__init__.py b/pyeeglab/dataset/tuh_eeg_abnormal/__init__.py deleted file mode 100644 index 926acfd..0000000 --- a/pyeeglab/dataset/tuh_eeg_abnormal/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .tuh_eeg_abnormal_loader import TUHEEGAbnormalLoader -from .tuh_eeg_abnormal_dataset import TUHEEGAbnormalDataset diff --git a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_dataset.py b/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_dataset.py deleted file mode 100644 index dcad1d8..0000000 --- a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_dataset.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Dict - -from .tuh_eeg_abnormal_loader import TUHEEGAbnormalLoader -from ..dataset import Dataset - - -class TUHEEGAbnormalDataset(Dataset): - - def __init__(self, path: str = './data/tuh_eeg_abnormal/v2.0.0/edf/') -> None: - super().__init__(TUHEEGAbnormalLoader(path)) - - def _get_dataset_env(self) -> Dict: - env = super()._get_dataset_env() - blacklist = ['IBI', 'BURSTS', 'STI 014', 'SUPPR'] - env['channels_set'] = list(set(env['channels_set']) - set(blacklist)) - env['class_id'] = 'normal' - return env diff --git a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_index.py b/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_index.py deleted file mode 100644 index 58ee4ba..0000000 --- a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_index.py +++ /dev/null @@ -1,46 +0,0 @@ -import logging -import json - -from typing import List - -from uuid import uuid4, uuid5, NAMESPACE_X500 -from os.path import join, sep - -from ...io import Index, Raw -from ...database import File, Metadata, Event - - -class TUHEEGAbnormalIndex(Index): - def __init__(self, path: str) -> None: - logging.debug('Create TUH EEG Corpus Index') - super().__init__('sqlite:///' + join(path, 'index.db'), path) - self.index() - - def _get_file(self, path: str) -> File: - length = len(self.path) - meta = path[length:].split(sep) - return File( - id=str(uuid5(NAMESPACE_X500, path[length:])), - extension=meta[-1].split('.')[-1], - path=path[length:] - ) - - def _get_record_metadata(self, file: File) -> Metadata: - meta = file.path.split(sep) - metadata = super()._get_record_metadata(file) - metadata.channels_reference = meta[2] - return metadata - - def _get_record_events(self, file: File) -> List[Event]: - logging.debug('Add file %s raw events to index', file.id) - raw = Raw(file.id, join(self.path, file.path)) - return [ - Event( - id=str(uuid4()), - file_id=raw.id, - begin=60, - end=120, - duration=60, - label=raw.path.split(sep)[-6] - ) - ] diff --git a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_loader.py b/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_loader.py deleted file mode 100644 index 2adabcc..0000000 --- a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_loader.py +++ /dev/null @@ -1,14 +0,0 @@ -import logging - -from ...io import DataLoader -from .tuh_eeg_abnormal_index import TUHEEGAbnormalIndex - -from typing import List - - -class TUHEEGAbnormalLoader(DataLoader): - - def __init__(self, path: str) -> None: - super().__init__(path) - logging.debug('Create TUH EEG Corpus Loader') - self.index = TUHEEGAbnormalIndex(self.path) diff --git a/pyeeglab/dataset/tuh_eeg_artifact/__init__.py b/pyeeglab/dataset/tuh_eeg_artifact/__init__.py deleted file mode 100644 index e61f493..0000000 --- a/pyeeglab/dataset/tuh_eeg_artifact/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .tuh_eeg_artifact_loader import TUHEEGArtifactLoader -from .tuh_eeg_artifact_dataset import TUHEEGArtifactDataset diff --git a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_dataset.py b/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_dataset.py deleted file mode 100644 index b32af2f..0000000 --- a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_dataset.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Dict - -from .tuh_eeg_artifact_loader import TUHEEGArtifactLoader -from ..dataset import Dataset - - -class TUHEEGArtifactDataset(Dataset): - - def __init__(self, path: str = './data/tuh_eeg_artifact/v1.0.0/edf/') -> None: - super().__init__(TUHEEGArtifactLoader(path)) - - def _get_dataset_env(self) -> Dict: - env = super()._get_dataset_env() - env['class_id'] = 'null' - return env diff --git a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_index.py b/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_index.py deleted file mode 100644 index b9c74d3..0000000 --- a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_index.py +++ /dev/null @@ -1,56 +0,0 @@ -import logging -import re -import json - -from typing import List - -from uuid import uuid4, uuid5, NAMESPACE_X500 -from os.path import join, sep - -from ...io import Index, Raw -from ...database import File, Metadata, Event - - -class TUHEEGArtifactIndex(Index): - - def __init__(self, path: str) -> None: - logging.debug('Create TUH EEG Corpus Index') - super().__init__('sqlite:///' + join(path, 'index.db'), path) - self.index() - - def _get_file(self, path: str) -> File: - length = len(self.path) - meta = path[length:].split(sep) - return File( - id=str(uuid5(NAMESPACE_X500, path[length:])), - extension=meta[-1].split('.')[-1], - path=path[length:] - ) - - def _get_record_metadata(self, file: File) -> Metadata: - meta = file.path.split(sep) - metadata = super()._get_record_metadata(file) - metadata.channels_reference = meta[0] - return metadata - - def _get_record_events(self, file: File) -> List[Event]: - logging.debug('Add file %s raw events to index', file.id) - raw = Raw(file.id, join(self.path, file.path)) - path = raw.path[:-4] + '.tse' - with open(path, 'r') as file: - annotations = file.read() - pattern = re.compile( - r'^(\d+.\d+) (\d+.\d+) (\w+) (\d.\d+)$', re.MULTILINE) - events = re.findall(pattern, annotations) - events = [ - Event( - id=str(uuid4()), - file_id=raw.id, - begin=float(event[0]), - end=float(event[1]), - duration=(float(event[1])-float(event[0])), - label=event[2] - ) - for event in events - ] - return events diff --git a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_loader.py b/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_loader.py deleted file mode 100644 index e09e7ac..0000000 --- a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_loader.py +++ /dev/null @@ -1,14 +0,0 @@ -import logging - -from typing import List - -from ...io import DataLoader -from .tuh_eeg_artifact_index import TUHEEGArtifactIndex - - -class TUHEEGArtifactLoader(DataLoader): - - def __init__(self, path: str, exclude_channels_reference: List[str] = ['02_tcp_le', '02_tcp_ar_a']) -> None: - super().__init__(path, exclude_channels_reference=exclude_channels_reference) - logging.debug('Create TUH EEG Corpus Loader') - self.index = TUHEEGArtifactIndex(self.path) diff --git a/pyeeglab/dataset/tuh_eeg_seizure/__init__.py b/pyeeglab/dataset/tuh_eeg_seizure/__init__.py deleted file mode 100644 index abf1a6c..0000000 --- a/pyeeglab/dataset/tuh_eeg_seizure/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .tuh_eeg_seizure_loader import TUHEEGSeizureLoader -from .tuh_eeg_seizure_dataset import TUHEEGSeizureDataset diff --git a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_dataset.py b/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_dataset.py deleted file mode 100644 index 8f391cd..0000000 --- a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_dataset.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Dict - -from .tuh_eeg_seizure_loader import TUHEEGSeizureLoader -from ..dataset import Dataset - - -class TUHEEGSeizureDataset(Dataset): - - def __init__(self, path: str = './data/tuh_eeg_seizure/v1.5.2/edf/') -> None: - super().__init__(TUHEEGSeizureLoader(path)) - - def _get_dataset_env(self) -> Dict: - env = super()._get_dataset_env() - env['class_id'] = 'fnsz' - return env diff --git a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_index.py b/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_index.py deleted file mode 100644 index 1d243d2..0000000 --- a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_index.py +++ /dev/null @@ -1,83 +0,0 @@ -import logging -import re -import json -import numpy as np - -from typing import List - -from uuid import uuid4, uuid5, NAMESPACE_X500 -from os.path import join, sep - -from ...io import Index, Raw -from ...database import File, Metadata, Event - - -class TUHEEGSeizureIndex(Index): - - def __init__(self, path: str, exclude_events: List = ['bckg', 'spsz']) -> None: - logging.debug('Create TUH EEG Corpus Index') - super().__init__('sqlite:///' + join(path, 'index.db'), path, exclude_events=exclude_events) - self.index() - - def _get_file(self, path: str) -> File: - length = len(self.path) - meta = path[length:].split(sep) - return File( - id=str(uuid5(NAMESPACE_X500, path[length:])), - extension=meta[-1].split('.')[-1], - path=path[length:] - ) - - def _get_record_metadata(self, file: File) -> Metadata: - meta = file.path.split(sep) - metadata = super()._get_record_metadata(file) - metadata.channels_reference = meta[1] - return metadata - - def _get_record_events(self, file: File) -> List[Event]: - logging.debug('Add file %s raw events to index', file.id) - raw = Raw(file.id, join(self.path, file.path)) - path = raw.path[:-4] + '.lbl' - with open(path, 'r') as file: - annotations = file.read() - pattern = re.compile( - r'^symbols\[0\] = ({.*})$', - re.MULTILINE - ) - mapper = re.findall(pattern, annotations) - mapper = eval(mapper[0]) - events = re.findall(pattern, annotations) - pattern = re.compile( - r'^label = {([^,]*), ([^,]*), ([^,]*), ([^,]*), ([^,]*), ([^}]*)}$', - re.MULTILINE - ) - events = re.findall(pattern, annotations) - labels = {} - for event in events: - intervall = (event[2], event[3]) - if intervall not in labels: - labels[intervall] = {} - channel = event[4] - if channel not in labels[intervall]: - labels[intervall][channel] = [] - label = np.array(json.loads(event[5])) - for i in list(np.nonzero(label)[0]): - labels[intervall][channel].append(mapper[i]) - labels = { - (intervall[0], intervall[1], label) - for intervall, channels in labels.items() - for channel, labels in channels.items() - for label in labels - } - events = [ - Event( - id=str(uuid4()), - file_id=raw.id, - begin=float(label[0]), - end=float(label[1]), - duration=(float(label[1])-float(label[0])), - label=label[2] - ) - for label in labels - ] - return events diff --git a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_loader.py b/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_loader.py deleted file mode 100644 index 5f2627a..0000000 --- a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_loader.py +++ /dev/null @@ -1,14 +0,0 @@ -import logging - -from typing import List - -from ...io import DataLoader -from .tuh_eeg_seizure_index import TUHEEGSeizureIndex - - -class TUHEEGSeizureLoader(DataLoader): - - def __init__(self, path: str, exclude_channels_reference: List[str] = ['02_tcp_le', '02_tcp_ar_a']) -> None: - super().__init__(path, exclude_channels_reference=exclude_channels_reference) - logging.debug('Create TUH EEG Corpus Loader') - self.index = TUHEEGSeizureIndex(self.path) diff --git a/pyeeglab/io/__init__.py b/pyeeglab/io/__init__.py deleted file mode 100644 index 12695ed..0000000 --- a/pyeeglab/io/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .index import Index -from .loader import DataLoader -from .raw import Raw diff --git a/pyeeglab/io/index.py b/pyeeglab/io/index.py deleted file mode 100644 index 9b13638..0000000 --- a/pyeeglab/io/index.py +++ /dev/null @@ -1,139 +0,0 @@ -import logging -import json - -from abc import ABC, abstractmethod - -from os import walk -from os.path import join, splitext -from uuid import uuid4 -from multiprocessing import Pool, cpu_count -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from mne import set_log_file - -from ..database import BASE_TABLE, File, Metadata, Event -from .raw import Raw - -from typing import List - - -class Index(ABC): - """ An abstract class representing the Index mechanism that is used to - discover the dataset structure. - - It is the first component that must be implemented in order to provide - full support to a specific dataset. If the dataset structure is regular, - the only method that must be implemented id "_get_file(path)". - """ - - def __init__(self, db: str, path: str, include_extensions: List[str] = ['edf'], exclude_events: List[str] = None) -> None: - """ - Parameters - ---------- - db : str - The database connection handle expressed as string. This is usually - configured by the Loader class as a sqlite handle, but in theory it - could be used with any type of database connection. - - """ - logging.debug('Create index at %s', db) - logging.debug('Load index at %s', db) - engine = create_engine(db) - BASE_TABLE.metadata.create_all(engine) - self.db = sessionmaker(bind=engine)() - self.path = path - self.include_extensions = include_extensions - self.exclude_events = exclude_events - logging.debug('Redirect MNE logging interface to file') - set_log_file(join(path, 'mne.log'), overwrite=False) - - def __getstate__(self): - # Workaround for unpickable sqlalchemy.orm.session - # during multiprocess dataset loading - state = self.__dict__.copy() - del state['db'] - return state - - def _get_paths(self, exclude_extensions: List[str] = ['.db', '.gz', '.log']) -> List[str]: - logging.debug('Get files from path') - paths = [ - join(dirpath, filename) - for dirpath, _, filenames in walk(self.path) - for filename in filenames - if splitext(filename)[1] not in exclude_extensions - ] - return paths - - @abstractmethod - def _get_file(self, path: str) -> File: - pass - - def _get_files(self, paths: List[str]) -> List[File]: - pool = Pool(cpu_count()) - files = pool.map(self._get_file, paths) - pool.close() - pool.join() - return files - - def _get_record_metadata(self, file: File) -> Metadata: - logging.debug('Add file %s raw metadata to index', file.id) - raw = Raw(file.id, join(self.path, file.path)) - return Metadata( - file_id=raw.id, - duration=raw.open().n_times/raw.open().info['sfreq'], - channels_count=raw.open().info['nchan'], - channels_set=json.dumps(raw.open().info['ch_names']), - sampling_frequency=raw.open().info['sfreq'], - max_value=raw.open().get_data().max(), - min_value=raw.open().get_data().min(), - ) - - def _parallel_record_metadata(self, files: List[File]) -> List[Metadata]: - pool = Pool(cpu_count()) - metadata = pool.map(self._get_record_metadata, files) - pool.close() - pool.join() - return metadata - - def _get_record_events(self, file: File) -> List[Event]: - logging.debug('Add file %s raw events to index', file.id) - raw = Raw(file.id, join(self.path, file.path)) - events = raw.get_events() - for event in events: - event['id'] = str(uuid4()) - event['file_id'] = raw.id - events = [Event(**event) for event in events] - return events - - def _parallel_record_events(self, files: List[File]) -> List[Event]: - pool = Pool(cpu_count()) - events = pool.map(self._get_record_events, files) - pool.close() - pool.join() - events = [e for event in events for e in event] - if self.exclude_events: - events = [ - e for e in events - if e.label not in self.exclude_events - ] - return events - - def index(self) -> None: - logging.debug('Index files') - files = self._get_paths() - files = self._get_files(files) - files = [ - file - for file in files - if not self.db.query(File).filter(File.id == file.id).all() - ] - for file in files: - logging.debug('Add file %s raw to index', file.id) - if self.include_extensions: - raws = [ - file for file in files if file.extension in self.include_extensions] - metadata = self._parallel_record_metadata(raws) - events = self._parallel_record_events(raws) - self.db.add_all(files + metadata + events) - self.db.commit() - logging.debug('Index files completed') diff --git a/pyeeglab/io/loader.py b/pyeeglab/io/loader.py deleted file mode 100644 index 73e732a..0000000 --- a/pyeeglab/io/loader.py +++ /dev/null @@ -1,153 +0,0 @@ -import logging -import json - -from abc import ABC -from os.path import isfile, join, sep -from hashlib import md5 -from multiprocessing import Pool, cpu_count - -from ..database import File, Metadata, Event -from .raw import Raw -from .index import Index - -from typing import List, Dict - - -class DataLoader(ABC): - - index: Index - - def __init__(self, path: str, exclude_channels_reference: List[str] = None, exclude_frequency: List[int] = None, - exclude_files: List[str] = None, minimum_event_duration: float = -1) -> None: - logging.debug('Create data loader') - if path[-1] != sep: - path = path + sep - self.path = path - self.exclude_channels_reference = exclude_channels_reference - self.exclude_frequency = exclude_frequency - self.exclude_files = exclude_files - self.minimum_event_duration = minimum_event_duration - - def __getstate__(self): - # Workaround for unpickable sqlalchemy.orm.session - # during multiprocess dataset loading - state = self.__dict__.copy() - del state['index'] - return state - - def _get_data_by_event(self, f: File, e: Event) -> Raw: - path_edf = join(self.path, f.path) - path_fif = path_edf + '_' + e.id + '_raw.fif.gz' - if not isfile(path_fif): - edf = Raw(f.id, path_edf, e.label) - edf.crop(e.begin, e.end-e.begin) - edf.open().save(path_fif) - fif = Raw(f.id, path_fif, e.label) - return fif - - def set_minimum_event_duration(self, minimum_event_duration: float) -> 'DataLoader': - self.minimum_event_duration = minimum_event_duration - - def get_dataset(self) -> List[Raw]: - files = self.index.db.query(File, Metadata, Event) - files = files.filter(File.id == Metadata.file_id) - files = files.filter(File.id == Event.file_id) - if self.index.include_extensions: - files = files.filter(File.extension.in_(self.index.include_extensions)) - if self.exclude_channels_reference: - files = files.filter(~Metadata.channels_reference.in_(self.exclude_channels_reference)) - if self.exclude_frequency: - files = files.filter(~Metadata.sampling_frequency.in_(self.exclude_frequency)) - if self.exclude_files: - files = files.filter(~File.path.in_(self.exclude_files)) - if self.minimum_event_duration > 0: - files = files.filter(Event.duration >= self.minimum_event_duration) - files = files.all() - files = [(file[0], file[2]) for file in files] - pool = Pool(cpu_count()) - fifs = pool.starmap(self._get_data_by_event, files) - pool.close() - pool.join() - return fifs - - def get_dataset_text(self) -> Dict: - txts = self.index.db.query(File, Event) - txts = txts.filter(File.id == Event.file_id) - txts = txts.filter(File.extension == 'txt') - txts = txts.all() - txts = {f.id: (join(self.index.path, f.path), e.label) for f, e in txts} - return txts - - def get_channels_set(self) -> List[str]: - files = self.index.db.query(File, Metadata) - files = files.filter(File.id == Metadata.file_id) - if self.exclude_channels_reference: - files = files.filter(~Metadata.channels_reference.in_(self.exclude_channels_reference)) - if self.exclude_frequency: - files = files.filter(~Metadata.sampling_frequency.in_(self.exclude_frequency)) - if self.exclude_files: - files = files.filter(~File.path.in_(self.exclude_files)) - if self.minimum_event_duration > 0: - files = files.filter(Event.duration >= self.minimum_event_duration) - files = files.group_by(Metadata.channels_set) - files = files.all() - files = [file[1] for file in files] - files = [set(json.loads(file.channels_set)) for file in files] - channels_set = files[0] - for file in files[1:]: - channels_set = channels_set.intersection(file) - return sorted(channels_set) - - def get_lowest_frequency(self) -> float: - frequency = self.index.db.query(Metadata) - if self.exclude_frequency: - frequency = frequency.filter(~Metadata.sampling_frequency.in_(self.exclude_frequency)) - frequency = frequency.all() - frequency = min([f.sampling_frequency for f in frequency], default=0) - return frequency - - def get_max_value(self) -> float: - files = self.index.db.query(File, Metadata) - files = files.filter(File.id == Metadata.file_id) - if self.exclude_channels_reference: - files = files.filter(~Metadata.channels_reference.in_(self.exclude_channels_reference)) - if self.exclude_frequency: - files = files.filter(~Metadata.sampling_frequency.in_(self.exclude_frequency)) - if self.exclude_files: - files = files.filter(~File.path.in_(self.exclude_files)) - if self.minimum_event_duration > 0: - files = files.filter(Event.duration >= self.minimum_event_duration) - files = files.group_by(Metadata.channels_set) - files = files.all() - max_value = max([f.max_value for _, f in files], default=0) - return max_value - - def get_min_value(self) -> float: - files = self.index.db.query(File, Metadata) - files = files.filter(File.id == Metadata.file_id) - if self.exclude_channels_reference: - files = files.filter(~Metadata.channels_reference.in_(self.exclude_channels_reference)) - if self.exclude_frequency: - files = files.filter(~Metadata.sampling_frequency.in_(self.exclude_frequency)) - if self.exclude_files: - files = files.filter(~File.path.in_(self.exclude_files)) - if self.minimum_event_duration > 0: - files = files.filter(Event.duration >= self.minimum_event_duration) - files = files.group_by(Metadata.channels_set) - files = files.all() - min_value = min([f.min_value for _, f in files], default=0) - return min_value - - def __eq__(self, other): - return hash(self) == hash(other) - - def __hash__(self): - value = [self.path] + [self.minimum_event_duration] - if self.exclude_channels_reference: - value += self.exclude_channels_reference - if self.exclude_frequency: - value += self.exclude_frequency - value = json.dumps(value).encode() - value = md5(value).hexdigest() - value = int(value, 16) - return value diff --git a/pyeeglab/io/raw.py b/pyeeglab/io/raw.py deleted file mode 100644 index 75a41ff..0000000 --- a/pyeeglab/io/raw.py +++ /dev/null @@ -1,83 +0,0 @@ -import logging -from typing import List -from importlib.util import find_spec - -from mne.io import Raw as Reader -from mne.io import read_raw_edf, read_raw_fif - -class Raw(): - - reader: Reader = None - - def __init__(self, fid: str, path: str, label: str = None) -> None: - self.id = fid - self.path = path - self.label = label - n_jobs = 1 - if find_spec('cupy') is not None: - n_jobs = 'cuda' - self.n_jobs = n_jobs - - def close(self) -> 'Raw': - if self.reader is not None: - logging.debug('Close Raw %s reader', self.id) - self.reader.close() - self.reader = None - return self - - def crop(self, offset: int, length: int) -> 'Raw': - logging.debug('Crop %s data to %s seconds from %s', self.id, length, offset) - tmax = self.open().n_times / self.open().info['sfreq'] - 0.1 - if offset + length < tmax: - tmax = offset + length - self.reader = self.open().crop(offset, tmax) - return self - - def open(self) -> Reader: - if self.reader is None: - if self.path.endswith('.edf'): - logging.debug('Open RawEDF %s reader', self.id) - try: - self.reader = read_raw_edf(self.path) - except RuntimeError: - logging.debug('Preload RawEDF %s reader', self.id) - self.reader = read_raw_edf(self.path, preload=True) - if self.path.endswith('.fif.gz'): - logging.debug('Open RawFIF %s reader', self.id) - self.reader = read_raw_fif(self.path) - return self.reader - - def get_events(self) -> List: - events = self.open().annotations - events = list(zip(events.onset, events.duration, events.description)) - events = [(event[0], event[0] + event[1], event[1], event[2]) for event in events] - keys = ['begin', 'end', 'duration', 'label'] - events = [dict(zip(keys, event)) for event in events] - return events - - def set_channels(self, channels: List[str]) -> 'Raw': - channels = set(self.open().ch_names) - set(channels) - channels = list(channels) - if len(channels) > 0: - logging.debug('Drop %s channels %s', self.id, '|'.join(channels)) - self.reader = self.open().drop_channels(channels) - return self - - def set_frequency(self, frequency: float) -> 'Raw': - sfreq = self.open().info['sfreq'] - if sfreq > frequency: - logging.debug('Downsample %s from %s Hz to %s Hz', self.id, sfreq, frequency) - self.reader = self.open().resample(frequency, n_jobs=self.n_jobs) - return self - - def set_filter(self, low_freq: float = None, high_freq: float = None) -> 'Raw': - if low_freq is not None or high_freq is not None: - logging.debug('Filter %s with low_req: %s Hz and high_freq: %s Hz', self.id, low_freq, high_freq) - self.reader = self.open().filter(low_freq, high_freq, n_jobs=self.n_jobs) - return self - - def notch_filter(self, freq: float) -> 'Raw': - if freq is not None: - logging.debug('Notch filter %s with freq: %s Hz ', self.id, freq) - self.reader = self.open().notch_filter(freq, n_jobs=self.n_jobs) - return self diff --git a/pyeeglab/pipeline/pipeline.py b/pyeeglab/pipeline/pipeline.py index f778b2e..ec082ff 100644 --- a/pyeeglab/pipeline/pipeline.py +++ b/pyeeglab/pipeline/pipeline.py @@ -1,18 +1,17 @@ -import logging import json - -import numpy as np -import pandas as pd +import logging from hashlib import md5 from os.path import join from multiprocessing import Pool, cpu_count -from ..io import Raw -from .preprocessor import Preprocessor - from typing import Dict, List +import numpy as np +import pandas as pd + +from .preprocessor import Preprocessor + class Pipeline(): @@ -32,11 +31,13 @@ def _check_nans(self, data): nans = data.isnull().values.any() return nans - def _trigger_pipeline(self, data: Raw, kwargs): - file_id = data.id - data.open().load_data() - for preprocessor in self.pipeline: - data = preprocessor.run(data, **kwargs) + def _trigger_pipeline(self, annotation, kwargs): + data = None + + with annotation as reader: + data = reader.load_data() + for preprocessor in self.pipeline: + data = preprocessor.run(data, **kwargs) nans = False if isinstance(data, list): @@ -44,11 +45,11 @@ def _trigger_pipeline(self, data: Raw, kwargs): else: nans = self._check_nans(data) if nans: - raise ValueError('Nans found in file with id {}'.format(file_id)) + raise ValueError('Nans found in file with id {}'.format(annotation.file_uuid)) return data - def run(self, data: List[Raw]) -> Dict: + def run(self, data) -> Dict: logging.debug('Environment variables: {}'.format( str(self.environment) )) @@ -62,7 +63,7 @@ def run(self, data: List[Raw]) -> Dict: labels = [self.labels_mapping[label] for label in labels] onehot_encoder = sorted(set(labels)) class_id = self.environment.get('class_id', None) - if class_id: + if class_id in onehot_encoder: onehot_encoder.remove(class_id) onehot_encoder = [class_id] + onehot_encoder labels = np.array([onehot_encoder.index(label) for label in labels]) diff --git a/pyeeglab/preprocess/features/brain_connectivity.py b/pyeeglab/preprocess/features/brain_connectivity.py index 96b96ac..ff3af4c 100644 --- a/pyeeglab/preprocess/features/brain_connectivity.py +++ b/pyeeglab/preprocess/features/brain_connectivity.py @@ -1,16 +1,15 @@ -import logging import json +import logging + +from itertools import product +from typing import List import numpy as np import pandas as pd - from yasa import bandpower -from itertools import product from ...pipeline import Preprocessor -from typing import List - class SpearmanCorrelation(Preprocessor): diff --git a/pyeeglab/preprocess/features/stat_features.py b/pyeeglab/preprocess/features/stat_features.py index 234ca03..5121fec 100644 --- a/pyeeglab/preprocess/features/stat_features.py +++ b/pyeeglab/preprocess/features/stat_features.py @@ -1,14 +1,11 @@ -import logging +from typing import List import numpy as np import pandas as pd - from scipy.integrate import simps from ...pipeline import Preprocessor -from typing import List - class Mean(Preprocessor): def run(self, data: List[pd.DataFrame], **kwargs) -> List[pd.DataFrame]: diff --git a/pyeeglab/preprocess/signal/channel_selector.py b/pyeeglab/preprocess/signal/channel_selector.py index 124ffaf..9f1fb4a 100644 --- a/pyeeglab/preprocess/signal/channel_selector.py +++ b/pyeeglab/preprocess/signal/channel_selector.py @@ -1,18 +1,18 @@ -import logging import json - -from ...io import Raw -from ...pipeline import Preprocessor +import logging from typing import List +from mne.io import Raw + +from ...pipeline import Preprocessor class CommonChannelSet(Preprocessor): - def __init__(self, blacklist: List[str] = []) -> None: + def __init__(self, blacklist: List[str] = None) -> None: super().__init__() logging.debug('Create common channels_set preprocessor') - self.blacklist = blacklist + self.blacklist = blacklist if blacklist else [] def to_json(self) -> str: out = { @@ -24,7 +24,9 @@ def to_json(self) -> str: return out def run(self, data: Raw, **kwargs) -> Raw: - channels = set(kwargs['channels_set']) - set(self.blacklist) - channels = list(channels) - data.set_channels(channels) + channels = set(data.ch_names) + channels = channels.difference(set(kwargs['channels_set'])) + data = data.drop_channels(channels) + data = data.drop_channels(self.blacklist) + data = data.reorder_channels(kwargs['channels_set']) return data diff --git a/pyeeglab/preprocess/signal/filter_selector.py b/pyeeglab/preprocess/signal/filter_selector.py index 33c7b15..2007338 100644 --- a/pyeeglab/preprocess/signal/filter_selector.py +++ b/pyeeglab/preprocess/signal/filter_selector.py @@ -1,7 +1,8 @@ -import logging import json +import logging + +from mne.io import Raw -from ...io import Raw from ...pipeline import Preprocessor @@ -24,8 +25,7 @@ def to_json(self) -> str: return out def run(self, data: Raw, **kwargs) -> Raw: - data.set_filter(self.low_freq, self.high_freq) - return data + return data.filter(self.low_freq, self.high_freq) class NotchFrequency(Preprocessor): @@ -45,5 +45,4 @@ def to_json(self) -> str: return out def run(self, data: Raw, **kwargs) -> Raw: - data.notch_filter(self.freq) - return data + return data.notch_filter(self.freq) diff --git a/pyeeglab/preprocess/signal/frequency_selector.py b/pyeeglab/preprocess/signal/frequency_selector.py index d3aca97..fc606c1 100644 --- a/pyeeglab/preprocess/signal/frequency_selector.py +++ b/pyeeglab/preprocess/signal/frequency_selector.py @@ -1,6 +1,7 @@ import logging -from ...io import Raw +from mne.io import Raw + from ...pipeline import Preprocessor @@ -11,5 +12,7 @@ def __init__(self) -> None: logging.debug('Create lowest_frequency preprocessor') def run(self, data: Raw, **kwargs) -> Raw: - data.set_frequency(kwargs['lowest_frequency']) + lowest_frequency = kwargs['lowest_frequency'] + if data.info['sfreq'] > lowest_frequency: + data = data.resample(lowest_frequency) return data diff --git a/pyeeglab/preprocess/signal/normalization.py b/pyeeglab/preprocess/signal/normalization.py index 7670ca2..11b07d4 100644 --- a/pyeeglab/preprocess/signal/normalization.py +++ b/pyeeglab/preprocess/signal/normalization.py @@ -1,12 +1,12 @@ import logging +from typing import List + import numpy as np import pandas as pd from ...pipeline import Preprocessor -from typing import List - class MinMaxNormalization(Preprocessor): def run(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame: diff --git a/pyeeglab/preprocess/transform/__init__.py b/pyeeglab/preprocess/transform/__init__.py index ec636cc..d28947a 100644 --- a/pyeeglab/preprocess/transform/__init__.py +++ b/pyeeglab/preprocess/transform/__init__.py @@ -1,3 +1,2 @@ from .data_converter import ToDataframe, ToNumpy, ToMergedDataframes, CorrelationToAdjacency from .frame_generator import StaticWindow, DynamicWindow, StaticWindowOverlap, DynamicWindowOverlap -from .graph_generator import GraphGenerator, GraphWithFeatures diff --git a/pyeeglab/preprocess/transform/data_converter.py b/pyeeglab/preprocess/transform/data_converter.py index 3e2124f..06f747a 100644 --- a/pyeeglab/preprocess/transform/data_converter.py +++ b/pyeeglab/preprocess/transform/data_converter.py @@ -1,13 +1,14 @@ -import logging import json +import logging + +from typing import List import numpy as np import pandas as pd -from ...io import Raw -from ...pipeline import Preprocessor +from mne.io import Raw -from typing import List +from ...pipeline import Preprocessor class ToDataframe(Preprocessor): @@ -17,8 +18,7 @@ def __init__(self) -> None: logging.debug('Create DataFrame converter preprocessor') def run(self, data: Raw, **kwargs) -> pd.DataFrame: - dataframe = data.open().to_data_frame().drop('time', axis=1) - data.close() + dataframe = data.to_data_frame().drop('time', axis=1) return dataframe diff --git a/pyeeglab/preprocess/transform/frame_generator.py b/pyeeglab/preprocess/transform/frame_generator.py index 9c10a20..9fd9ee1 100644 --- a/pyeeglab/preprocess/transform/frame_generator.py +++ b/pyeeglab/preprocess/transform/frame_generator.py @@ -1,13 +1,12 @@ -import logging import json - -import pandas as pd +import logging from math import floor +from typing import List -from ...pipeline import Preprocessor +import pandas as pd -from typing import List +from ...pipeline import Preprocessor class StaticWindow(Preprocessor): diff --git a/pyeeglab/preprocess/transform/graph_generator.py b/pyeeglab/preprocess/transform/graph_generator.py deleted file mode 100644 index 59339f7..0000000 --- a/pyeeglab/preprocess/transform/graph_generator.py +++ /dev/null @@ -1,47 +0,0 @@ -import logging -import json - -import pandas as pd -import networkx as nx -from networkx.convert_matrix import from_pandas_edgelist - -from ...pipeline import Preprocessor - -from typing import List - - -class GraphGenerator(Preprocessor): - - def __init__(self) -> None: - super().__init__() - logging.debug('Create new graph generator preprocessor') - - def run(self, data: List[pd.DataFrame], **kwargs) -> List[nx.Graph]: - nodes = [set(d.From.to_list() + d.To.to_list()) for d in data] - edges = [d.where(d.Weight != 0).dropna().reset_index(drop=True) for d in data] - graphs = [] - for i in range(len(data)): - graph = from_pandas_edgelist(edges[i], 'From', 'To') - graph.add_nodes_from(nodes[i]) - graphs.append(graph) - return graphs - - -class GraphWithFeatures(GraphGenerator): - - def __init__(self): - super().__init__() - logging.debug('Create new graph with features preprocessor') - - def _run(self, adjacency: List[pd.DataFrame], features: List[pd.DataFrame], **kwargs) -> List[nx.Graph]: - graphs = super().run(adjacency, **kwargs) - for i, graph in enumerate(graphs): - feature = { - node: {'features': features[i].loc[node, :].to_numpy()} - for node in graph.nodes - } - nx.set_node_attributes(graph, feature) - return graphs - - def run(self, data, **kwargs) -> List[nx.Graph]: - return self._run(*data, **kwargs) diff --git a/requirements.txt b/requirements.txt index a02381a..b2628f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -mne==0.20.0 -networkx>=2.2 numpy scipy pandas -sqlalchemy>=1.2 wfdb +sqlalchemy +dataclasses +mne>=0.21.0 yasa>=0.1.6 diff --git a/setup.py b/setup.py index a12c589..d6153c7 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='PyEEGLab', - version='0.9.4', + version='0.10.0', author='Alessio Zanga', author_email='alessio.zanga@outlook.it', license='GNU GENERAL PUBLIC LICENSE - Version 3, 29 June 2007', @@ -32,13 +32,13 @@ url='https://github.com/AlessioZanga/PyEEGLab', packages=setuptools.find_packages(), install_requires=[ - 'mne==0.20.0', - 'networkx>=2.2', 'numpy', 'scipy', 'pandas', - 'sqlalchemy>=1.2', 'wfdb', + 'sqlalchemy', + 'dataclasses', + 'mne>=0.21.0', 'yasa>=0.1.6', ], ) diff --git a/tests/test_chbmit.py b/tests/test_chbmit.py index b4b0f48..69ec621 100644 --- a/tests/test_chbmit.py +++ b/tests/test_chbmit.py @@ -3,38 +3,29 @@ import unittest sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -from pyeeglab import CHBMITLoader, CHBMITDataset, \ - Pipeline, CommonChannelSet, LowestFrequency, BandPassFrequency, ToDataframe, \ - DynamicWindow, ForkedPreprocessor, BinarizedSpearmanCorrelation, \ - CorrelationToAdjacency, Bandpower, GraphWithFeatures +from pyeeglab import * class TestCHBMIT(unittest.TestCase): - PATH = './tests/samples/physionet.org/files/chbmit/1.0.0' + PATH = './tests/samples/physionet.org/files/chbmit/' def test_index(self): - CHBMITLoader(self.PATH) + PhysioNetCHBMITDataset(self.PATH) def test_loader(self): - loader = CHBMITLoader(self.PATH) - loader.get_dataset() - loader.get_dataset_text() - loader.get_channels_set() - loader.get_lowest_frequency() + loader = PhysioNetCHBMITDataset(self.PATH) + loader.maximal_channels_subset + loader.lowest_frequency + loader.signal_min_max_range def test_dataset(self): - dataset = CHBMITDataset(self.PATH) + dataset = PhysioNetCHBMITDataset(self.PATH) preprocessing = Pipeline([ CommonChannelSet(), LowestFrequency(), BandPassFrequency(0.1, 47), ToDataframe(), DynamicWindow(4), - ForkedPreprocessor( - inputs=[ - [BinarizedSpearmanCorrelation(), CorrelationToAdjacency()], - Bandpower() - ], - output=GraphWithFeatures() - ) + Skewness(), + ToNumpy() ]) dataset = dataset.set_pipeline(preprocessing).load() diff --git a/tests/test_eegmmidb.py b/tests/test_eegmmidb.py index 8b29b66..4d9248e 100644 --- a/tests/test_eegmmidb.py +++ b/tests/test_eegmmidb.py @@ -3,38 +3,29 @@ import unittest sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -from pyeeglab import EEGMMIDBLoader, EEGMMIDBDataset, \ - Pipeline, CommonChannelSet, LowestFrequency, BandPassFrequency, ToDataframe, \ - DynamicWindow, ForkedPreprocessor, BinarizedSpearmanCorrelation, \ - CorrelationToAdjacency, Bandpower, GraphWithFeatures +from pyeeglab import * class TestEEGMMIDB(unittest.TestCase): - PATH = './tests/samples/physionet.org/files/eegmmidb/1.0.0' + PATH = './tests/samples/physionet.org/files/eegmmidb/' def test_index(self): - EEGMMIDBLoader(self.PATH) + PhysioNetEEGMMIDBDataset(self.PATH) def test_loader(self): - loader = EEGMMIDBLoader(self.PATH) - loader.get_dataset() - loader.get_dataset_text() - loader.get_channels_set() - loader.get_lowest_frequency() + loader = PhysioNetEEGMMIDBDataset(self.PATH) + loader.maximal_channels_subset + loader.lowest_frequency + loader.signal_min_max_range def test_dataset(self): - dataset = EEGMMIDBDataset(self.PATH) + dataset = PhysioNetEEGMMIDBDataset(self.PATH) preprocessing = Pipeline([ CommonChannelSet(), LowestFrequency(), BandPassFrequency(0.1, 47), ToDataframe(), DynamicWindow(4), - ForkedPreprocessor( - inputs=[ - [BinarizedSpearmanCorrelation(), CorrelationToAdjacency()], - Bandpower() - ], - output=GraphWithFeatures() - ) + Skewness(), + ToNumpy() ]) dataset = dataset.set_pipeline(preprocessing).load() diff --git a/tests/test_tuh_eeg_abnormal.py b/tests/test_tuh_eeg_abnormal.py index 159f72a..a9e8bcb 100644 --- a/tests/test_tuh_eeg_abnormal.py +++ b/tests/test_tuh_eeg_abnormal.py @@ -3,23 +3,19 @@ import unittest sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -from pyeeglab import TUHEEGAbnormalLoader, TUHEEGAbnormalDataset, \ - Pipeline, CommonChannelSet, LowestFrequency, BandPassFrequency, ToDataframe, \ - DynamicWindow, ForkedPreprocessor, BinarizedSpearmanCorrelation, \ - CorrelationToAdjacency, Bandpower, GraphWithFeatures +from pyeeglab import * class TestTUHEEGAbnormal(unittest.TestCase): - PATH = './tests/samples/tuh_eeg_abnormal/v2.0.0/edf' + PATH = './tests/samples/tuh_eeg_abnormal/' def test_index(self): - TUHEEGAbnormalLoader(self.PATH) + TUHEEGAbnormalDataset(self.PATH) def test_loader(self): - loader = TUHEEGAbnormalLoader(self.PATH) - loader.get_dataset() - loader.get_dataset_text() - loader.get_channels_set() - loader.get_lowest_frequency() + loader = TUHEEGAbnormalDataset(self.PATH) + loader.maximal_channels_subset + loader.lowest_frequency + loader.signal_min_max_range def test_dataset(self): dataset = TUHEEGAbnormalDataset(self.PATH) @@ -29,12 +25,7 @@ def test_dataset(self): BandPassFrequency(0.1, 47), ToDataframe(), DynamicWindow(4), - ForkedPreprocessor( - inputs=[ - [BinarizedSpearmanCorrelation(), CorrelationToAdjacency()], - Bandpower() - ], - output=GraphWithFeatures() - ) + Skewness(), + ToNumpy() ]) dataset = dataset.set_pipeline(preprocessing).load() diff --git a/tests/test_tuh_eeg_artifact.py b/tests/test_tuh_eeg_artifact.py index b03ecc1..595ca91 100644 --- a/tests/test_tuh_eeg_artifact.py +++ b/tests/test_tuh_eeg_artifact.py @@ -3,40 +3,29 @@ import unittest sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -from pyeeglab import TUHEEGArtifactLoader, TUHEEGArtifactDataset, \ - Pipeline, CommonChannelSet, LowestFrequency, BandPassFrequency, ToDataframe, \ - DynamicWindow, ForkedPreprocessor, BinarizedSpearmanCorrelation, \ - CorrelationToAdjacency, Bandpower, GraphWithFeatures +from pyeeglab import * class TestTUHEEGArtifact(unittest.TestCase): - PATH = './tests/samples/tuh_eeg_artifact/v1.0.0/edf' + PATH = './tests/samples/tuh_eeg_artifact/' def test_index(self): - TUHEEGArtifactLoader(self.PATH) + TUHEEGArtifactDataset(self.PATH) def test_loader(self): - loader = TUHEEGArtifactLoader(self.PATH) - loader.get_dataset() - loader.get_dataset_text() - loader.get_channels_set() - loader.get_lowest_frequency() + loader = TUHEEGArtifactDataset(self.PATH) + loader.maximal_channels_subset + loader.lowest_frequency + loader.signal_min_max_range def test_dataset(self): dataset = TUHEEGArtifactDataset(self.PATH) - """ preprocessing = Pipeline([ - CommonChannelSet(['EEG T1-REF', 'EEG T2-REF']), + CommonChannelSet(), LowestFrequency(), BandPassFrequency(0.1, 47), ToDataframe(), DynamicWindow(4), - ForkedPreprocessor( - inputs=[ - [BinarizedSpearmanCorrelation(), CorrelationToAdjacency()], - Bandpower() - ], - output=GraphWithFeatures() - ) + Skewness(), + ToNumpy() ]) dataset = dataset.set_pipeline(preprocessing).load() - """