From 36177810c33c4ef5fad895d12318c719725663ac Mon Sep 17 00:00:00 2001 From: AlessioZanga Date: Mon, 28 Sep 2020 16:48:21 +0200 Subject: [PATCH] Refactor index, cache, log and preprocessing systems --- CHANGELOG.md | 25 ++ Makefile | 21 -- README.md | 17 +- examples/tensorboard/example_tensorboard.py | 12 +- .../tensorboard/example_tensorboard_cnn.py | 12 +- .../tensorboard/example_tensorboard_gat.py | 12 +- .../tensorboard/example_tensorboard_gnn.py | 12 +- examples/tensorboard/example_validation.py | 12 +- .../example_cnn_dense_classification.py | 5 +- ...le_cnn_dense_classification_with_filter.py | 5 +- .../example_cnn_lstm_classification.py | 5 +- ...ple_cnn_lstm_classification_with_filter.py | 5 +- .../example_gat_lstm_classification.py | 5 +- ...ple_gat_lstm_classification_with_filter.py | 5 +- pyeeglab/__init__.py | 8 +- pyeeglab/cache/__init__.py | 1 - pyeeglab/cache/cache.py | 65 ---- pyeeglab/database/__init__.py | 1 - pyeeglab/database/tables.py | 106 ------ pyeeglab/dataset/__init__.py | 7 +- pyeeglab/dataset/annotation.py | 33 ++ pyeeglab/dataset/chbmit/__init__.py | 2 - pyeeglab/dataset/chbmit/chbmit_dataset.py | 15 - pyeeglab/dataset/chbmit/chbmit_index.py | 57 ---- pyeeglab/dataset/chbmit/chbmit_loader.py | 20 -- pyeeglab/dataset/dataset.py | 301 ++++++++++++++++-- pyeeglab/dataset/declarative_base.py | 2 + pyeeglab/dataset/eegmmidb/__init__.py | 2 - pyeeglab/dataset/eegmmidb/eegmmidb_dataset.py | 8 - pyeeglab/dataset/eegmmidb/eegmmidb_index.py | 23 -- pyeeglab/dataset/eegmmidb/eegmmidb_loader.py | 18 -- pyeeglab/dataset/file.py | 24 ++ pyeeglab/dataset/metadata.py | 15 + pyeeglab/dataset/physionet/__init__.py | 2 + pyeeglab/dataset/physionet/chbmit_dataset.py | 72 +++++ .../dataset/physionet/eegmmidb_dataset.py | 64 ++++ pyeeglab/dataset/physionet/utils.py | 30 ++ pyeeglab/dataset/tuh_eeg/__init__.py | 3 + pyeeglab/dataset/tuh_eeg/abnormal_dataset.py | 59 ++++ pyeeglab/dataset/tuh_eeg/artifact_dataset.py | 51 +++ pyeeglab/dataset/tuh_eeg/seizure_dataset.py | 57 ++++ pyeeglab/dataset/tuh_eeg/utils.py | 96 ++++++ pyeeglab/dataset/tuh_eeg_abnormal/__init__.py | 2 - .../tuh_eeg_abnormal_dataset.py | 17 - .../tuh_eeg_abnormal_index.py | 46 --- .../tuh_eeg_abnormal_loader.py | 14 - pyeeglab/dataset/tuh_eeg_artifact/__init__.py | 2 - .../tuh_eeg_artifact_dataset.py | 15 - .../tuh_eeg_artifact_index.py | 56 ---- .../tuh_eeg_artifact_loader.py | 17 - pyeeglab/dataset/tuh_eeg_seizure/__init__.py | 2 - .../tuh_eeg_seizure_dataset.py | 15 - .../tuh_eeg_seizure/tuh_eeg_seizure_index.py | 83 ----- .../tuh_eeg_seizure/tuh_eeg_seizure_loader.py | 14 - pyeeglab/io/__init__.py | 3 - pyeeglab/io/index.py | 139 -------- pyeeglab/io/loader.py | 153 --------- pyeeglab/io/raw.py | 83 ----- pyeeglab/pipeline/pipeline.py | 29 +- .../preprocess/features/brain_connectivity.py | 9 +- pyeeglab/preprocess/features/stat_features.py | 5 +- .../preprocess/signal/channel_selector.py | 20 +- pyeeglab/preprocess/signal/filter_selector.py | 11 +- .../preprocess/signal/frequency_selector.py | 7 +- pyeeglab/preprocess/signal/normalization.py | 4 +- .../preprocess/transform/data_converter.py | 12 +- .../preprocess/transform/frame_generator.py | 9 +- requirements.txt | 3 +- setup.py | 5 +- tests/test_chbmit.py | 15 +- tests/test_eegmmidb.py | 15 +- tests/test_tuh_eeg_abnormal.py | 13 +- tests/test_tuh_eeg_artifact.py | 15 +- 73 files changed, 940 insertions(+), 1188 deletions(-) delete mode 100755 Makefile delete mode 100644 pyeeglab/cache/__init__.py delete mode 100644 pyeeglab/cache/cache.py delete mode 100644 pyeeglab/database/__init__.py delete mode 100644 pyeeglab/database/tables.py create mode 100644 pyeeglab/dataset/annotation.py delete mode 100644 pyeeglab/dataset/chbmit/__init__.py delete mode 100644 pyeeglab/dataset/chbmit/chbmit_dataset.py delete mode 100644 pyeeglab/dataset/chbmit/chbmit_index.py delete mode 100644 pyeeglab/dataset/chbmit/chbmit_loader.py create mode 100644 pyeeglab/dataset/declarative_base.py delete mode 100644 pyeeglab/dataset/eegmmidb/__init__.py delete mode 100644 pyeeglab/dataset/eegmmidb/eegmmidb_dataset.py delete mode 100644 pyeeglab/dataset/eegmmidb/eegmmidb_index.py delete mode 100644 pyeeglab/dataset/eegmmidb/eegmmidb_loader.py create mode 100644 pyeeglab/dataset/file.py create mode 100644 pyeeglab/dataset/metadata.py create mode 100644 pyeeglab/dataset/physionet/__init__.py create mode 100644 pyeeglab/dataset/physionet/chbmit_dataset.py create mode 100644 pyeeglab/dataset/physionet/eegmmidb_dataset.py create mode 100644 pyeeglab/dataset/physionet/utils.py create mode 100644 pyeeglab/dataset/tuh_eeg/__init__.py create mode 100644 pyeeglab/dataset/tuh_eeg/abnormal_dataset.py create mode 100644 pyeeglab/dataset/tuh_eeg/artifact_dataset.py create mode 100644 pyeeglab/dataset/tuh_eeg/seizure_dataset.py create mode 100644 pyeeglab/dataset/tuh_eeg/utils.py delete mode 100644 pyeeglab/dataset/tuh_eeg_abnormal/__init__.py delete mode 100644 pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_dataset.py delete mode 100644 pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_index.py delete mode 100644 pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_loader.py delete mode 100644 pyeeglab/dataset/tuh_eeg_artifact/__init__.py delete mode 100644 pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_dataset.py delete mode 100644 pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_index.py delete mode 100644 pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_loader.py delete mode 100644 pyeeglab/dataset/tuh_eeg_seizure/__init__.py delete mode 100644 pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_dataset.py delete mode 100644 pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_index.py delete mode 100644 pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_loader.py delete mode 100644 pyeeglab/io/__init__.py delete mode 100644 pyeeglab/io/index.py delete mode 100644 pyeeglab/io/loader.py delete mode 100644 pyeeglab/io/raw.py diff --git a/CHANGELOG.md b/CHANGELOG.md index cfbe08d..de57125 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.10.0] - 2020-09-28 +### Added +* Added new dataset versioning system, + you can work with multiple versions + of the same dataset during creation +* Added new in-Python download system, + you can download a specific version + of the dataset during initialization + +### Changed +* Refactored dataset indexing system, + a dedicated directory will be created + for the index database during initialization +* Refactored preprocessing cache system, + a dedicated directory will be created + and a faster version for cache coherence + check has been implemented +* Refactored logging system +* Update package requirements +* Update README + +### Removed +* Deprecated Makefile + + ## [0.9.3] - 2020-07-05 ### Change diff --git a/Makefile b/Makefile deleted file mode 100755 index 31a95c1..0000000 --- a/Makefile +++ /dev/null @@ -1,21 +0,0 @@ -.PHONY: tuh_eeg_abnormal tuh_eeg_artifact eegmmidb clean - -tuh_eeg_abnormal: - echo "Request your access password at: https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php" - rsync -auxvL nedc_tuh_eeg@www.isip.piconepress.com:~/data/tuh_eeg_abnormal/ data/tuh_eeg_abnormal - -tuh_eeg_artifact: - echo "Request your access password at: https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php" - rsync -auxvL nedc_tuh_eeg@www.isip.piconepress.com:~/data/tuh_eeg_artifact/v2.0.0 data/tuh_eeg_artifact/v2.0.0 - -tuh_eeg_seizure: - echo "Request your access password at: https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php" - rsync -auxvL nedc_tuh_eeg@www.isip.piconepress.com:~/data/tuh_eeg_seizure/v1.5.2 data/tuh_eeg_seizure - -eegmmidb: - wget -r -N -c -np https://physionet.org/files/eegmmidb/1.0.0/ -P data - -clean: - find -L data -iname "*.db" -type f -delete - find -L data -iname "*.log" -type f -delete - find -L data -iname "*.fif.gz" -type f -delete diff --git a/README.md b/README.md index 905d16e..205ed9f 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,6 @@ PyEEGLab is distributed using the pip repository: pip install PyEEGLab -If you use Python 3.6, the dataclasses package must be installed as backport of Python 3.7 dataclasses: - - pip install dataclasses - If you need a bleeding edge version, you can install it directly from GitHub: pip install git+https://github.com/AlessioZanga/PyEEGLab@develop @@ -99,11 +95,18 @@ The following datasets will work upon downloading: > **WARNING**: Retriving the TUH EEG datasets require valid credentials, you can get your own at: https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php. -In the root directory of this project there is a Makefile, by typing: +Given the dataset instance, trigger the download using the "download" method: + + from pyeeglab import * + dataset = TUHEEGAbnormalDataset() + dataset.download(user='USER', password='PASSWORD') + dataset.index() + +then index the new downloaded files. - make tuh_eeg_abnormal +It should be noted that the download mechanism work on Unix-like systems given the following packages: -you will trigger the dataset download. + sudo apt install sshpass rsync wget ## Documentation diff --git a/examples/tensorboard/example_tensorboard.py b/examples/tensorboard/example_tensorboard.py index 500b222..de8ae71 100755 --- a/examples/tensorboard/example_tensorboard.py +++ b/examples/tensorboard/example_tensorboard.py @@ -37,8 +37,6 @@ from pyeeglab import * def build_data(dataset): - dataset.set_cache_manager(PickleCache('../../export')) - preprocessing = Pipeline([ CommonChannelSet(), LowestFrequency(), @@ -212,18 +210,18 @@ def tune_model(dataset_name, data): if __name__ == '__main__': dataset = {} - dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') + dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') - dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/v1.0.0/edf') + dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/') dataset['tuh_eeg_artifact'].set_minimum_event_duration(4) - dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/v1.5.2/edf') + dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/') dataset['tuh_eeg_seizure'].set_minimum_event_duration(4) - # dataset['eegmmidb'] = EEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/1.0.0') + # dataset['eegmmidb'] = PhysioNetEEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/') # dataset['eegmmidb'].set_minimum_event_duration(4) - dataset['chbmit'] = CHBMITDataset('../../data/physionet.org/files/chbmit/1.0.0') + dataset['chbmit'] = PhysioNetCHBMITDataset('../../data/physionet.org/files/chbmit/') dataset['chbmit'].set_minimum_event_duration(4) """ diff --git a/examples/tensorboard/example_tensorboard_cnn.py b/examples/tensorboard/example_tensorboard_cnn.py index fac908f..b345afc 100755 --- a/examples/tensorboard/example_tensorboard_cnn.py +++ b/examples/tensorboard/example_tensorboard_cnn.py @@ -39,8 +39,6 @@ from pyeeglab import * def build_data(dataset): - dataset.set_cache_manager(PickleCache('../../export')) - preprocessing = Pipeline([ CommonChannelSet(), LowestFrequency(), @@ -201,18 +199,18 @@ def tune_model(dataset_name, data): if __name__ == '__main__': dataset = {} - # dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') + # dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') - # dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/v1.0.0/edf') + # dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/') # dataset['tuh_eeg_artifact'].set_minimum_event_duration(4) - dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/v1.5.2/edf') + dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/') dataset['tuh_eeg_seizure'].set_minimum_event_duration(4) - # dataset['eegmmidb'] = EEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/1.0.0') + # dataset['eegmmidb'] = PhysioNetEEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/') # dataset['eegmmidb'].set_minimum_event_duration(4) - dataset['chbmit'] = CHBMITDataset('../../data/physionet.org/files/chbmit/1.0.0') + dataset['chbmit'] = PhysioNetCHBMITDataset('../../data/physionet.org/files/chbmit/') dataset['chbmit'].set_minimum_event_duration(4) """ diff --git a/examples/tensorboard/example_tensorboard_gat.py b/examples/tensorboard/example_tensorboard_gat.py index 704ade2..f74c1dc 100755 --- a/examples/tensorboard/example_tensorboard_gat.py +++ b/examples/tensorboard/example_tensorboard_gat.py @@ -37,8 +37,6 @@ from pyeeglab import * def build_data(dataset): - dataset.set_cache_manager(PickleCache('../../export')) - preprocessing = Pipeline([ CommonChannelSet(), LowestFrequency(), @@ -206,18 +204,18 @@ def tune_model(dataset_name, data): if __name__ == '__main__': dataset = {} - dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') + dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') - dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/v1.0.0/edf') + dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/') dataset['tuh_eeg_artifact'].set_minimum_event_duration(4) - dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/v1.5.2/edf') + dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/') dataset['tuh_eeg_seizure'].set_minimum_event_duration(4) - # dataset['eegmmidb'] = EEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/1.0.0') + # dataset['eegmmidb'] = PhysioNetEEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/') # dataset['eegmmidb'].set_minimum_event_duration(4) - dataset['chbmit'] = CHBMITDataset('../../data/physionet.org/files/chbmit/1.0.0') + dataset['chbmit'] = PhysioNetCHBMITDataset('../../data/physionet.org/files/chbmit/') dataset['chbmit'].set_minimum_event_duration(4) """ diff --git a/examples/tensorboard/example_tensorboard_gnn.py b/examples/tensorboard/example_tensorboard_gnn.py index e38034e..40c9977 100755 --- a/examples/tensorboard/example_tensorboard_gnn.py +++ b/examples/tensorboard/example_tensorboard_gnn.py @@ -39,8 +39,6 @@ from pyeeglab import * def build_data(dataset): - dataset.set_cache_manager(PickleCache('../../export')) - preprocessing = Pipeline([ CommonChannelSet(), LowestFrequency(), @@ -208,18 +206,18 @@ def tune_model(dataset_name, data): if __name__ == '__main__': dataset = {} - # dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') + # dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') - # dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/v1.0.0/edf') + # dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/') # dataset['tuh_eeg_artifact'].set_minimum_event_duration(4) - dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/v1.5.2/edf') + dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/') dataset['tuh_eeg_seizure'].set_minimum_event_duration(4) - # dataset['eegmmidb'] = EEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/1.0.0') + # dataset['eegmmidb'] = PhysioNetEEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/') # dataset['eegmmidb'].set_minimum_event_duration(4) - dataset['chbmit'] = CHBMITDataset('../../data/physionet.org/files/chbmit/1.0.0') + dataset['chbmit'] = PhysioNetCHBMITDataset('../../data/physionet.org/files/chbmit/') dataset['chbmit'].set_minimum_event_duration(4) """ diff --git a/examples/tensorboard/example_validation.py b/examples/tensorboard/example_validation.py index 242f55b..cc906f7 100755 --- a/examples/tensorboard/example_validation.py +++ b/examples/tensorboard/example_validation.py @@ -38,8 +38,6 @@ from pyeeglab import * def build_data(dataset): - dataset.set_cache_manager(PickleCache('../../export')) - preprocessing = Pipeline([ CommonChannelSet(), LowestFrequency(), @@ -220,18 +218,18 @@ def tune_model(dataset_name, data, metrics, folds, val_size=0.1): if __name__ == '__main__': dataset = {} - # dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') + # dataset['tuh_eeg_abnormal'] = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') - # dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/v1.0.0/edf') + # dataset['tuh_eeg_artifact'] = TUHEEGArtifactDataset('../../data/tuh_eeg_artifact/') # dataset['tuh_eeg_artifact'].set_minimum_event_duration(4) - # dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/v1.5.2/edf') + # dataset['tuh_eeg_seizure'] = TUHEEGSeizureDataset('../../data/tuh_eeg_seizure/') # dataset['tuh_eeg_seizure'].set_minimum_event_duration(4) - # dataset['eegmmidb'] = EEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/1.0.0') + # dataset['eegmmidb'] = PhysioNetEEGMMIDBDataset('../../data/physionet.org/files/eegmmidb/') # dataset['eegmmidb'].set_minimum_event_duration(4) - dataset['chbmit'] = CHBMITDataset('../../data/physionet.org/files/chbmit/1.0.0') + dataset['chbmit'] = PhysioNetCHBMITDataset('../../data/physionet.org/files/chbmit/') dataset['chbmit'].set_minimum_event_duration(4) """ diff --git a/examples/tuh_eeg_abnormal/example_cnn_dense_classification.py b/examples/tuh_eeg_abnormal/example_cnn_dense_classification.py index 392d04f..f813d4f 100755 --- a/examples/tuh_eeg_abnormal/example_cnn_dense_classification.py +++ b/examples/tuh_eeg_abnormal/example_cnn_dense_classification.py @@ -21,12 +21,11 @@ import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) -from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \ +from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \ LowestFrequency, ToDataframe, DynamicWindow, BinarizedSpearmanCorrelation, \ ToNumpy -dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') -dataset.set_cache_manager(PickleCache('../../export')) +dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') preprocessing = Pipeline([ CommonChannelSet(), diff --git a/examples/tuh_eeg_abnormal/example_cnn_dense_classification_with_filter.py b/examples/tuh_eeg_abnormal/example_cnn_dense_classification_with_filter.py index 3455490..59674b2 100755 --- a/examples/tuh_eeg_abnormal/example_cnn_dense_classification_with_filter.py +++ b/examples/tuh_eeg_abnormal/example_cnn_dense_classification_with_filter.py @@ -21,12 +21,11 @@ import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) -from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \ +from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \ LowestFrequency, BandPassFrequency, ToDataframe, DynamicWindow, \ BinarizedSpearmanCorrelation, ToNumpy -dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') -dataset.set_cache_manager(PickleCache('../../export')) +dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') preprocessing = Pipeline([ CommonChannelSet(), diff --git a/examples/tuh_eeg_abnormal/example_cnn_lstm_classification.py b/examples/tuh_eeg_abnormal/example_cnn_lstm_classification.py index 38407f1..e553bd3 100755 --- a/examples/tuh_eeg_abnormal/example_cnn_lstm_classification.py +++ b/examples/tuh_eeg_abnormal/example_cnn_lstm_classification.py @@ -21,12 +21,11 @@ import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) -from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \ +from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \ LowestFrequency, ToDataframe, DynamicWindow, BinarizedSpearmanCorrelation, \ ToNumpy -dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') -dataset.set_cache_manager(PickleCache('../../export')) +dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') preprocessing = Pipeline([ CommonChannelSet(), diff --git a/examples/tuh_eeg_abnormal/example_cnn_lstm_classification_with_filter.py b/examples/tuh_eeg_abnormal/example_cnn_lstm_classification_with_filter.py index ea7dab1..faee56a 100755 --- a/examples/tuh_eeg_abnormal/example_cnn_lstm_classification_with_filter.py +++ b/examples/tuh_eeg_abnormal/example_cnn_lstm_classification_with_filter.py @@ -21,12 +21,11 @@ import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) -from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \ +from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \ LowestFrequency, BandPassFrequency, ToDataframe, DynamicWindow, \ BinarizedSpearmanCorrelation, ToNumpy -dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') -dataset.set_cache_manager(PickleCache('../../export')) +dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') preprocessing = Pipeline([ CommonChannelSet(), diff --git a/examples/tuh_eeg_abnormal/example_gat_lstm_classification.py b/examples/tuh_eeg_abnormal/example_gat_lstm_classification.py index 3f85e02..cb6e7be 100755 --- a/examples/tuh_eeg_abnormal/example_gat_lstm_classification.py +++ b/examples/tuh_eeg_abnormal/example_gat_lstm_classification.py @@ -23,12 +23,11 @@ import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) -from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \ +from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \ LowestFrequency, ToDataframe, DynamicWindow, BinarizedSpearmanCorrelation, \ CorrelationToAdjacency, Bandpower, GraphWithFeatures, ForkedPreprocessor -dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') -dataset.set_cache_manager(PickleCache('../../export')) +dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') preprocessing = Pipeline([ CommonChannelSet(), diff --git a/examples/tuh_eeg_abnormal/example_gat_lstm_classification_with_filter.py b/examples/tuh_eeg_abnormal/example_gat_lstm_classification_with_filter.py index 5bd0c3f..859ff78 100755 --- a/examples/tuh_eeg_abnormal/example_gat_lstm_classification_with_filter.py +++ b/examples/tuh_eeg_abnormal/example_gat_lstm_classification_with_filter.py @@ -23,13 +23,12 @@ import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) -from pyeeglab import TUHEEGAbnormalDataset, PickleCache, Pipeline, CommonChannelSet, \ +from pyeeglab import TUHEEGAbnormalDataset, Pipeline, CommonChannelSet, \ LowestFrequency, BandPassFrequency, ToDataframe, DynamicWindow, \ BinarizedSpearmanCorrelation, CorrelationToAdjacency, Bandpower, \ GraphWithFeatures, ForkedPreprocessor -dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/v2.0.0/edf') -dataset.set_cache_manager(PickleCache('../../export')) +dataset = TUHEEGAbnormalDataset('../../data/tuh_eeg_abnormal/') preprocessing = Pipeline([ CommonChannelSet(), diff --git a/pyeeglab/__init__.py b/pyeeglab/__init__.py index cce3be6..df80b73 100644 --- a/pyeeglab/__init__.py +++ b/pyeeglab/__init__.py @@ -1,15 +1,11 @@ import logging import warnings -from importlib.util import find_spec -from mne.utils import set_config +logging.basicConfig(format="%(asctime)s %(levelname)7s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S") from .dataset import * -from .io import Raw -from .cache import PickleCache -from .pipeline import Pipeline, ForkedPreprocessor +from .pipeline import * from .preprocess import * logging.getLogger().setLevel(logging.DEBUG) - warnings.filterwarnings("ignore", category=RuntimeWarning) diff --git a/pyeeglab/cache/__init__.py b/pyeeglab/cache/__init__.py deleted file mode 100644 index eb3bba2..0000000 --- a/pyeeglab/cache/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .cache import Cache, PickleCache diff --git a/pyeeglab/cache/cache.py b/pyeeglab/cache/cache.py deleted file mode 100644 index 202b9bd..0000000 --- a/pyeeglab/cache/cache.py +++ /dev/null @@ -1,65 +0,0 @@ -import json -import logging -from abc import ABC, abstractmethod -from typing import List, Dict - -from os.path import isfile, join -from pathlib import Path -from hashlib import md5 -from pickle import load, dump - -from ..io import DataLoader -from ..pipeline import Pipeline - -class Cache(ABC): - - def __init__(self) -> None: - logging.debug('Create cache manager') - - def _get_cache_key(self, dataset: str, loader: DataLoader, pipeline: Pipeline) -> str: - if dataset.endswith('dataset'): - dataset = dataset[:-7] - key = [loader, pipeline] - key = [hash(k) for k in key] - key = [str(k).encode() for k in key] - key = [md5(k).hexdigest()[:10] for k in key] - key = list(zip(['loader', 'pipeline'], key)) - key = ['_'.join(k) for k in key] - key = dataset + '_' + '_'.join(key) - return key - - @abstractmethod - def load(self, dataset: str, loader: DataLoader, pipeline: Pipeline) -> Dict: - pass - - -class PickleCache(Cache): - - def __init__(self, path: str): - super().__init__() - logging.debug('Create single pickle cache manager') - Path(path).mkdir(parents=True, exist_ok=True) - self.path = path - - def load(self, dataset: str, loader: DataLoader, pipeline: Pipeline): - logging.debug('Computing cache key') - key = self._get_cache_key(dataset, loader, pipeline) - logging.debug('Computed cache key: %s', key) - key = join(self.path, key + '.pkl') - if isfile(key): - logging.debug('Cache file found') - with open(key, 'rb') as file: - try: - logging.debug('Loading cache file') - data = load(file) - return data - except: - logging.debug('Loading cache file failed') - pass - logging.debug('Cache file not found, genereting new one') - data = loader.get_dataset() - data = pipeline.run(data) - with open(key, 'wb') as file: - logging.debug('Dumping cache file') - dump(data, file) - return data diff --git a/pyeeglab/database/__init__.py b/pyeeglab/database/__init__.py deleted file mode 100644 index fdcd6bb..0000000 --- a/pyeeglab/database/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .tables import * diff --git a/pyeeglab/database/tables.py b/pyeeglab/database/tables.py deleted file mode 100644 index a5fd22b..0000000 --- a/pyeeglab/database/tables.py +++ /dev/null @@ -1,106 +0,0 @@ -import logging - -from dataclasses import dataclass - -from sqlalchemy import Column, Integer, Float, Text, ForeignKey -from sqlalchemy.ext.declarative import declarative_base - -from ..io.raw import Raw - - -BASE_TABLE = declarative_base() - - -@dataclass -class File(BASE_TABLE): - """File represents a single file contained in the dataset. - - This is an ORM class derived from the BASE_TABLE in a declarative base - used by SQLAlchemy. - - Attributes - ---------- - id : str - The primary key generated randomly using an UUID4 generator. - extension : str - A not null indexed string reporting the EEG recording format. - path : str - A not null string used to point to the relative path of the file respect - to the current sqlite database location. - """ - __tablename__ = 'file' - id: str = Column(Text, primary_key=True) - extension: str = Column(Text, nullable=False, index=True) - path: str = Column(Text, nullable=False) - - -@dataclass -class Metadata(BASE_TABLE): - """ Metadata represents a single metadata record associated with a single - file contained in the dataset. - - This is an ORM class derived from the BASE_TABLE in a declarative base - used by SQLAlchemy. - - Attributes - ---------- - file_id : str - The foreign key related to file_id, used also as a primary key for metadata - table since this is a one-to-one relationship. - duration : int - This is the EEG sample duration reported in seconds. This is not inteded as - a precise duration estimate, but only a reference for statistical analysis. - For more precise duration measurement, please, use the Raw record class methods. - channels_count : int - This field report the number of channels reported in the EEG header. - channels_reference : str - An indexed string describing the EEG channel reference system. - channels_set : str - The list of channels saved as a JSON string extracted from the EEG header. - sampling_frequency : int - The sampling fequency expressend in Hz extrated from the EEG header. - max_value : float - The max value sampled in this record across all channels. - min_value : float - The min value sampled in this record across all channels. - """ - __tablename__ = 'metadata' - file_id: str = Column(Text, ForeignKey('file.id'), primary_key=True) - duration: int = Column(Integer, nullable=False) - channels_count: int = Column(Integer, nullable=False) - channels_reference: str = Column(Text, nullable=True, index=True) - channels_set: str = Column(Text, nullable=False, index=True) - sampling_frequency: int = Column(Integer, nullable=False, index=True) - max_value: float = Column(Float, nullable=False) - min_value: float = Column(Float, nullable=False) - - -@dataclass -class Event(BASE_TABLE): - """Event represents a single event associated to a single file. - Multiple events can be associated to multiple files, according to - the record annotations. - - Attributes - ---------- - id : str - This is the primary key that is associated to each event. It is - created from a UUID4. - file_id : str - This the foreign key that link the event to the related file. - begin : float - The timestep expressed in seconds from which the event begins. - end : float - The timestep expressed in seconds at which the event ends. - duration : float - The entire duration of the event. - label : str - The label of this event as written inthe annotations. - """ - __tablename__ = 'event' - id: str = Column(Text, primary_key=True) - file_id: str = Column(Text, ForeignKey('file.id'), index=True) - begin: float = Column(Float, nullable=False) - end: float = Column(Float, nullable=False) - duration: float = Column(Float, nullable=False) - label: str = Column(Text, nullable=False, index=True) diff --git a/pyeeglab/dataset/__init__.py b/pyeeglab/dataset/__init__.py index 6b1dac2..3880a61 100644 --- a/pyeeglab/dataset/__init__.py +++ b/pyeeglab/dataset/__init__.py @@ -1,5 +1,2 @@ -from .tuh_eeg_abnormal import TUHEEGAbnormalLoader, TUHEEGAbnormalDataset -from .tuh_eeg_artifact import TUHEEGArtifactLoader, TUHEEGArtifactDataset -from .tuh_eeg_seizure import TUHEEGSeizureLoader, TUHEEGSeizureDataset -from .eegmmidb import EEGMMIDBLoader, EEGMMIDBDataset -from .chbmit import CHBMITLoader, CHBMITDataset +from .physionet import * +from .tuh_eeg import * diff --git a/pyeeglab/dataset/annotation.py b/pyeeglab/dataset/annotation.py new file mode 100644 index 0000000..db8aee8 --- /dev/null +++ b/pyeeglab/dataset/annotation.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from mne.io import Raw, read_raw +from sqlalchemy import Column, ForeignKey, Text, Float +from sqlalchemy.orm import relationship +from .declarative_base import Base + + +@dataclass +class Annotation(Base): + __tablename__ = "annotation" + uuid: str = Column(Text, primary_key=True) + file_uuid: str = Column(Text, ForeignKey("file.uuid"), nullable=False) + begin: float = Column(Float, nullable=False) + end: float = Column(Float, nullable=False) + label: str = Column(Text, nullable=False, index=True) + + file = relationship("File", lazy="subquery") + + @property + def duration(self) -> float: + return self.end - self.begin + + def __enter__(self) -> Raw: + self.reader = read_raw(self.file.path) + tmax = self.reader.n_times / self.reader.info["sfreq"] - 0.1 + tmax = tmax if self.end > tmax else self.end + self.reader.crop(self.begin, tmax) + return self.reader + + def __exit__(self, *args, **kwargs) -> None: + self.reader.close() + del self.reader + diff --git a/pyeeglab/dataset/chbmit/__init__.py b/pyeeglab/dataset/chbmit/__init__.py deleted file mode 100644 index a592069..0000000 --- a/pyeeglab/dataset/chbmit/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .chbmit_loader import CHBMITLoader -from .chbmit_dataset import CHBMITDataset diff --git a/pyeeglab/dataset/chbmit/chbmit_dataset.py b/pyeeglab/dataset/chbmit/chbmit_dataset.py deleted file mode 100644 index 4692ae1..0000000 --- a/pyeeglab/dataset/chbmit/chbmit_dataset.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Dict - -from .chbmit_loader import CHBMITLoader -from ..dataset import Dataset - - -class CHBMITDataset(Dataset): - - def __init__(self, path: str = './data/physionet.org/files/chbmit/1.0.0/') -> None: - super().__init__(CHBMITLoader(path)) - - def _get_dataset_env(self) -> Dict: - env = super()._get_dataset_env() - env['class_id'] = 'noseizure' - return env diff --git a/pyeeglab/dataset/chbmit/chbmit_index.py b/pyeeglab/dataset/chbmit/chbmit_index.py deleted file mode 100644 index 05a1df9..0000000 --- a/pyeeglab/dataset/chbmit/chbmit_index.py +++ /dev/null @@ -1,57 +0,0 @@ -import logging - -from uuid import uuid4, uuid5, NAMESPACE_X500 -from os.path import isfile, join, sep -from wfdb import rdann - -from ...io import Index -from ...database import File, Event - -from typing import List - - -class CHBMITIndex(Index): - def __init__(self, path: str) -> None: - logging.debug('Create CHB-MIT Scalp EEG Index') - super().__init__('sqlite:///' + join(path, 'index.db'), path) - self.index() - - def _get_file(self, path: str) -> File: - length = len(self.path) - meta = path[length:].split(sep) - return File( - id=str(uuid5(NAMESPACE_X500, path[length:])), - extension=meta[-1].split('.')[-1], - path=path[length:], - ) - - def _get_record_events(self, file: File) -> List[Event]: - logging.debug('Add file %s raw events to index', file.id) - path = join(self.path, file.path) - if isfile(path + '.seizures'): - events = rdann(path, 'seizures') - events = list(events.sample / events.fs) - events = [events[i:i+2] for i in range(0, len(events), 2)] - events = [ - Event( - id=str(uuid4()), - file_id=file.id, - begin=event[0], - end=event[1], - duration=(event[1] - event[0]), - label='seizure' - ) - for event in events - ] - else: - events = [ - Event( - id=str(uuid4()), - file_id=file.id, - begin=60, - end=120, - duration=60, - label='noseizure' - ) - ] - return events diff --git a/pyeeglab/dataset/chbmit/chbmit_loader.py b/pyeeglab/dataset/chbmit/chbmit_loader.py deleted file mode 100644 index 022c057..0000000 --- a/pyeeglab/dataset/chbmit/chbmit_loader.py +++ /dev/null @@ -1,20 +0,0 @@ -import logging - -from typing import List - -from ...io import DataLoader -from .chbmit_index import CHBMITIndex - - -class CHBMITLoader(DataLoader): - - def __init__(self, path: str) -> None: - exclude_files = [ - 'chb03/chb03_35.edf', # Corrupted data - 'chb12/chb12_27.edf', # Bad channel names - 'chb12/chb12_28.edf', # Bad channel names - 'chb12/chb12_29.edf', # Bad channel names - ] - super().__init__(path, exclude_files=exclude_files) - logging.debug('Create CHB-MIT Scalp EEG Loader') - self.index = CHBMITIndex(self.path) diff --git a/pyeeglab/dataset/dataset.py b/pyeeglab/dataset/dataset.py index 692b3a2..fc67639 100644 --- a/pyeeglab/dataset/dataset.py +++ b/pyeeglab/dataset/dataset.py @@ -1,45 +1,286 @@ +import os +import json import logging -from abc import ABC -from typing import Dict +import hashlib +import pickle + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from functools import reduce +from multiprocessing import Pool, cpu_count +from operator import add, and_ +from uuid import uuid4, uuid5, NAMESPACE_X500 + +from typing import Dict, List, Tuple + +import mne +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker, Query + +from .declarative_base import Base +from .file import File +from .metadata import Metadata +from .annotation import Annotation -from ..io import DataLoader -from ..cache import Cache, PickleCache from ..pipeline import Pipeline + +@dataclass(init=False) class Dataset(ABC): + path: str + name: str + version: str - cache: Cache - loader: DataLoader - pipeline: Pipeline + extensions: List[str] + exclude_file: List[str] + exclude_channels_set: List[str] + exclude_channels_reference: List[str] + exclude_sampling_frequency: List[int] + minimum_annotation_duration: float - def __init__(self, loader: DataLoader) -> None: - logging.debug('Create dataset') - self.loader = loader - self.set_cache_manager(PickleCache('export')) - self.set_pipeline(Pipeline()) + session: Session + query: Query - def load(self) -> Dict: - dataset = self.__class__.__name__.lower() - return self.cache.load(dataset, self.loader, self.pipeline) + pipeline: Pipeline = None + + def __init__( + self, + path: str, + name: str, + version: str = None, + extensions: List[str] = [".edf"], + exclude_file: List[str] = None, + exclude_channels_set: List[str] = None, + exclude_channels_reference: List[str] = None, + exclude_sampling_frequency: List[str] = None, + minimum_annotation_duration: float = None + ) -> None: + # Set basic attributes + self.path = os.path.abspath(os.path.join(path, version)) + self.name = name + self.version = version - def _get_dataset_env(self) -> Dict: + # Set data set filter attributes + self.extensions = extensions if extensions else [] + self.exclude_file = exclude_file if exclude_file else [] + self.exclude_channels_set = exclude_channels_set if exclude_channels_set else [] + self.exclude_channels_reference = exclude_channels_reference if exclude_channels_reference else [] + self.exclude_sampling_frequency = exclude_sampling_frequency if exclude_sampling_frequency else [] + self.minimum_annotation_duration = minimum_annotation_duration if minimum_annotation_duration else 0 + + logging.info("Init dataset '%s'@'%s' at '%s'", self.name, self.version, self.path) + + # Make workspace directory + logging.debug("Make .pyeeglab directory") + workspace = os.path.join(self.path, ".pyeeglab") + os.makedirs(workspace, exist_ok=True) + logging.debug("Make .pyeeglab/cache directory") + os.makedirs(os.path.join(workspace, "cache"), exist_ok=True) + logging.debug("Set MNE log .pyeeglab/mne.log") + mne.set_log_file(os.path.join(workspace, "mne.log"), overwrite=False) + + # Index data set files + self.index() + + def __getstate__(self): + # Workaround for unpickable sqlalchemy.orm.session + # during multiprocess dataset loading + state = self.__dict__.copy() + for attribute in ["session", "query"]: + if hasattr(self, attribute): + del state[attribute] + return state + + @abstractmethod + def download(self, user: str = None, password: str = None) -> None: + pass + + def index(self) -> None: + # Init index session + logging.debug("Make index session") + connection = os.path.join(self.path, ".pyeeglab", "index.sqlite3") + connection = create_engine("sqlite:///" + connection) + Base.metadata.create_all(connection) + self.session = sessionmaker(bind=connection)() + # Open multiprocess pool + logging.info("Index data set directory") + pool = Pool(cpu_count()) + # Get files path from data set path + paths = [ + os.path.join(directory, filename) + for directory, _, filenames in os.walk(self.path) + for filename in filenames + ] + # Get Files instances form paths, filtering already indexed + files = self.session.query(File).all() + files = [file.uuid for file in files] + files = [ + file + for file in pool.map(self._get_file, paths) + if file.uuid not in files + ] + for file in files: + logging.debug("Add file %s to index", file.uuid) + # Filter raw data files by extension + raws = [ + file + for file in files + if os.path.splitext(file.path)[-1] in self.extensions + ] + # Get metadata and annotation for data files + metadatas = pool.map(self._get_metadata, raws) + annotations = pool.map(self._get_annotation, raws) + # Close multiprocess pool + pool.close() + pool.join() + # Commit insertions to index + commits = files + metadatas + reduce(add, annotations, []) + if commits: + logging.info("Commit insertions to index") + self.session.add_all(commits) + self.session.commit() + logging.info("Index data set completed") + # Init default query + logging.debug("Init default query") + self.query = self.session.query(File, Metadata, Annotation).\ + join(File.meta).\ + join(File.annotations).\ + filter(~Metadata.channels_reference.in_(self.exclude_channels_reference)).\ + filter(~Metadata.sampling_frequency.in_(self.exclude_sampling_frequency)).\ + filter((Annotation.end - Annotation.begin) >= self.minimum_annotation_duration) + # Filter exclude file paths + for file in self.exclude_file: + self.query = self.query.filter(~File.path.like("%{}%".format(file))) + logging.debug("SQL query representation: '%s'", str(self.query).replace("\n", "")) + + def _get_file(self, path: str) -> File: + return File( + uuid=str(uuid5(NAMESPACE_X500, path)), + path=path, + extension=os.path.splitext(path)[-1] + ) + + def _get_metadata(self, file: File) -> Metadata: + logging.debug("Add file %s metadata to index", file.uuid) + with file as reader: + info = reader.info + metadata = Metadata( + file_uuid=file.uuid, + duration=reader.n_times/info["sfreq"], + channels_set=json.dumps(info["ch_names"]), + sampling_frequency=info["sfreq"], + max_value=reader.get_data().max(), + min_value=reader.get_data().min(), + ) + return metadata + + def _get_annotation(self, file: File) -> List[Annotation]: + logging.debug("Add file %s annotations to index", file.uuid) + with file as reader: + annotations = [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=annotation[0], + end=annotation[0]+annotation[1], + label=annotation[2], + ) + for annotation in reader.annotations + ] + return annotations + + @property + def environment(self) -> Dict: + min_max = self.signal_min_max_range return { - 'channels_set': self.loader.get_channels_set(), - 'lowest_frequency': self.loader.get_lowest_frequency(), - 'max_value': self.loader.get_max_value(), - 'min_value': self.loader.get_min_value(), + "channels_set": self.maximal_channels_subset, + "lowest_frequency": self.lowest_frequency, + "min_value": min_max[0], + "max_value": min_max[1], } - def set_minimum_event_duration(self, minimum_event_duration: float) -> 'Dataset': - self.loader.set_minimum_event_duration(minimum_event_duration) - return self + @property + def lowest_frequency(self) -> float: + frequency = self.query.all() + frequency = min([ + f[1].sampling_frequency + for f in frequency + ], default=0) + return frequency - def set_cache_manager(self, cache: Cache) -> 'Dataset': - self.cache = cache - return self - - def set_pipeline(self, pipeline: Pipeline) -> 'Dataset': + @property + def maximal_channels_subset(self) -> List[str]: + channels = self.query.group_by(Metadata.channels_set).all() + channels = [ + frozenset(json.loads(channel[1].channels_set)) + for channel in channels + ] + channels = reduce(and_, channels) + channels = channels - frozenset(self.exclude_channels_set) + channels = sorted(channels) + return channels + + @property + def signal_min_max_range(self) -> Tuple[float]: + min_max = self.query.all() + min_max = [m[1] for m in min_max] + min_max = tuple([ + min([m.min_value for m in min_max], default=0), + max([m.max_value for m in min_max], default=0), + ]) + return min_max + + def set_pipeline(self, pipeline: Pipeline) -> "Dataset": self.pipeline = pipeline - environment = self._get_dataset_env() - self.pipeline.environment.update(environment) + self.pipeline.environment.update(self.environment) + return self + + def set_minimum_event_duration(self, duration: float) -> "Dataset": + logging.warning("This function will be deprecated in the near future") + self.minimum_annotation_duration = duration return self + + def load(self) -> Dict: + # Compute cache path + cache = os.path.join(self.path, ".pyeeglab", "cache") + # Compute cache key + logging.info("Compute cache key") + name = self.__class__.__name__.lower() + if name.endswith("dataset"): + name = name[:-len("dataset")] + key = [hash(self), hash(self.pipeline)] + key = [str(k).encode() for k in key] + key = [hashlib.md5(k).hexdigest()[:10] for k in key] + key = list(zip(["loader", "pipeline"], key)) + key = ["_".join(k) for k in key] + key = name + "_" + "_".join(key) + logging.info("Computed cache key: %s", key) + # Load file cache + cache = os.path.join(cache, key + ".pkl") + if os.path.exists(cache): + logging.info("Cache file found at %s", cache) + with open(cache, "rb") as reader: + try: + logging.info("Loading cache file") + return pickle.load(reader) + except: + logging.error("Loading cache file failed") + # Cache file not found, start preprocessing + logging.info("Cache file not found, genereting new one") + data = [row[2] for row in self.query.all()] + data = self.pipeline.run(data) + with open(cache, "wb") as file: + logging.info("Dumping cache file") + pickle.dump(data, file) + return data + + def __hash__(self) -> int: + key = [self.path, self.version, self.minimum_annotation_duration] + key += self.exclude_file + key += self.exclude_channels_set + key += self.exclude_channels_reference + key += self.exclude_sampling_frequency + key = json.dumps(key).encode() + key = hashlib.md5(key).hexdigest() + key = int(key, 16) + return key diff --git a/pyeeglab/dataset/declarative_base.py b/pyeeglab/dataset/declarative_base.py new file mode 100644 index 0000000..c64447d --- /dev/null +++ b/pyeeglab/dataset/declarative_base.py @@ -0,0 +1,2 @@ +from sqlalchemy.ext.declarative import declarative_base +Base = declarative_base() diff --git a/pyeeglab/dataset/eegmmidb/__init__.py b/pyeeglab/dataset/eegmmidb/__init__.py deleted file mode 100644 index 4cbd826..0000000 --- a/pyeeglab/dataset/eegmmidb/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .eegmmidb_loader import EEGMMIDBLoader -from .eegmmidb_dataset import EEGMMIDBDataset diff --git a/pyeeglab/dataset/eegmmidb/eegmmidb_dataset.py b/pyeeglab/dataset/eegmmidb/eegmmidb_dataset.py deleted file mode 100644 index a5aa15d..0000000 --- a/pyeeglab/dataset/eegmmidb/eegmmidb_dataset.py +++ /dev/null @@ -1,8 +0,0 @@ -from .eegmmidb_loader import EEGMMIDBLoader -from ..dataset import Dataset - - -class EEGMMIDBDataset(Dataset): - - def __init__(self, path: str = './data/physionet.org/files/eegmmidb/1.0.0/') -> None: - super().__init__(EEGMMIDBLoader(path)) diff --git a/pyeeglab/dataset/eegmmidb/eegmmidb_index.py b/pyeeglab/dataset/eegmmidb/eegmmidb_index.py deleted file mode 100644 index 83af529..0000000 --- a/pyeeglab/dataset/eegmmidb/eegmmidb_index.py +++ /dev/null @@ -1,23 +0,0 @@ -import logging - -from uuid import uuid5, NAMESPACE_X500 -from os.path import join, sep - -from ...io import Index -from ...database import File - - -class EEGMMIDBIndex(Index): - def __init__(self, path: str) -> None: - logging.debug('Create EEG Motor Movement/Imagery Index') - super().__init__('sqlite:///' + join(path, 'index.db'), path) - self.index() - - def _get_file(self, path: str) -> File: - length = len(self.path) - meta = path[length:].split(sep) - return File( - id=str(uuid5(NAMESPACE_X500, path[length:])), - extension=meta[-1].split('.')[-1], - path=path[length:] - ) diff --git a/pyeeglab/dataset/eegmmidb/eegmmidb_loader.py b/pyeeglab/dataset/eegmmidb/eegmmidb_loader.py deleted file mode 100644 index b9bde7b..0000000 --- a/pyeeglab/dataset/eegmmidb/eegmmidb_loader.py +++ /dev/null @@ -1,18 +0,0 @@ -import logging - -from typing import List - -from ...io import DataLoader -from .eegmmidb_index import EEGMMIDBIndex - - -class EEGMMIDBLoader(DataLoader): - - def __init__(self, path: str, exclude_frequency: List[int] = [128]) -> None: - exclude_files = [ - 'S021/S021R08.edf', # Corrupted data - 'S104/S104R04.edf', # Corrupted data - ] - super().__init__(path, exclude_files = exclude_files, exclude_frequency = exclude_frequency) - logging.debug('Create EEG Motor Movement/Imagery Loader') - self.index = EEGMMIDBIndex(self.path) diff --git a/pyeeglab/dataset/file.py b/pyeeglab/dataset/file.py new file mode 100644 index 0000000..e4cbbb1 --- /dev/null +++ b/pyeeglab/dataset/file.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from mne.io import Raw, read_raw +from sqlalchemy import Column, Text +from sqlalchemy.orm import relationship +from .declarative_base import Base + + +@dataclass +class File(Base): + __tablename__ = "file" + uuid: str = Column(Text, primary_key=True) + path: str = Column(Text, nullable=False) + extension: str = Column(Text, nullable=False) + + meta = relationship("Metadata", cascade="all,delete", backref="File") + annotations = relationship("Annotation", cascade="all,delete", backref="File") + + def __enter__(self) -> Raw: + self.reader = read_raw(self.path) + return self.reader + + def __exit__(self, *args, **kwargs) -> None: + self.reader.close() + del self.reader diff --git a/pyeeglab/dataset/metadata.py b/pyeeglab/dataset/metadata.py new file mode 100644 index 0000000..fb770e4 --- /dev/null +++ b/pyeeglab/dataset/metadata.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from sqlalchemy import Column, ForeignKey, Text, Float, Integer +from .declarative_base import Base + + +@dataclass +class Metadata(Base): + __tablename__ = "metadata" + file_uuid: str = Column(Text, ForeignKey("file.uuid"), primary_key=True) + duration: int = Column(Float, nullable=False) + channels_set: str = Column(Text, nullable=False, index=True) + channels_reference: str = Column(Text, nullable=True, index=True) + sampling_frequency: int = Column(Integer, nullable=False, index=True) + max_value: float = Column(Float, nullable=False) + min_value: float = Column(Float, nullable=False) diff --git a/pyeeglab/dataset/physionet/__init__.py b/pyeeglab/dataset/physionet/__init__.py new file mode 100644 index 0000000..3e3a51d --- /dev/null +++ b/pyeeglab/dataset/physionet/__init__.py @@ -0,0 +1,2 @@ +from .chbmit_dataset import * +from .eegmmidb_dataset import * diff --git a/pyeeglab/dataset/physionet/chbmit_dataset.py b/pyeeglab/dataset/physionet/chbmit_dataset.py new file mode 100644 index 0000000..8fe88ea --- /dev/null +++ b/pyeeglab/dataset/physionet/chbmit_dataset.py @@ -0,0 +1,72 @@ +import os +import logging + +from uuid import uuid4 + +from typing import List + +from .utils import wget + +from ..dataset import Dataset +from ..file import File +from ..annotation import Annotation + +import wfdb + + +class PhysioNetCHBMITDataset(Dataset): + + def __init__( + self, + path: str = "./data/physionet.org/files/chbmit/", + version: str = "1.0.0", + exclude_file: List[str] = [ + "chb03/chb03_35.edf", # Corrupted data + "chb12/chb12_27.edf", # Bad channel names + "chb12/chb12_28.edf", # Bad channel names + "chb12/chb12_29.edf", # Bad channel names + ], + exclude_channels_set: List[str] = [ + "ECG", + "LOC-ROC", + "LUE-RAE", + "VNS", + ], + ) -> None: + super().__init__( + path=path, + name="CHB-MIT Scalp EEG Database", + version=version, + exclude_file=exclude_file, + exclude_channels_set=exclude_channels_set, + ) + + def download(self, user: str = None, password: str = None) -> None: + wget(self.path, user, password, "chbmit", self.version) + + def _get_annotation(self, file: File) -> List[Annotation]: + logging.debug("Add file %s annotations to index", file.uuid) + annotations = [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=60, + end=120, + label="noseizure" + ) + ] + if os.path.isfile(file.path + ".seizures"): + annotations = wfdb.rdann(file.path, "seizures") + annotations = list(annotations.sample / annotations.fs) + annotations = [annotations[i:i+2] for i in range(0, len(annotations), 2)] + annotations = [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=annotation[0], + end=annotation[1], + label="seizure" + ) + for annotation in annotations + ] + return annotations diff --git a/pyeeglab/dataset/physionet/eegmmidb_dataset.py b/pyeeglab/dataset/physionet/eegmmidb_dataset.py new file mode 100644 index 0000000..f6da608 --- /dev/null +++ b/pyeeglab/dataset/physionet/eegmmidb_dataset.py @@ -0,0 +1,64 @@ +import os +import logging + +from uuid import uuid4 + +from typing import List + +from .utils import wget + +from ..dataset import Dataset +from ..file import File +from ..annotation import Annotation + + +class PhysioNetEEGMMIDBDataset(Dataset): + + def __init__( + self, + path: str = "./data/physionet.org/files/eegmmidb/", + version: str = "1.0.0", + exclude_file: List[str] = [ + "S021/S021R08.edf", # Corrupted data + "S104/S104R04.edf", # Corrupted data + ], + exclude_sampling_frequency: List[int] = [ 128 ], + ) -> None: + super().__init__( + path=path, + name="EEG Motor Movement/Imagery Dataset", + version=version, + exclude_file=exclude_file, + exclude_sampling_frequency=exclude_sampling_frequency, + ) + + def download(self, user: str = None, password: str = None) -> None: + wget(self.path, user, password, "eegmmidb", self.version) + + def _get_annotation(self, file: File) -> List[Annotation]: + logging.debug("Add file %s annotations to index", file.uuid) + with file as reader: + try: + annotations = [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=annotation[0], + end=annotation[0]+annotation[1], + label=annotation[2], + ) + for annotation in reader.annotations + ] + except KeyError: + # Alternative annotation format + annotations = [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=annotation["onset"], + end=annotation["onset"]+annotation["duration"], + label=annotation["description"], + ) + for annotation in reader.annotations + ] + return annotations diff --git a/pyeeglab/dataset/physionet/utils.py b/pyeeglab/dataset/physionet/utils.py new file mode 100644 index 0000000..c379903 --- /dev/null +++ b/pyeeglab/dataset/physionet/utils.py @@ -0,0 +1,30 @@ +import logging +import subprocess + + +def wget(path: str, user: str, password: str, slug: str, version: str) -> None: + logging.info("Download started, it will take some time") + url = "https://physionet.org/files/" + slug + "/" + version + "/" + process = subprocess.Popen( + [ + "wget", + "-r", + "-N", + "-c", + "-np", + "-nH", + "--cut-dirs=3", + url, + "-P", + path + ], + stdout=subprocess.PIPE, + universal_newlines=True, + ) + while True: + output = process.stdout.readline() + logging.info(output.strip()) + if process.poll() is not None: + for output in process.stdout.readlines(): + logging.info(output.strip()) + break diff --git a/pyeeglab/dataset/tuh_eeg/__init__.py b/pyeeglab/dataset/tuh_eeg/__init__.py new file mode 100644 index 0000000..d3db617 --- /dev/null +++ b/pyeeglab/dataset/tuh_eeg/__init__.py @@ -0,0 +1,3 @@ +from .abnormal_dataset import * +from .artifact_dataset import * +from .seizure_dataset import * diff --git a/pyeeglab/dataset/tuh_eeg/abnormal_dataset.py b/pyeeglab/dataset/tuh_eeg/abnormal_dataset.py new file mode 100644 index 0000000..62e885d --- /dev/null +++ b/pyeeglab/dataset/tuh_eeg/abnormal_dataset.py @@ -0,0 +1,59 @@ +import os +import logging + +from uuid import uuid4 + +from typing import List + +from .utils import rsync + +from ..dataset import Dataset +from ..file import File +from ..metadata import Metadata +from ..annotation import Annotation + + +class TUHEEGAbnormalDataset(Dataset): + + def __init__( + self, + path: str = "./data/tuh_eeg_abnormal/", + version: str = "2.0.0", + exclude_channels_set: List[str] = [ + "BURSTS", + "ECG EKG-REF", + "EMG-REF", + "IBI", + "PHOTIC-REF", + "PULSE RATE", + "STI 014", + "SUPPR" + ], + ) -> None: + super().__init__( + path=path, + name="Temple University Hospital EEG Abnormal Dataset", + version="v"+version, + exclude_channels_set=exclude_channels_set, + ) + + def download(self, user: str = None, password: str = None) -> None: + rsync(self.path, user, password, "tuh_eeg_abnormal", self.version) + + def _get_metadata(self, file: File) -> Metadata: + meta = file.path.split(os.path.sep) + metadata = super()._get_metadata(file) + metadata.channels_reference = meta[-5] + return metadata + + def _get_annotation(self, file: File) -> List[Annotation]: + logging.debug("Add file %s annotations to index", file.uuid) + return [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=60, + end=120, + label=file.path.split(os.path.sep)[-6], + ) + ] diff --git a/pyeeglab/dataset/tuh_eeg/artifact_dataset.py b/pyeeglab/dataset/tuh_eeg/artifact_dataset.py new file mode 100644 index 0000000..23ee37e --- /dev/null +++ b/pyeeglab/dataset/tuh_eeg/artifact_dataset.py @@ -0,0 +1,51 @@ +import os +import logging + +from typing import List + +from .utils import rsync, parse_tse + +from ..dataset import Dataset +from ..file import File +from ..metadata import Metadata +from ..annotation import Annotation + + +class TUHEEGArtifactDataset(Dataset): + + def __init__( + self, + path: str = "./data/tuh_eeg_artifact/", + version: str = "1.0.0", + exclude_channels_set: List[str] = [ + "BURSTS", + "EMG-REF", + "IBI", + "PHOTIC-REF", + "SUPPR" + ], + exclude_channels_reference: List[str] = [ + "02_tcp_le", + "03_tcp_ar_a" + ], + ) -> None: + super().__init__( + path=path, + name="Temple University Hospital EEG Artifact Dataset", + version="v"+version, + exclude_channels_set=exclude_channels_set, + exclude_channels_reference=exclude_channels_reference, + ) + + def download(self, user: str = None, password: str = None) -> None: + rsync(self.path, user, password, "tuh_eeg_artifact", self.version) + + def _get_metadata(self, file: File) -> Metadata: + meta = file.path.split(os.path.sep) + metadata = super()._get_metadata(file) + metadata.channels_reference = meta[-5] + return metadata + + def _get_annotation(self, file: File) -> List[Annotation]: + logging.debug("Add file %s annotations to index", file.uuid) + return parse_tse(file) diff --git a/pyeeglab/dataset/tuh_eeg/seizure_dataset.py b/pyeeglab/dataset/tuh_eeg/seizure_dataset.py new file mode 100644 index 0000000..593ef3d --- /dev/null +++ b/pyeeglab/dataset/tuh_eeg/seizure_dataset.py @@ -0,0 +1,57 @@ +import os +import re +import logging + +from uuid import uuid4 + +from typing import List + +from .utils import rsync, parse_lbl + +from ..dataset import Dataset +from ..file import File +from ..metadata import Metadata +from ..annotation import Annotation + + +class TUHEEGSeizureDataset(Dataset): + + def __init__( + self, + path: str = "./data/tuh_eeg_seizure/", + version: str = "1.5.2", + exclude_channels_set: List[str] = [ + "BURSTS", + "ECG EKG-REF", + "EMG-REF", + "IBI", + "PHOTIC-REF", + "PULSE RATE", + "RESP ABDOMEN-RE", + "SUPPR" + ], + exclude_channels_reference: List[str] = [ + "02_tcp_le", + "03_tcp_ar_a" + ], + ) -> None: + super().__init__( + path=path, + name="Temple University Hospital EEG Seizure Dataset", + version="v"+version, + exclude_channels_set=exclude_channels_set, + exclude_channels_reference=exclude_channels_reference, + ) + + def download(self, user: str = None, password: str = None) -> None: + rsync(self.path, user, password, "tuh_eeg_seizure", self.version) + + def _get_metadata(self, file: File) -> Metadata: + meta = file.path.split(os.path.sep) + metadata = super()._get_metadata(file) + metadata.channels_reference = meta[-5] + return metadata + + def _get_annotation(self, file: File) -> List[Annotation]: + logging.debug("Add file %s annotations to index", file.uuid) + return parse_lbl(file) diff --git a/pyeeglab/dataset/tuh_eeg/utils.py b/pyeeglab/dataset/tuh_eeg/utils.py new file mode 100644 index 0000000..d4a1d4c --- /dev/null +++ b/pyeeglab/dataset/tuh_eeg/utils.py @@ -0,0 +1,96 @@ +import re +import logging +import subprocess + +from uuid import uuid4 + +from typing import List + +from ..file import File +from ..annotation import Annotation + + +def rsync(path: str, user: str, password: str, slug: str, version: str) -> None: + if user is not None and password is not None: + logging.info("Download started, it will take some time") + url = user + "@" + "www.isip.piconepress.com:~/data/" + url = url + slug + "/v" + version + "/" + process = subprocess.Popen( + [ + "sshpass", + "-p", + password, + "rsync", + "-auxvL", + url, + path + ], + stdout=subprocess.PIPE, + universal_newlines=True, + ) + while True: + output = process.stdout.readline() + logging.info(output.strip()) + if process.poll() is not None: + for output in process.stdout.readlines(): + logging.info(output.strip()) + break + else: + logging.warn("Download disabled, add 'user' and 'password' as optional parameters during data set creation") + logging.warn("Request your login credentials at: https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php") + + +def parse_lbl(file: File) -> List[Annotation]: + path = file.path[:-4] + ".lbl" + with open(path, "r") as reader: + annotations = reader.read() + symbols = re.compile( + r"^symbols\[0\] = ({.*})$", + re.MULTILINE + ) + symbols = re.findall(symbols, annotations) + symbols = eval(symbols[0]) + pattern = re.compile( + r"^label = {([^,]*), ([^,]*), ([^,]*), ([^,]*), ([^,]*), ([^}]*)}$", + re.MULTILINE + ) + annotations = re.findall(pattern, annotations) + annotations = { + (annotation[2], annotation[3], symbols[index]) + for annotation in annotations + for index, value in enumerate(eval(annotation[5])) + if value > 0 + } + annotations = [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=float(annotation[0]), + end=float(annotation[1]), + label=annotation[2] + ) + for annotation in annotations + ] + return annotations + + +def parse_tse(file: File) -> List[Annotation]: + path = file.path[:-4] + ".tse" + with open(path, "r") as reader: + annotations = reader.read() + pattern = re.compile( + r"^(\d+.\d+) (\d+.\d+) (\w+) (\d.\d+)$", + re.MULTILINE + ) + annotations = re.findall(pattern, annotations) + annotations = [ + Annotation( + uuid=str(uuid4()), + file_uuid=file.uuid, + begin=float(annotation[0]), + end=float(annotation[1]), + label=annotation[2] + ) + for annotation in annotations + ] + return annotations diff --git a/pyeeglab/dataset/tuh_eeg_abnormal/__init__.py b/pyeeglab/dataset/tuh_eeg_abnormal/__init__.py deleted file mode 100644 index 926acfd..0000000 --- a/pyeeglab/dataset/tuh_eeg_abnormal/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .tuh_eeg_abnormal_loader import TUHEEGAbnormalLoader -from .tuh_eeg_abnormal_dataset import TUHEEGAbnormalDataset diff --git a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_dataset.py b/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_dataset.py deleted file mode 100644 index dcad1d8..0000000 --- a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_dataset.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Dict - -from .tuh_eeg_abnormal_loader import TUHEEGAbnormalLoader -from ..dataset import Dataset - - -class TUHEEGAbnormalDataset(Dataset): - - def __init__(self, path: str = './data/tuh_eeg_abnormal/v2.0.0/edf/') -> None: - super().__init__(TUHEEGAbnormalLoader(path)) - - def _get_dataset_env(self) -> Dict: - env = super()._get_dataset_env() - blacklist = ['IBI', 'BURSTS', 'STI 014', 'SUPPR'] - env['channels_set'] = list(set(env['channels_set']) - set(blacklist)) - env['class_id'] = 'normal' - return env diff --git a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_index.py b/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_index.py deleted file mode 100644 index 58ee4ba..0000000 --- a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_index.py +++ /dev/null @@ -1,46 +0,0 @@ -import logging -import json - -from typing import List - -from uuid import uuid4, uuid5, NAMESPACE_X500 -from os.path import join, sep - -from ...io import Index, Raw -from ...database import File, Metadata, Event - - -class TUHEEGAbnormalIndex(Index): - def __init__(self, path: str) -> None: - logging.debug('Create TUH EEG Corpus Index') - super().__init__('sqlite:///' + join(path, 'index.db'), path) - self.index() - - def _get_file(self, path: str) -> File: - length = len(self.path) - meta = path[length:].split(sep) - return File( - id=str(uuid5(NAMESPACE_X500, path[length:])), - extension=meta[-1].split('.')[-1], - path=path[length:] - ) - - def _get_record_metadata(self, file: File) -> Metadata: - meta = file.path.split(sep) - metadata = super()._get_record_metadata(file) - metadata.channels_reference = meta[2] - return metadata - - def _get_record_events(self, file: File) -> List[Event]: - logging.debug('Add file %s raw events to index', file.id) - raw = Raw(file.id, join(self.path, file.path)) - return [ - Event( - id=str(uuid4()), - file_id=raw.id, - begin=60, - end=120, - duration=60, - label=raw.path.split(sep)[-6] - ) - ] diff --git a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_loader.py b/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_loader.py deleted file mode 100644 index 2adabcc..0000000 --- a/pyeeglab/dataset/tuh_eeg_abnormal/tuh_eeg_abnormal_loader.py +++ /dev/null @@ -1,14 +0,0 @@ -import logging - -from ...io import DataLoader -from .tuh_eeg_abnormal_index import TUHEEGAbnormalIndex - -from typing import List - - -class TUHEEGAbnormalLoader(DataLoader): - - def __init__(self, path: str) -> None: - super().__init__(path) - logging.debug('Create TUH EEG Corpus Loader') - self.index = TUHEEGAbnormalIndex(self.path) diff --git a/pyeeglab/dataset/tuh_eeg_artifact/__init__.py b/pyeeglab/dataset/tuh_eeg_artifact/__init__.py deleted file mode 100644 index e61f493..0000000 --- a/pyeeglab/dataset/tuh_eeg_artifact/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .tuh_eeg_artifact_loader import TUHEEGArtifactLoader -from .tuh_eeg_artifact_dataset import TUHEEGArtifactDataset diff --git a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_dataset.py b/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_dataset.py deleted file mode 100644 index 455b945..0000000 --- a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_dataset.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Dict - -from .tuh_eeg_artifact_loader import TUHEEGArtifactLoader -from ..dataset import Dataset - - -class TUHEEGArtifactDataset(Dataset): - - def __init__(self, path: str = './data/tuh_eeg_artifact/v2.0.0/edf/') -> None: - super().__init__(TUHEEGArtifactLoader(path)) - - def _get_dataset_env(self) -> Dict: - env = super()._get_dataset_env() - env['class_id'] = 'bckg' - return env diff --git a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_index.py b/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_index.py deleted file mode 100644 index b9c74d3..0000000 --- a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_index.py +++ /dev/null @@ -1,56 +0,0 @@ -import logging -import re -import json - -from typing import List - -from uuid import uuid4, uuid5, NAMESPACE_X500 -from os.path import join, sep - -from ...io import Index, Raw -from ...database import File, Metadata, Event - - -class TUHEEGArtifactIndex(Index): - - def __init__(self, path: str) -> None: - logging.debug('Create TUH EEG Corpus Index') - super().__init__('sqlite:///' + join(path, 'index.db'), path) - self.index() - - def _get_file(self, path: str) -> File: - length = len(self.path) - meta = path[length:].split(sep) - return File( - id=str(uuid5(NAMESPACE_X500, path[length:])), - extension=meta[-1].split('.')[-1], - path=path[length:] - ) - - def _get_record_metadata(self, file: File) -> Metadata: - meta = file.path.split(sep) - metadata = super()._get_record_metadata(file) - metadata.channels_reference = meta[0] - return metadata - - def _get_record_events(self, file: File) -> List[Event]: - logging.debug('Add file %s raw events to index', file.id) - raw = Raw(file.id, join(self.path, file.path)) - path = raw.path[:-4] + '.tse' - with open(path, 'r') as file: - annotations = file.read() - pattern = re.compile( - r'^(\d+.\d+) (\d+.\d+) (\w+) (\d.\d+)$', re.MULTILINE) - events = re.findall(pattern, annotations) - events = [ - Event( - id=str(uuid4()), - file_id=raw.id, - begin=float(event[0]), - end=float(event[1]), - duration=(float(event[1])-float(event[0])), - label=event[2] - ) - for event in events - ] - return events diff --git a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_loader.py b/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_loader.py deleted file mode 100644 index 491742c..0000000 --- a/pyeeglab/dataset/tuh_eeg_artifact/tuh_eeg_artifact_loader.py +++ /dev/null @@ -1,17 +0,0 @@ -import logging - -from typing import List - -from ...io import DataLoader -from .tuh_eeg_artifact_index import TUHEEGArtifactIndex - - -class TUHEEGArtifactLoader(DataLoader): - - def __init__(self, path: str, exclude_channels_reference: List[str] = ['02_tcp_le', '03_tcp_ar_a']) -> None: - exclude_files = [ - '01_tcp_ar/101/00010158/s001_2013_01_14/00010158_s001_t001.edf', # Corrupted data - ] - super().__init__(path, exclude_channels_reference=exclude_channels_reference, exclude_files=exclude_files) - logging.debug('Create TUH EEG Corpus Loader') - self.index = TUHEEGArtifactIndex(self.path) diff --git a/pyeeglab/dataset/tuh_eeg_seizure/__init__.py b/pyeeglab/dataset/tuh_eeg_seizure/__init__.py deleted file mode 100644 index abf1a6c..0000000 --- a/pyeeglab/dataset/tuh_eeg_seizure/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .tuh_eeg_seizure_loader import TUHEEGSeizureLoader -from .tuh_eeg_seizure_dataset import TUHEEGSeizureDataset diff --git a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_dataset.py b/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_dataset.py deleted file mode 100644 index 8f391cd..0000000 --- a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_dataset.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Dict - -from .tuh_eeg_seizure_loader import TUHEEGSeizureLoader -from ..dataset import Dataset - - -class TUHEEGSeizureDataset(Dataset): - - def __init__(self, path: str = './data/tuh_eeg_seizure/v1.5.2/edf/') -> None: - super().__init__(TUHEEGSeizureLoader(path)) - - def _get_dataset_env(self) -> Dict: - env = super()._get_dataset_env() - env['class_id'] = 'fnsz' - return env diff --git a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_index.py b/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_index.py deleted file mode 100644 index 1d243d2..0000000 --- a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_index.py +++ /dev/null @@ -1,83 +0,0 @@ -import logging -import re -import json -import numpy as np - -from typing import List - -from uuid import uuid4, uuid5, NAMESPACE_X500 -from os.path import join, sep - -from ...io import Index, Raw -from ...database import File, Metadata, Event - - -class TUHEEGSeizureIndex(Index): - - def __init__(self, path: str, exclude_events: List = ['bckg', 'spsz']) -> None: - logging.debug('Create TUH EEG Corpus Index') - super().__init__('sqlite:///' + join(path, 'index.db'), path, exclude_events=exclude_events) - self.index() - - def _get_file(self, path: str) -> File: - length = len(self.path) - meta = path[length:].split(sep) - return File( - id=str(uuid5(NAMESPACE_X500, path[length:])), - extension=meta[-1].split('.')[-1], - path=path[length:] - ) - - def _get_record_metadata(self, file: File) -> Metadata: - meta = file.path.split(sep) - metadata = super()._get_record_metadata(file) - metadata.channels_reference = meta[1] - return metadata - - def _get_record_events(self, file: File) -> List[Event]: - logging.debug('Add file %s raw events to index', file.id) - raw = Raw(file.id, join(self.path, file.path)) - path = raw.path[:-4] + '.lbl' - with open(path, 'r') as file: - annotations = file.read() - pattern = re.compile( - r'^symbols\[0\] = ({.*})$', - re.MULTILINE - ) - mapper = re.findall(pattern, annotations) - mapper = eval(mapper[0]) - events = re.findall(pattern, annotations) - pattern = re.compile( - r'^label = {([^,]*), ([^,]*), ([^,]*), ([^,]*), ([^,]*), ([^}]*)}$', - re.MULTILINE - ) - events = re.findall(pattern, annotations) - labels = {} - for event in events: - intervall = (event[2], event[3]) - if intervall not in labels: - labels[intervall] = {} - channel = event[4] - if channel not in labels[intervall]: - labels[intervall][channel] = [] - label = np.array(json.loads(event[5])) - for i in list(np.nonzero(label)[0]): - labels[intervall][channel].append(mapper[i]) - labels = { - (intervall[0], intervall[1], label) - for intervall, channels in labels.items() - for channel, labels in channels.items() - for label in labels - } - events = [ - Event( - id=str(uuid4()), - file_id=raw.id, - begin=float(label[0]), - end=float(label[1]), - duration=(float(label[1])-float(label[0])), - label=label[2] - ) - for label in labels - ] - return events diff --git a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_loader.py b/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_loader.py deleted file mode 100644 index 5f2627a..0000000 --- a/pyeeglab/dataset/tuh_eeg_seizure/tuh_eeg_seizure_loader.py +++ /dev/null @@ -1,14 +0,0 @@ -import logging - -from typing import List - -from ...io import DataLoader -from .tuh_eeg_seizure_index import TUHEEGSeizureIndex - - -class TUHEEGSeizureLoader(DataLoader): - - def __init__(self, path: str, exclude_channels_reference: List[str] = ['02_tcp_le', '02_tcp_ar_a']) -> None: - super().__init__(path, exclude_channels_reference=exclude_channels_reference) - logging.debug('Create TUH EEG Corpus Loader') - self.index = TUHEEGSeizureIndex(self.path) diff --git a/pyeeglab/io/__init__.py b/pyeeglab/io/__init__.py deleted file mode 100644 index 12695ed..0000000 --- a/pyeeglab/io/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .index import Index -from .loader import DataLoader -from .raw import Raw diff --git a/pyeeglab/io/index.py b/pyeeglab/io/index.py deleted file mode 100644 index 9b13638..0000000 --- a/pyeeglab/io/index.py +++ /dev/null @@ -1,139 +0,0 @@ -import logging -import json - -from abc import ABC, abstractmethod - -from os import walk -from os.path import join, splitext -from uuid import uuid4 -from multiprocessing import Pool, cpu_count -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from mne import set_log_file - -from ..database import BASE_TABLE, File, Metadata, Event -from .raw import Raw - -from typing import List - - -class Index(ABC): - """ An abstract class representing the Index mechanism that is used to - discover the dataset structure. - - It is the first component that must be implemented in order to provide - full support to a specific dataset. If the dataset structure is regular, - the only method that must be implemented id "_get_file(path)". - """ - - def __init__(self, db: str, path: str, include_extensions: List[str] = ['edf'], exclude_events: List[str] = None) -> None: - """ - Parameters - ---------- - db : str - The database connection handle expressed as string. This is usually - configured by the Loader class as a sqlite handle, but in theory it - could be used with any type of database connection. - - """ - logging.debug('Create index at %s', db) - logging.debug('Load index at %s', db) - engine = create_engine(db) - BASE_TABLE.metadata.create_all(engine) - self.db = sessionmaker(bind=engine)() - self.path = path - self.include_extensions = include_extensions - self.exclude_events = exclude_events - logging.debug('Redirect MNE logging interface to file') - set_log_file(join(path, 'mne.log'), overwrite=False) - - def __getstate__(self): - # Workaround for unpickable sqlalchemy.orm.session - # during multiprocess dataset loading - state = self.__dict__.copy() - del state['db'] - return state - - def _get_paths(self, exclude_extensions: List[str] = ['.db', '.gz', '.log']) -> List[str]: - logging.debug('Get files from path') - paths = [ - join(dirpath, filename) - for dirpath, _, filenames in walk(self.path) - for filename in filenames - if splitext(filename)[1] not in exclude_extensions - ] - return paths - - @abstractmethod - def _get_file(self, path: str) -> File: - pass - - def _get_files(self, paths: List[str]) -> List[File]: - pool = Pool(cpu_count()) - files = pool.map(self._get_file, paths) - pool.close() - pool.join() - return files - - def _get_record_metadata(self, file: File) -> Metadata: - logging.debug('Add file %s raw metadata to index', file.id) - raw = Raw(file.id, join(self.path, file.path)) - return Metadata( - file_id=raw.id, - duration=raw.open().n_times/raw.open().info['sfreq'], - channels_count=raw.open().info['nchan'], - channels_set=json.dumps(raw.open().info['ch_names']), - sampling_frequency=raw.open().info['sfreq'], - max_value=raw.open().get_data().max(), - min_value=raw.open().get_data().min(), - ) - - def _parallel_record_metadata(self, files: List[File]) -> List[Metadata]: - pool = Pool(cpu_count()) - metadata = pool.map(self._get_record_metadata, files) - pool.close() - pool.join() - return metadata - - def _get_record_events(self, file: File) -> List[Event]: - logging.debug('Add file %s raw events to index', file.id) - raw = Raw(file.id, join(self.path, file.path)) - events = raw.get_events() - for event in events: - event['id'] = str(uuid4()) - event['file_id'] = raw.id - events = [Event(**event) for event in events] - return events - - def _parallel_record_events(self, files: List[File]) -> List[Event]: - pool = Pool(cpu_count()) - events = pool.map(self._get_record_events, files) - pool.close() - pool.join() - events = [e for event in events for e in event] - if self.exclude_events: - events = [ - e for e in events - if e.label not in self.exclude_events - ] - return events - - def index(self) -> None: - logging.debug('Index files') - files = self._get_paths() - files = self._get_files(files) - files = [ - file - for file in files - if not self.db.query(File).filter(File.id == file.id).all() - ] - for file in files: - logging.debug('Add file %s raw to index', file.id) - if self.include_extensions: - raws = [ - file for file in files if file.extension in self.include_extensions] - metadata = self._parallel_record_metadata(raws) - events = self._parallel_record_events(raws) - self.db.add_all(files + metadata + events) - self.db.commit() - logging.debug('Index files completed') diff --git a/pyeeglab/io/loader.py b/pyeeglab/io/loader.py deleted file mode 100644 index 73e732a..0000000 --- a/pyeeglab/io/loader.py +++ /dev/null @@ -1,153 +0,0 @@ -import logging -import json - -from abc import ABC -from os.path import isfile, join, sep -from hashlib import md5 -from multiprocessing import Pool, cpu_count - -from ..database import File, Metadata, Event -from .raw import Raw -from .index import Index - -from typing import List, Dict - - -class DataLoader(ABC): - - index: Index - - def __init__(self, path: str, exclude_channels_reference: List[str] = None, exclude_frequency: List[int] = None, - exclude_files: List[str] = None, minimum_event_duration: float = -1) -> None: - logging.debug('Create data loader') - if path[-1] != sep: - path = path + sep - self.path = path - self.exclude_channels_reference = exclude_channels_reference - self.exclude_frequency = exclude_frequency - self.exclude_files = exclude_files - self.minimum_event_duration = minimum_event_duration - - def __getstate__(self): - # Workaround for unpickable sqlalchemy.orm.session - # during multiprocess dataset loading - state = self.__dict__.copy() - del state['index'] - return state - - def _get_data_by_event(self, f: File, e: Event) -> Raw: - path_edf = join(self.path, f.path) - path_fif = path_edf + '_' + e.id + '_raw.fif.gz' - if not isfile(path_fif): - edf = Raw(f.id, path_edf, e.label) - edf.crop(e.begin, e.end-e.begin) - edf.open().save(path_fif) - fif = Raw(f.id, path_fif, e.label) - return fif - - def set_minimum_event_duration(self, minimum_event_duration: float) -> 'DataLoader': - self.minimum_event_duration = minimum_event_duration - - def get_dataset(self) -> List[Raw]: - files = self.index.db.query(File, Metadata, Event) - files = files.filter(File.id == Metadata.file_id) - files = files.filter(File.id == Event.file_id) - if self.index.include_extensions: - files = files.filter(File.extension.in_(self.index.include_extensions)) - if self.exclude_channels_reference: - files = files.filter(~Metadata.channels_reference.in_(self.exclude_channels_reference)) - if self.exclude_frequency: - files = files.filter(~Metadata.sampling_frequency.in_(self.exclude_frequency)) - if self.exclude_files: - files = files.filter(~File.path.in_(self.exclude_files)) - if self.minimum_event_duration > 0: - files = files.filter(Event.duration >= self.minimum_event_duration) - files = files.all() - files = [(file[0], file[2]) for file in files] - pool = Pool(cpu_count()) - fifs = pool.starmap(self._get_data_by_event, files) - pool.close() - pool.join() - return fifs - - def get_dataset_text(self) -> Dict: - txts = self.index.db.query(File, Event) - txts = txts.filter(File.id == Event.file_id) - txts = txts.filter(File.extension == 'txt') - txts = txts.all() - txts = {f.id: (join(self.index.path, f.path), e.label) for f, e in txts} - return txts - - def get_channels_set(self) -> List[str]: - files = self.index.db.query(File, Metadata) - files = files.filter(File.id == Metadata.file_id) - if self.exclude_channels_reference: - files = files.filter(~Metadata.channels_reference.in_(self.exclude_channels_reference)) - if self.exclude_frequency: - files = files.filter(~Metadata.sampling_frequency.in_(self.exclude_frequency)) - if self.exclude_files: - files = files.filter(~File.path.in_(self.exclude_files)) - if self.minimum_event_duration > 0: - files = files.filter(Event.duration >= self.minimum_event_duration) - files = files.group_by(Metadata.channels_set) - files = files.all() - files = [file[1] for file in files] - files = [set(json.loads(file.channels_set)) for file in files] - channels_set = files[0] - for file in files[1:]: - channels_set = channels_set.intersection(file) - return sorted(channels_set) - - def get_lowest_frequency(self) -> float: - frequency = self.index.db.query(Metadata) - if self.exclude_frequency: - frequency = frequency.filter(~Metadata.sampling_frequency.in_(self.exclude_frequency)) - frequency = frequency.all() - frequency = min([f.sampling_frequency for f in frequency], default=0) - return frequency - - def get_max_value(self) -> float: - files = self.index.db.query(File, Metadata) - files = files.filter(File.id == Metadata.file_id) - if self.exclude_channels_reference: - files = files.filter(~Metadata.channels_reference.in_(self.exclude_channels_reference)) - if self.exclude_frequency: - files = files.filter(~Metadata.sampling_frequency.in_(self.exclude_frequency)) - if self.exclude_files: - files = files.filter(~File.path.in_(self.exclude_files)) - if self.minimum_event_duration > 0: - files = files.filter(Event.duration >= self.minimum_event_duration) - files = files.group_by(Metadata.channels_set) - files = files.all() - max_value = max([f.max_value for _, f in files], default=0) - return max_value - - def get_min_value(self) -> float: - files = self.index.db.query(File, Metadata) - files = files.filter(File.id == Metadata.file_id) - if self.exclude_channels_reference: - files = files.filter(~Metadata.channels_reference.in_(self.exclude_channels_reference)) - if self.exclude_frequency: - files = files.filter(~Metadata.sampling_frequency.in_(self.exclude_frequency)) - if self.exclude_files: - files = files.filter(~File.path.in_(self.exclude_files)) - if self.minimum_event_duration > 0: - files = files.filter(Event.duration >= self.minimum_event_duration) - files = files.group_by(Metadata.channels_set) - files = files.all() - min_value = min([f.min_value for _, f in files], default=0) - return min_value - - def __eq__(self, other): - return hash(self) == hash(other) - - def __hash__(self): - value = [self.path] + [self.minimum_event_duration] - if self.exclude_channels_reference: - value += self.exclude_channels_reference - if self.exclude_frequency: - value += self.exclude_frequency - value = json.dumps(value).encode() - value = md5(value).hexdigest() - value = int(value, 16) - return value diff --git a/pyeeglab/io/raw.py b/pyeeglab/io/raw.py deleted file mode 100644 index 75a41ff..0000000 --- a/pyeeglab/io/raw.py +++ /dev/null @@ -1,83 +0,0 @@ -import logging -from typing import List -from importlib.util import find_spec - -from mne.io import Raw as Reader -from mne.io import read_raw_edf, read_raw_fif - -class Raw(): - - reader: Reader = None - - def __init__(self, fid: str, path: str, label: str = None) -> None: - self.id = fid - self.path = path - self.label = label - n_jobs = 1 - if find_spec('cupy') is not None: - n_jobs = 'cuda' - self.n_jobs = n_jobs - - def close(self) -> 'Raw': - if self.reader is not None: - logging.debug('Close Raw %s reader', self.id) - self.reader.close() - self.reader = None - return self - - def crop(self, offset: int, length: int) -> 'Raw': - logging.debug('Crop %s data to %s seconds from %s', self.id, length, offset) - tmax = self.open().n_times / self.open().info['sfreq'] - 0.1 - if offset + length < tmax: - tmax = offset + length - self.reader = self.open().crop(offset, tmax) - return self - - def open(self) -> Reader: - if self.reader is None: - if self.path.endswith('.edf'): - logging.debug('Open RawEDF %s reader', self.id) - try: - self.reader = read_raw_edf(self.path) - except RuntimeError: - logging.debug('Preload RawEDF %s reader', self.id) - self.reader = read_raw_edf(self.path, preload=True) - if self.path.endswith('.fif.gz'): - logging.debug('Open RawFIF %s reader', self.id) - self.reader = read_raw_fif(self.path) - return self.reader - - def get_events(self) -> List: - events = self.open().annotations - events = list(zip(events.onset, events.duration, events.description)) - events = [(event[0], event[0] + event[1], event[1], event[2]) for event in events] - keys = ['begin', 'end', 'duration', 'label'] - events = [dict(zip(keys, event)) for event in events] - return events - - def set_channels(self, channels: List[str]) -> 'Raw': - channels = set(self.open().ch_names) - set(channels) - channels = list(channels) - if len(channels) > 0: - logging.debug('Drop %s channels %s', self.id, '|'.join(channels)) - self.reader = self.open().drop_channels(channels) - return self - - def set_frequency(self, frequency: float) -> 'Raw': - sfreq = self.open().info['sfreq'] - if sfreq > frequency: - logging.debug('Downsample %s from %s Hz to %s Hz', self.id, sfreq, frequency) - self.reader = self.open().resample(frequency, n_jobs=self.n_jobs) - return self - - def set_filter(self, low_freq: float = None, high_freq: float = None) -> 'Raw': - if low_freq is not None or high_freq is not None: - logging.debug('Filter %s with low_req: %s Hz and high_freq: %s Hz', self.id, low_freq, high_freq) - self.reader = self.open().filter(low_freq, high_freq, n_jobs=self.n_jobs) - return self - - def notch_filter(self, freq: float) -> 'Raw': - if freq is not None: - logging.debug('Notch filter %s with freq: %s Hz ', self.id, freq) - self.reader = self.open().notch_filter(freq, n_jobs=self.n_jobs) - return self diff --git a/pyeeglab/pipeline/pipeline.py b/pyeeglab/pipeline/pipeline.py index e270611..ec082ff 100644 --- a/pyeeglab/pipeline/pipeline.py +++ b/pyeeglab/pipeline/pipeline.py @@ -1,18 +1,17 @@ -import logging import json - -import numpy as np -import pandas as pd +import logging from hashlib import md5 from os.path import join from multiprocessing import Pool, cpu_count -from ..io import Raw -from .preprocessor import Preprocessor - from typing import Dict, List +import numpy as np +import pandas as pd + +from .preprocessor import Preprocessor + class Pipeline(): @@ -32,11 +31,13 @@ def _check_nans(self, data): nans = data.isnull().values.any() return nans - def _trigger_pipeline(self, data: Raw, kwargs): - file_id = data.id - data.open().load_data() - for preprocessor in self.pipeline: - data = preprocessor.run(data, **kwargs) + def _trigger_pipeline(self, annotation, kwargs): + data = None + + with annotation as reader: + data = reader.load_data() + for preprocessor in self.pipeline: + data = preprocessor.run(data, **kwargs) nans = False if isinstance(data, list): @@ -44,11 +45,11 @@ def _trigger_pipeline(self, data: Raw, kwargs): else: nans = self._check_nans(data) if nans: - raise ValueError('Nans found in file with id {}'.format(file_id)) + raise ValueError('Nans found in file with id {}'.format(annotation.file_uuid)) return data - def run(self, data: List[Raw]) -> Dict: + def run(self, data) -> Dict: logging.debug('Environment variables: {}'.format( str(self.environment) )) diff --git a/pyeeglab/preprocess/features/brain_connectivity.py b/pyeeglab/preprocess/features/brain_connectivity.py index 96b96ac..ff3af4c 100644 --- a/pyeeglab/preprocess/features/brain_connectivity.py +++ b/pyeeglab/preprocess/features/brain_connectivity.py @@ -1,16 +1,15 @@ -import logging import json +import logging + +from itertools import product +from typing import List import numpy as np import pandas as pd - from yasa import bandpower -from itertools import product from ...pipeline import Preprocessor -from typing import List - class SpearmanCorrelation(Preprocessor): diff --git a/pyeeglab/preprocess/features/stat_features.py b/pyeeglab/preprocess/features/stat_features.py index 234ca03..5121fec 100644 --- a/pyeeglab/preprocess/features/stat_features.py +++ b/pyeeglab/preprocess/features/stat_features.py @@ -1,14 +1,11 @@ -import logging +from typing import List import numpy as np import pandas as pd - from scipy.integrate import simps from ...pipeline import Preprocessor -from typing import List - class Mean(Preprocessor): def run(self, data: List[pd.DataFrame], **kwargs) -> List[pd.DataFrame]: diff --git a/pyeeglab/preprocess/signal/channel_selector.py b/pyeeglab/preprocess/signal/channel_selector.py index 124ffaf..9f1fb4a 100644 --- a/pyeeglab/preprocess/signal/channel_selector.py +++ b/pyeeglab/preprocess/signal/channel_selector.py @@ -1,18 +1,18 @@ -import logging import json - -from ...io import Raw -from ...pipeline import Preprocessor +import logging from typing import List +from mne.io import Raw + +from ...pipeline import Preprocessor class CommonChannelSet(Preprocessor): - def __init__(self, blacklist: List[str] = []) -> None: + def __init__(self, blacklist: List[str] = None) -> None: super().__init__() logging.debug('Create common channels_set preprocessor') - self.blacklist = blacklist + self.blacklist = blacklist if blacklist else [] def to_json(self) -> str: out = { @@ -24,7 +24,9 @@ def to_json(self) -> str: return out def run(self, data: Raw, **kwargs) -> Raw: - channels = set(kwargs['channels_set']) - set(self.blacklist) - channels = list(channels) - data.set_channels(channels) + channels = set(data.ch_names) + channels = channels.difference(set(kwargs['channels_set'])) + data = data.drop_channels(channels) + data = data.drop_channels(self.blacklist) + data = data.reorder_channels(kwargs['channels_set']) return data diff --git a/pyeeglab/preprocess/signal/filter_selector.py b/pyeeglab/preprocess/signal/filter_selector.py index 33c7b15..2007338 100644 --- a/pyeeglab/preprocess/signal/filter_selector.py +++ b/pyeeglab/preprocess/signal/filter_selector.py @@ -1,7 +1,8 @@ -import logging import json +import logging + +from mne.io import Raw -from ...io import Raw from ...pipeline import Preprocessor @@ -24,8 +25,7 @@ def to_json(self) -> str: return out def run(self, data: Raw, **kwargs) -> Raw: - data.set_filter(self.low_freq, self.high_freq) - return data + return data.filter(self.low_freq, self.high_freq) class NotchFrequency(Preprocessor): @@ -45,5 +45,4 @@ def to_json(self) -> str: return out def run(self, data: Raw, **kwargs) -> Raw: - data.notch_filter(self.freq) - return data + return data.notch_filter(self.freq) diff --git a/pyeeglab/preprocess/signal/frequency_selector.py b/pyeeglab/preprocess/signal/frequency_selector.py index d3aca97..fc606c1 100644 --- a/pyeeglab/preprocess/signal/frequency_selector.py +++ b/pyeeglab/preprocess/signal/frequency_selector.py @@ -1,6 +1,7 @@ import logging -from ...io import Raw +from mne.io import Raw + from ...pipeline import Preprocessor @@ -11,5 +12,7 @@ def __init__(self) -> None: logging.debug('Create lowest_frequency preprocessor') def run(self, data: Raw, **kwargs) -> Raw: - data.set_frequency(kwargs['lowest_frequency']) + lowest_frequency = kwargs['lowest_frequency'] + if data.info['sfreq'] > lowest_frequency: + data = data.resample(lowest_frequency) return data diff --git a/pyeeglab/preprocess/signal/normalization.py b/pyeeglab/preprocess/signal/normalization.py index 7670ca2..11b07d4 100644 --- a/pyeeglab/preprocess/signal/normalization.py +++ b/pyeeglab/preprocess/signal/normalization.py @@ -1,12 +1,12 @@ import logging +from typing import List + import numpy as np import pandas as pd from ...pipeline import Preprocessor -from typing import List - class MinMaxNormalization(Preprocessor): def run(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame: diff --git a/pyeeglab/preprocess/transform/data_converter.py b/pyeeglab/preprocess/transform/data_converter.py index 3e2124f..06f747a 100644 --- a/pyeeglab/preprocess/transform/data_converter.py +++ b/pyeeglab/preprocess/transform/data_converter.py @@ -1,13 +1,14 @@ -import logging import json +import logging + +from typing import List import numpy as np import pandas as pd -from ...io import Raw -from ...pipeline import Preprocessor +from mne.io import Raw -from typing import List +from ...pipeline import Preprocessor class ToDataframe(Preprocessor): @@ -17,8 +18,7 @@ def __init__(self) -> None: logging.debug('Create DataFrame converter preprocessor') def run(self, data: Raw, **kwargs) -> pd.DataFrame: - dataframe = data.open().to_data_frame().drop('time', axis=1) - data.close() + dataframe = data.to_data_frame().drop('time', axis=1) return dataframe diff --git a/pyeeglab/preprocess/transform/frame_generator.py b/pyeeglab/preprocess/transform/frame_generator.py index 9c10a20..9fd9ee1 100644 --- a/pyeeglab/preprocess/transform/frame_generator.py +++ b/pyeeglab/preprocess/transform/frame_generator.py @@ -1,13 +1,12 @@ -import logging import json - -import pandas as pd +import logging from math import floor +from typing import List -from ...pipeline import Preprocessor +import pandas as pd -from typing import List +from ...pipeline import Preprocessor class StaticWindow(Preprocessor): diff --git a/requirements.txt b/requirements.txt index 8aa5a94..b2628f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,6 @@ scipy pandas wfdb sqlalchemy -mne==0.20.0 +dataclasses +mne>=0.21.0 yasa>=0.1.6 diff --git a/setup.py b/setup.py index 910c293..d6153c7 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='PyEEGLab', - version='0.9.4', + version='0.10.0', author='Alessio Zanga', author_email='alessio.zanga@outlook.it', license='GNU GENERAL PUBLIC LICENSE - Version 3, 29 June 2007', @@ -37,7 +37,8 @@ 'pandas', 'wfdb', 'sqlalchemy', - 'mne==0.20.0', + 'dataclasses', + 'mne>=0.21.0', 'yasa>=0.1.6', ], ) diff --git a/tests/test_chbmit.py b/tests/test_chbmit.py index 334296c..69ec621 100644 --- a/tests/test_chbmit.py +++ b/tests/test_chbmit.py @@ -6,20 +6,19 @@ from pyeeglab import * class TestCHBMIT(unittest.TestCase): - PATH = './tests/samples/physionet.org/files/chbmit/1.0.0' + PATH = './tests/samples/physionet.org/files/chbmit/' def test_index(self): - CHBMITLoader(self.PATH) + PhysioNetCHBMITDataset(self.PATH) def test_loader(self): - loader = CHBMITLoader(self.PATH) - loader.get_dataset() - loader.get_dataset_text() - loader.get_channels_set() - loader.get_lowest_frequency() + loader = PhysioNetCHBMITDataset(self.PATH) + loader.maximal_channels_subset + loader.lowest_frequency + loader.signal_min_max_range def test_dataset(self): - dataset = CHBMITDataset(self.PATH) + dataset = PhysioNetCHBMITDataset(self.PATH) preprocessing = Pipeline([ CommonChannelSet(), LowestFrequency(), diff --git a/tests/test_eegmmidb.py b/tests/test_eegmmidb.py index 0f22e13..4d9248e 100644 --- a/tests/test_eegmmidb.py +++ b/tests/test_eegmmidb.py @@ -6,20 +6,19 @@ from pyeeglab import * class TestEEGMMIDB(unittest.TestCase): - PATH = './tests/samples/physionet.org/files/eegmmidb/1.0.0' + PATH = './tests/samples/physionet.org/files/eegmmidb/' def test_index(self): - EEGMMIDBLoader(self.PATH) + PhysioNetEEGMMIDBDataset(self.PATH) def test_loader(self): - loader = EEGMMIDBLoader(self.PATH) - loader.get_dataset() - loader.get_dataset_text() - loader.get_channels_set() - loader.get_lowest_frequency() + loader = PhysioNetEEGMMIDBDataset(self.PATH) + loader.maximal_channels_subset + loader.lowest_frequency + loader.signal_min_max_range def test_dataset(self): - dataset = EEGMMIDBDataset(self.PATH) + dataset = PhysioNetEEGMMIDBDataset(self.PATH) preprocessing = Pipeline([ CommonChannelSet(), LowestFrequency(), diff --git a/tests/test_tuh_eeg_abnormal.py b/tests/test_tuh_eeg_abnormal.py index 6e29f7a..a9e8bcb 100644 --- a/tests/test_tuh_eeg_abnormal.py +++ b/tests/test_tuh_eeg_abnormal.py @@ -6,17 +6,16 @@ from pyeeglab import * class TestTUHEEGAbnormal(unittest.TestCase): - PATH = './tests/samples/tuh_eeg_abnormal/v2.0.0/edf' + PATH = './tests/samples/tuh_eeg_abnormal/' def test_index(self): - TUHEEGAbnormalLoader(self.PATH) + TUHEEGAbnormalDataset(self.PATH) def test_loader(self): - loader = TUHEEGAbnormalLoader(self.PATH) - loader.get_dataset() - loader.get_dataset_text() - loader.get_channels_set() - loader.get_lowest_frequency() + loader = TUHEEGAbnormalDataset(self.PATH) + loader.maximal_channels_subset + loader.lowest_frequency + loader.signal_min_max_range def test_dataset(self): dataset = TUHEEGAbnormalDataset(self.PATH) diff --git a/tests/test_tuh_eeg_artifact.py b/tests/test_tuh_eeg_artifact.py index 6cdb5f2..595ca91 100644 --- a/tests/test_tuh_eeg_artifact.py +++ b/tests/test_tuh_eeg_artifact.py @@ -6,22 +6,21 @@ from pyeeglab import * class TestTUHEEGArtifact(unittest.TestCase): - PATH = './tests/samples/tuh_eeg_artifact/v1.0.0/edf' + PATH = './tests/samples/tuh_eeg_artifact/' def test_index(self): - TUHEEGArtifactLoader(self.PATH) + TUHEEGArtifactDataset(self.PATH) def test_loader(self): - loader = TUHEEGArtifactLoader(self.PATH) - loader.get_dataset() - loader.get_dataset_text() - loader.get_channels_set() - loader.get_lowest_frequency() + loader = TUHEEGArtifactDataset(self.PATH) + loader.maximal_channels_subset + loader.lowest_frequency + loader.signal_min_max_range def test_dataset(self): dataset = TUHEEGArtifactDataset(self.PATH) preprocessing = Pipeline([ - CommonChannelSet(['EEG T1-REF', 'EEG T2-REF']), + CommonChannelSet(), LowestFrequency(), BandPassFrequency(0.1, 47), ToDataframe(),