From b899a8913868afcc469f9e9c72e8ff30fef6b5ad Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 21 Aug 2024 17:35:32 -0400 Subject: [PATCH 01/62] Updating to MEDS v0.3.2 by correcting the subject ID field name. --- MIMIC-IV_Example/README.md | 16 +-- MIMIC-IV_Example/configs/event_configs.yaml | 14 +- MIMIC-IV_Example/joint_script.sh | 8 +- MIMIC-IV_Example/joint_script_slurm.sh | 6 +- README.md | 72 +++++----- eICU_Example/README.md | 32 ++--- eICU_Example/configs/event_configs.yaml | 6 +- eICU_Example/joint_script.sh | 10 +- eICU_Example/joint_script_slurm.sh | 6 +- pyproject.toml | 6 +- src/MEDS_transforms/__init__.py | 11 +- .../aggregate_code_metadata.py | 75 +++++----- src/MEDS_transforms/configs/extract.yaml | 4 +- src/MEDS_transforms/configs/preprocess.yaml | 4 +- .../stage_configs/count_code_occurrences.yaml | 2 +- .../stage_configs/filter_measurements.yaml | 2 +- .../stage_configs/filter_patients.yaml | 3 - .../stage_configs/filter_subjects.yaml | 3 + .../stage_configs/fit_normalization.yaml | 2 +- .../stage_configs/reshard_to_split.yaml | 2 +- ...nts.yaml => split_and_shard_subjects.yaml} | 4 +- src/MEDS_transforms/extract/README.md | 22 +-- .../extract/convert_to_sharded_events.py | 80 +++++------ .../extract/extract_code_metadata.py | 8 +- .../extract/finalize_MEDS_data.py | 14 +- .../extract/finalize_MEDS_metadata.py | 29 ++-- .../extract/merge_to_MEDS_cohort.py | 36 ++--- src/MEDS_transforms/extract/shard_events.py | 23 +-- ...atients.py => split_and_shard_subjects.py} | 134 +++++++++--------- src/MEDS_transforms/filters/README.md | 2 +- .../filters/filter_measurements.py | 36 ++--- ...{filter_patients.py => filter_subjects.py} | 96 ++++++------- src/MEDS_transforms/mapreduce/mapper.py | 22 +-- src/MEDS_transforms/mapreduce/utils.py | 8 +- src/MEDS_transforms/reshard_to_split.py | 20 +-- .../add_time_derived_measurements.py | 46 +++--- .../transforms/normalization.py | 16 +-- .../transforms/occlude_outliers.py | 10 +- .../transforms/reorder_measurements.py | 22 +-- .../transforms/tensorization.py | 4 +- .../transforms/tokenization.py | 54 +++---- src/MEDS_transforms/utils.py | 22 +-- tests/test_add_time_derived_measurements.py | 19 +-- tests/test_aggregate_code_metadata.py | 10 +- tests/test_extract.py | 76 +++++----- tests/test_extract_no_metadata.py | 76 +++++----- tests/test_filter_measurements.py | 38 ++--- ...er_patients.py => test_filter_subjects.py} | 33 ++--- tests/test_fit_vocabulary_indices.py | 2 +- tests/test_multi_stage_preprocess_pipeline.py | 108 +++++++------- tests/test_normalization.py | 10 +- tests/test_occlude_outliers.py | 10 +- tests/test_reorder_measurements.py | 8 +- tests/test_reshard_to_split.py | 32 +++-- tests/test_tokenization.py | 26 ++-- tests/transform_tester_base.py | 27 ++-- tests/utils.py | 2 +- 57 files changed, 735 insertions(+), 734 deletions(-) delete mode 100644 src/MEDS_transforms/configs/stage_configs/filter_patients.yaml create mode 100644 src/MEDS_transforms/configs/stage_configs/filter_subjects.yaml rename src/MEDS_transforms/configs/stage_configs/{split_and_shard_patients.yaml => split_and_shard_subjects.yaml} (73%) rename src/MEDS_transforms/extract/{split_and_shard_patients.py => split_and_shard_subjects.py} (71%) rename src/MEDS_transforms/filters/{filter_patients.py => filter_subjects.py} (67%) rename tests/{test_filter_patients.py => test_filter_subjects.py} (75%) diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index c0038603..6bf348d0 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -76,10 +76,10 @@ This is a step in a few parts: - the `hosp/diagnoses_icd` table with the `hosp/admissions` table to get the `dischtime` for each `hadm_id`. - the `hosp/drgcodes` table with the `hosp/admissions` table to get the `dischtime` for each `hadm_id`. -2. Convert the patient's static data to a more parseable form. This entails: - - Get the patient's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and +2. Convert the subject's static data to a more parseable form. This entails: + - Get the subject's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and `anchor_offset` fields. - - Merge the patient's `dod` with the `deathtime` from the `admissions` table. + - Merge the subject's `dod` with the `deathtime` from the `admissions` table. After these steps, modified files or symlinks to the original files will be written in a new directory which will be used as the input to the actual MEDS extraction ETL. We'll use `$MIMICIV_PREMEDS_DIR` to denote this @@ -109,14 +109,14 @@ This is a step in 4 parts: This step uses the `./scripts/extraction/shard_events.py` script. See `joint_script*.sh` for the expected format of the command. -2. Extract and form the patient splits and sub-shards. The `./scripts/extraction/split_and_shard_patients.py` +2. Extract and form the subject splits and sub-shards. The `./scripts/extraction/split_and_shard_subjects.py` script is used for this step. See `joint_script*.sh` for the expected format of the command. -3. Extract patient sub-shards and convert to MEDS events. The +3. Extract subject sub-shards and convert to MEDS events. The `./scripts/extraction/convert_to_sharded_events.py` script is used for this step. See `joint_script*.sh` for the expected format of the command. -4. Merge the MEDS events into a single file per patient sub-shard. The +4. Merge the MEDS events into a single file per subject sub-shard. The `./scripts/extraction/merge_to_MEDS_cohort.py` script is used for this step. See `joint_script*.sh` for the expected format of the command. @@ -139,7 +139,7 @@ timeline which is otherwise stored at the _datetime_ resolution? Other questions: -1. How to handle merging the deathtimes between the hosp table and the patients table? +1. How to handle merging the deathtimes between the hosp table and the subjects table? 2. How to handle the dob nonsense MIMIC has? ## Notes @@ -153,4 +153,4 @@ may need to run `unset SLURM_CPU_BIND` in your terminal first to avoid errors. If you wanted, some other processing could also be done here, such as: -1. Converting the patient's dynamically recorded race into a static, most commonly recorded race field. +1. Converting the subject's dynamically recorded race into a static, most commonly recorded race field. diff --git a/MIMIC-IV_Example/configs/event_configs.yaml b/MIMIC-IV_Example/configs/event_configs.yaml index 0cd0381d..619d5a2e 100644 --- a/MIMIC-IV_Example/configs/event_configs.yaml +++ b/MIMIC-IV_Example/configs/event_configs.yaml @@ -1,4 +1,4 @@ -patient_id_col: subject_id +subject_id_col: subject_id hosp/admissions: ed_registration: code: ED_REGISTRATION @@ -27,7 +27,7 @@ hosp/admissions: time: col(dischtime) time_format: "%Y-%m-%d %H:%M:%S" hadm_id: hadm_id - # We omit the death event here as it is joined to the data in the patients table in the pre-MEDS step. + # We omit the death event here as it is joined to the data in the subjects table in the pre-MEDS step. hosp/diagnoses_icd: diagnosis: @@ -108,7 +108,7 @@ hosp/omr: time: col(chartdate) time_format: "%Y-%m-%d" -hosp/patients: +hosp/subjects: gender: code: - GENDER @@ -295,18 +295,18 @@ icu/inputevents: description: ["omop_concept_name", "label"] # List of strings are columns to be collated itemid: "itemid (omop_source_code)" parent_codes: "{omop_vocabulary_id}/{omop_concept_code}" - patient_weight: + subject_weight: code: - - PATIENT_WEIGHT_AT_INFUSION + - SUBJECT_WEIGHT_AT_INFUSION - KG time: col(starttime) time_format: "%Y-%m-%d %H:%M:%S" - numeric_value: patientweight + numeric_value: subjectweight icu/outputevents: output: code: - - PATIENT_FLUID_OUTPUT + - SUBJECT_FLUID_OUTPUT - col(itemid) - col(valueuom) time: col(charttime) diff --git a/MIMIC-IV_Example/joint_script.sh b/MIMIC-IV_Example/joint_script.sh index a98fee7b..dd1459c4 100755 --- a/MIMIC-IV_Example/joint_script.sh +++ b/MIMIC-IV_Example/joint_script.sh @@ -8,7 +8,7 @@ function display_help() { echo "Usage: $0 " echo echo "This script processes MIMIC-IV data through several steps, handling raw data conversion," - echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." echo echo "Arguments:" echo " MIMICIV_RAW_DIR Directory containing raw MIMIC-IV data files." @@ -88,11 +88,11 @@ MEDS_extract-shard_events \ etl_metadata.dataset_version="2.2" \ event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" -echo "Splitting patients in serial" -MEDS_extract-split_and_shard_patients \ +echo "Splitting subjects in serial" +MEDS_extract-split_and_shard_subjects \ input_dir="$MIMICIV_PREMEDS_DIR" \ cohort_dir="$MIMICIV_MEDS_DIR" \ - stage="split_and_shard_patients" \ + stage="split_and_shard_subjects" \ etl_metadata.dataset_name="MIMIC-IV" \ etl_metadata.dataset_version="2.2" \ event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" diff --git a/MIMIC-IV_Example/joint_script_slurm.sh b/MIMIC-IV_Example/joint_script_slurm.sh index 3ff96846..e13fb7e9 100755 --- a/MIMIC-IV_Example/joint_script_slurm.sh +++ b/MIMIC-IV_Example/joint_script_slurm.sh @@ -8,7 +8,7 @@ function display_help() { echo "Usage: $0 " echo echo "This script processes MIMIC-IV data through several steps, handling raw data conversion," - echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." echo "This script uses slurm to process the data in parallel via the 'submitit' Hydra launcher." echo echo "Arguments:" @@ -72,8 +72,8 @@ MEDS_extract-shard_events \ event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml \ stage=shard_events -echo "Splitting patients on one worker" -MEDS_extract-split_and_shard_patients \ +echo "Splitting subjects on one worker" +MEDS_extract-split_and_shard_subjects \ --multirun \ worker="range(0,1)" \ hydra/launcher=submitit_slurm \ diff --git a/README.md b/README.md index e797d3fc..a9e057d2 100644 --- a/README.md +++ b/README.md @@ -45,12 +45,12 @@ directories. The fundamental design philosophy of this repository can be summarized as follows: 1. _(The MEDS Assumption)_: All structured electronic health record (EHR) data can be represented as a - series of events, each of which is associated with a patient, a time, and a set of codes and + series of events, each of which is associated with a subject, a time, and a set of codes and numeric values. This representation is the Medical Event Data Standard (MEDS) format, and in this - repository we use it in the "flat" format, where data is organized as rows of `patient_id`, + repository we use it in the "flat" format, where data is organized as rows of `subject_id`, `time`, `code`, `numeric_value` columns. 2. _Easy Efficiency through Sharding_: MEDS datasets in this repository are sharded into smaller, more - manageable pieces (organized as separate files) at the patient level (and, during the raw-data extraction + manageable pieces (organized as separate files) at the subject level (and, during the raw-data extraction process, the event level). This enables users to scale up their processing capabilities ad nauseum by leveraging more workers to process these shards in parallel. This parallelization is seamlessly enabled with the configuration schema used in the scripts in this repository. This style of parallelization @@ -62,7 +62,7 @@ The fundamental design philosophy of this repository can be summarized as follow the others, and each stage is designed to do a small amount of work and be easily testable in isolation. This design philosophy ensures that the pipeline is robust to changes, easy to debug, and easy to extend. In particular, to add new operations specific to a given model or dataset, the user need only write - simple functions that take in a flat MEDS dataframe (representing a single patient level shard) and + simple functions that take in a flat MEDS dataframe (representing a single subject level shard) and return a new flat MEDS dataframe, and then wrap that function in a script by following the examples provided in this repository. These individual functions can use the same configuration schema as other stages in the pipeline or include a separate, stage-specific configuration, and can use whatever @@ -198,9 +198,9 @@ To use this repository as a template, the user should follow these steps: Assumptions: 1. Your data is organized in a set of parquet files on disk such that each row of each file corresponds to - one or more measurements per patient and has all necessary information in that row to extract said - measurement, organized in a simple, columnar format. Each of these parquet files stores the patient's ID in - a column called `patient_id` in the same type. + one or more measurements per subject and has all necessary information in that row to extract said + measurement, organized in a simple, columnar format. Each of these parquet files stores the subject's ID in + a column called `subject_id` in the same type. 2. You have a pre-defined or can externally define the requisite MEDS base `code_metadata` file that describes the codes in your data as necessary. This file is not used in the provided pre-processing pipeline in this package, but is necessary for other uses of the MEDS data. @@ -221,16 +221,16 @@ The provided ETL consists of the following steps, which can be performed as need degree of parallelism is desired per step. 1. It re-shards the input data into a set of smaller, event-level shards to facilitate parallel processing. - This can be skipped if your input data is already suitably sharded at either a per-patient or per-event + This can be skipped if your input data is already suitably sharded at either a per-subject or per-event level. 2. It extracts the subject IDs from the sharded data and computes the set of ML splits and (per split) the - patient shards. These are stored in a JSON file in the output cohort directory. + subject shards. These are stored in a JSON file in the output cohort directory. 3. It converts the input, event level shards into the MEDS flat format and joins and shards these data into - patient-level shards for MEDS use and stores them in a nested format in the output cohort directory, + subject-level shards for MEDS use and stores them in a nested format in the output cohort directory, again in the flat format. This step can be broken down into two sub-steps: - - First, each input shard is converted to the MEDS flat format and split into sub patient-level shards. - - Second, the appropriate sub patient-level shards are joined and re-organized into the final - patient-level shards. This method ensures that we minimize the amount of read contention on the input + - First, each input shard is converted to the MEDS flat format and split into sub subject-level shards. + - Second, the appropriate sub subject-level shards are joined and re-organized into the final + subject-level shards. This method ensures that we minimize the amount of read contention on the input shards during the join process and can maximize parallel throughput, as (theoretically, with sufficient workers) all input shards can be sub-sharded in parallel and then all output shards can be joined in parallel. @@ -239,7 +239,7 @@ The ETL scripts all use [Hydra](https://hydra.cc/) for configuration management, `configs/extraction.yaml` file for configuration. The user can override any of these settings in the normal way for Hydra configurations. -If desired, appropriate scripts can be written and run at a per-patient shard level to convert between the +If desired, appropriate scripts can be written and run at a per-subject shard level to convert between the flat format and any of the other valid nested MEDS format, but for now we leave that up to the user. #### Input Event Extraction @@ -250,11 +250,11 @@ dataframes should be parsed into different event formats. The YAML file stores a following structure: ```yaml -patient_id: $GLOBAL_PATIENT_ID_OVERWRITE # Optional, if you want to overwrite the patient ID column name for - # all inputs. If not specified, defaults to "patient_id". +subject_id: $GLOBAL_SUBJECT_ID_OVERWRITE # Optional, if you want to overwrite the subject ID column name for + # all inputs. If not specified, defaults to "subject_id". $INPUT_FILE_STEM: - patient_id: $INPUT_FILE_PATIENT_ID # Optional, if you want to overwrite the patient ID column name for - # this input. IF not specified, defaults to the global patient ID. + subject_id: $INPUT_FILE_SUBJECT_ID # Optional, if you want to overwrite the subject ID column name for + # this input. IF not specified, defaults to the global subject ID. $EVENT_NAME: code: - $CODE_PART_1 @@ -287,18 +287,18 @@ script is a functional test that is also run with `pytest` to verify correctness 1. `scripts/extraction/shard_events.py` shards the input data into smaller, event-level shards by splitting raw files into chunks of a configurable number of rows. Files are split sequentially, with no regard for - data content or patient boundaries. The resulting files are stored in the `subsharded_events` + data content or subject boundaries. The resulting files are stored in the `subsharded_events` subdirectory of the output directory. -2. `scripts/extraction/split_and_shard_patients.py` splits the patient population into ML splits and shards - these splits into patient-level shards. The result of this process is only a simple `JSON` file - containing the patient IDs belonging to individual splits and shards. This file is stored in the +2. `scripts/extraction/split_and_shard_subjects.py` splits the subject population into ML splits and shards + these splits into subject-level shards. The result of this process is only a simple `JSON` file + containing the subject IDs belonging to individual splits and shards. This file is stored in the `output_directory/splits.json` file. 3. `scripts/extraction/convert_to_sharded_events.py` converts the input, event-level shards into the MEDS - event format and splits them into patient-level sub-shards. So, the resulting files are sharded into - patient-level, then event-level groups and are not merged into full patient-level shards or appropriately + event format and splits them into subject-level sub-shards. So, the resulting files are sharded into + subject-level, then event-level groups and are not merged into full subject-level shards or appropriately sorted for downstream use. -4. `scripts/extraction/merge_to_MEDS_cohort.py` merges the patient-level, event-level shards into full - patient-level shards and sorts them appropriately for downstream use. The resulting files are stored in +4. `scripts/extraction/merge_to_MEDS_cohort.py` merges the subject-level, event-level shards into full + subject-level shards and sorts them appropriately for downstream use. The resulting files are stored in the `output_directory/final_cohort` directory. ## MEDS Pre-processing Transformations @@ -308,9 +308,9 @@ contains a variety of pre-processing transformations and scripts that can be app in various ways to prepare them for downstream modeling. Broadly speaking, the pre-processing pipeline can be broken down into the following steps: -1. Filtering the dataset by criteria that do not require cross-patient analyses, e.g., +1. Filtering the dataset by criteria that do not require cross-subject analyses, e.g., - - Filtering patients by the number of events or unique times they have. + - Filtering subjects by the number of events or unique times they have. - Removing numeric values that fall outside of pre-specified, per-code ranges (e.g., for outlier removal). @@ -318,9 +318,9 @@ broken down into the following steps: - Adding time-derived measurements, e.g., - The time since the last event of a certain type. - - The patient's age as of each unique timepoint. + - The subject's age as of each unique timepoint. - The time-of-day of each event. - - Adding a "dummy" event to the dataset for each patient that occurs at the end of the observation + - Adding a "dummy" event to the dataset for each subject that occurs at the end of the observation period. 3. Iteratively (a) grouping the dataset by `code` and associated code modifier columns and collecting @@ -344,11 +344,11 @@ broken down into the following steps: 5. Normalizing the data to convert codes to indices and numeric values to the desired form (either categorical indices or normalized numeric values). -6. Tokenizing the data in time to create a pre-tensorized dataset with clear delineations between patients, - patient sequence elements, and measurements per sequence element (note that various of these delineations +6. Tokenizing the data in time to create a pre-tensorized dataset with clear delineations between subjects, + subject sequence elements, and measurements per sequence element (note that various of these delineations may be fully flat/trivial for unnested formats). -7. Tensorizing the data to permit efficient retrieval from disk of patient data for deep-learning modeling +7. Tensorizing the data to permit efficient retrieval from disk of subject data for deep-learning modeling via PyTorch. Much like how the entire MEDS ETL pipeline is controlled by a single configuration file, the pre-processing @@ -363,7 +363,7 @@ be a bottleneck. Tokenization is the process of producing dataframes that are arranged into the sequences that will eventually be processed by deep-learning methods. Generally, these dataframes will be arranged such that each row -corresponds to a unique patient, with nested list-type columns corresponding either to _events_ (unique +corresponds to a unique subject, with nested list-type columns corresponding either to _events_ (unique timepoints), themselves with nested, list-type measurements, or to _measurements_ (unique measurements within a timepoint) directly. Importantly, _tokenized files are generally not ideally suited to direct ingestion by PyTorch datasets_. Instead, they should undergo a _tensorization_ process to be converted into a format that @@ -379,7 +379,7 @@ does not inhibit rapid training, and (3) be organized such that CPU and GPU reso during training. Similarly, by _scalability_, we mean that the three desiderata above should hold true even as the dataset size grows much larger---while total training time can increase, time to begin training, to process the data per-item, and CPU/GPU resources required should remain constant, or only grow negligibly, -such as the cost of maintaining a larger index of patient IDs to file offsets or paths (though disk space will +such as the cost of maintaining a larger index of subject IDs to file offsets or paths (though disk space will of course increase). Depending on one's performance needs and dataset sizes, there are 3 modes of deep learning training that can @@ -398,7 +398,7 @@ on an as-needed basis. This mode is extremely scalable, because the entire datas loaded or stored in memory in its entirety. When done properly, retrieving data from disk can be done in a manner that is independent of the total dataset size as well, thereby rendering the load time similarly unconstrained by total dataset size. This mode is also extremely flexible, because different cohorts can be -loaded from the same base dataset simply by changing which patients and what offsets within patient data are +loaded from the same base dataset simply by changing which subjects and what offsets within subject data are read on any given cohort, all without changing the base files or underlying code. However, this mode does require ragged dataset collation which can be more resource intensive than pre-batched iteration, so it is slower than the "Fixed-batch retrieval" approach. This mode is what is currently supported by this repository. diff --git a/eICU_Example/README.md b/eICU_Example/README.md index c0494c94..37eb9d03 100644 --- a/eICU_Example/README.md +++ b/eICU_Example/README.md @@ -19,7 +19,7 @@ up from this one). - [ ] Testing the MEDS extraction ETL runs on eICU-CRD (this should be expected to work, but needs live testing). - [ ] Sub-sharding - - [ ] Patient split gathering + - [ ] Subject split gathering - [ ] Event extraction - [ ] Merging - [ ] Validating the output MEDS cohort @@ -58,10 +58,10 @@ This is a step in a few parts: 1. Join a few tables by `hadm_id` to get the right timestamps in the right rows for processing. In particular, we need to join: - TODO -2. Convert the patient's static data to a more parseable form. This entails: - - Get the patient's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and +2. Convert the subject's static data to a more parseable form. This entails: + - Get the subject's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and `anchor_offset` fields. - - Merge the patient's `dod` with the `deathtime` from the `admissions` table. + - Merge the subject's `dod` with the `deathtime` from the `admissions` table. After these steps, modified files or symlinks to the original files will be written in a new directory which will be used as the input to the actual MEDS extraction ETL. We'll use `$EICU_PREMEDS_DIR` to denote this @@ -78,12 +78,12 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less ## Step 3: Run the MEDS extraction ETL -Note that eICU has a lot more observations per patient than does MIMIC-IV, so to keep to a reasonable memory +Note that eICU has a lot more observations per subject than does MIMIC-IV, so to keep to a reasonable memory burden (e.g., \< 150GB per worker), you will want a smaller shard size, as well as to turn off the final unique check (which should not be necessary given the structure of eICU and is expensive) in the merge stage. You can do this by setting the following parameters at the end of the mandatory args when running this script: -- `stage_configs.split_and_shard_patients.n_patients_per_shard=10000` +- `stage_configs.split_and_shard_subjects.n_subjects_per_shard=10000` - `stage_configs.merge_to_MEDS_cohort.unique_by=null` ### Running locally, serially @@ -106,10 +106,10 @@ This is a step in 4 parts: In practice, on a machine with 150 GB of RAM and 10 cores, this step takes approximately 20 minutes in total. -1. Extract and form the patient splits and sub-shards. +1. Extract and form the subject splits and sub-shards. ```bash -./scripts/extraction/split_and_shard_patients.py \ +./scripts/extraction/split_and_shard_subjects.py \ input_dir=$EICU_PREMEDS_DIR \ cohort_dir=$EICU_MEDS_DIR \ event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml @@ -117,7 +117,7 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes appro In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. -1. Extract patient sub-shards and convert to MEDS events. +1. Extract subject sub-shards and convert to MEDS events. ```bash ./scripts/extraction/convert_to_sharded_events.py \ @@ -132,7 +132,7 @@ multiple times (though this will, of course, consume more resources). If your fi commands can also be launched as separate slurm jobs, for example. For eICU, this level of parallelization and performance is not necessary; however, for larger datasets, it can be. -1. Merge the MEDS events into a single file per patient sub-shard. +1. Merge the MEDS events into a single file per subject sub-shard. ```bash ./scripts/extraction/merge_to_MEDS_cohort.py \ @@ -172,10 +172,10 @@ to finish before moving on to the next stage. Let `$N_PARALLEL_WORKERS` be the n In practice, on a machine with 150 GB of RAM and 10 cores, this step takes approximately 20 minutes in total. -1. Extract and form the patient splits and sub-shards. +1. Extract and form the subject splits and sub-shards. ```bash -./scripts/extraction/split_and_shard_patients.py \ +./scripts/extraction/split_and_shard_subjects.py \ input_dir=$EICU_PREMEDS_DIR \ cohort_dir=$EICU_MEDS_DIR \ event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml @@ -183,7 +183,7 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes appro In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. -1. Extract patient sub-shards and convert to MEDS events. +1. Extract subject sub-shards and convert to MEDS events. ```bash ./scripts/extraction/convert_to_sharded_events.py \ @@ -198,7 +198,7 @@ multiple times (though this will, of course, consume more resources). If your fi commands can also be launched as separate slurm jobs, for example. For eICU, this level of parallelization and performance is not necessary; however, for larger datasets, it can be. -1. Merge the MEDS events into a single file per patient sub-shard. +1. Merge the MEDS events into a single file per subject sub-shard. ```bash ./scripts/extraction/merge_to_MEDS_cohort.py \ @@ -221,7 +221,7 @@ timeline which is otherwise stored at the _datetime_ resolution? Other questions: -1. How to handle merging the deathtimes between the hosp table and the patients table? +1. How to handle merging the deathtimes between the hosp table and the subjects table? 2. How to handle the dob nonsense MIMIC has? ## Future Work @@ -230,4 +230,4 @@ Other questions: If you wanted, some other processing could also be done here, such as: -1. Converting the patient's dynamically recorded race into a static, most commonly recorded race field. +1. Converting the subject's dynamically recorded race into a static, most commonly recorded race field. diff --git a/eICU_Example/configs/event_configs.yaml b/eICU_Example/configs/event_configs.yaml index e6f2e7ab..fb7901cf 100644 --- a/eICU_Example/configs/event_configs.yaml +++ b/eICU_Example/configs/event_configs.yaml @@ -1,7 +1,7 @@ -# Note that there is no "patient_id" for eICU -- patients are only differentiable during the course of a +# Note that there is no "subject_id" for eICU -- patients are only differentiable during the course of a # single health system stay. Accordingly, we set the "patient" id here as the "patientHealthSystemStayID" -patient_id_col: patienthealthsystemstayid +subject_id_col: patienthealthsystemstayid patient: dob: @@ -131,7 +131,7 @@ infusionDrug: volume_of_fluid: "volumeoffluid" patient_weight: code: - - "INFUSION_PATIENT_WEIGHT" + - "INFUSION_SUBJECT_WEIGHT" time: col(infusionEnteredTimestamp) numeric_value: "patientweight" diff --git a/eICU_Example/joint_script.sh b/eICU_Example/joint_script.sh index fd76ee28..0b3ad6c5 100755 --- a/eICU_Example/joint_script.sh +++ b/eICU_Example/joint_script.sh @@ -8,7 +8,7 @@ function display_help() { echo "Usage: $0 " echo echo "This script processes eICU data through several steps, handling raw data conversion," - echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." echo echo "Arguments:" echo " EICU_RAW_DIR Directory containing raw eICU data files." @@ -39,12 +39,12 @@ N_PARALLEL_WORKERS="$4" shift 4 -echo "Note that eICU has a lot more observations per patient than does MIMIC-IV, so to keep to a reasonable " +echo "Note that eICU has a lot more observations per subject than does MIMIC-IV, so to keep to a reasonable " echo "memory burden (e.g., < 150GB per worker), you will want a smaller shard size, as well as to turn off " echo "the final unique check (which should not be necessary given the structure of eICU and is expensive) " echo "in the merge stage. You can do this by setting the following parameters at the end of the mandatory " echo "args when running this script:" -echo " * stage_configs.split_and_shard_patients.n_patients_per_shard=10000" +echo " * stage_configs.split_and_shard_subjects.n_subjects_per_shard=10000" echo " * stage_configs.merge_to_MEDS_cohort.unique_by=null" echo "Running pre-MEDS conversion." @@ -59,8 +59,8 @@ echo "Running shard_events.py with $N_PARALLEL_WORKERS workers in parallel" cohort_dir="$EICU_MEDS_DIR" \ event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" -echo "Splitting patients in serial" -./scripts/extraction/split_and_shard_patients.py \ +echo "Splitting subjects in serial" +./scripts/extraction/split_and_shard_subjects.py \ input_dir="$EICU_PREMEDS_DIR" \ cohort_dir="$EICU_MEDS_DIR" \ event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" diff --git a/eICU_Example/joint_script_slurm.sh b/eICU_Example/joint_script_slurm.sh index 78802860..bdd7abe8 100755 --- a/eICU_Example/joint_script_slurm.sh +++ b/eICU_Example/joint_script_slurm.sh @@ -8,7 +8,7 @@ function display_help() { echo "Usage: $0 " echo echo "This script processes eICU data through several steps, handling raw data conversion," - echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." echo "This script uses slurm to process the data in parallel via the 'submitit' Hydra launcher." echo echo "Arguments:" @@ -71,8 +71,8 @@ echo "Trying submitit launching with $N_PARALLEL_WORKERS jobs." cohort_dir="$EICU_MEDS_DIR" \ event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml -echo "Splitting patients on one worker" -./scripts/extraction/split_and_shard_patients.py \ +echo "Splitting subjects on one worker" +./scripts/extraction/split_and_shard_subjects.py \ --multirun \ worker="range(0,1)" \ hydra/launcher=submitit_slurm \ diff --git a/pyproject.toml b/pyproject.toml index d4f64501..8073111f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "polars~=1.1.0", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy", "meds==0.3", + "polars~=1.1.0", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy", "meds==0.3.2", ] [tool.setuptools_scm] @@ -35,7 +35,7 @@ docs = [ [project.scripts] # MEDS_extract -MEDS_extract-split_and_shard_patients = "MEDS_transforms.extract.split_and_shard_patients:main" +MEDS_extract-split_and_shard_subjects = "MEDS_transforms.extract.split_and_shard_subjects:main" MEDS_extract-shard_events = "MEDS_transforms.extract.shard_events:main" MEDS_extract-convert_to_sharded_events = "MEDS_transforms.extract.convert_to_sharded_events:main" MEDS_extract-merge_to_MEDS_cohort = "MEDS_transforms.extract.merge_to_MEDS_cohort:main" @@ -50,7 +50,7 @@ MEDS_transform-fit_vocabulary_indices = "MEDS_transforms.fit_vocabulary_indices: MEDS_transform-reshard_to_split = "MEDS_transforms.reshard_to_split:main" ## Filters MEDS_transform-filter_measurements = "MEDS_transforms.filters.filter_measurements:main" -MEDS_transform-filter_patients = "MEDS_transforms.filters.filter_patients:main" +MEDS_transform-filter_subjects = "MEDS_transforms.filters.filter_subjects:main" ## Transforms MEDS_transform-reorder_measurements = "MEDS_transforms.transforms.reorder_measurements:main" MEDS_transform-add_time_derived_measurements = "MEDS_transforms.transforms.add_time_derived_measurements:main" diff --git a/src/MEDS_transforms/__init__.py b/src/MEDS_transforms/__init__.py index e0aaaf3a..c40e2d92 100644 --- a/src/MEDS_transforms/__init__.py +++ b/src/MEDS_transforms/__init__.py @@ -2,6 +2,7 @@ from importlib.resources import files import polars as pl +from meds import code_field, subject_id_field, time_field __package_name__ = "MEDS_transforms" try: @@ -12,12 +13,12 @@ PREPROCESS_CONFIG_YAML = files(__package_name__).joinpath("configs/preprocess.yaml") EXTRACT_CONFIG_YAML = files(__package_name__).joinpath("configs/extract.yaml") -MANDATORY_COLUMNS = ["patient_id", "time", "code", "numeric_value"] +MANDATORY_COLUMNS = [subject_id_field, time_field, code_field, "numeric_value"] MANDATORY_TYPES = { - "patient_id": pl.Int64, - "time": pl.Datetime("us"), - "code": pl.String, + subject_id_field: pl.Int64, + time_field: pl.Datetime("us"), + code_field: pl.String, "numeric_value": pl.Float32, "categorical_value": pl.String, "text_value": pl.String, @@ -29,5 +30,5 @@ "category_value": "categoric_value", "textual_value": "text_value", "timestamp": "time", - "subject_id": "patient_id", + "patient_id": subject_id_field, } diff --git a/src/MEDS_transforms/aggregate_code_metadata.py b/src/MEDS_transforms/aggregate_code_metadata.py index 13e9b345..d5c2f81a 100755 --- a/src/MEDS_transforms/aggregate_code_metadata.py +++ b/src/MEDS_transforms/aggregate_code_metadata.py @@ -12,6 +12,7 @@ import polars as pl import polars.selectors as cs from loguru import logger +from meds import subject_id_field from omegaconf import DictConfig, ListConfig, OmegaConf from MEDS_transforms import PREPROCESS_CONFIG_YAML @@ -26,7 +27,7 @@ class METADATA_FN(StrEnum): This enumeration contains the supported code-metadata collection and aggregation function names that can be applied to codes (or, rather, unique code & modifier units) in a MEDS cohort. Each function name is mapped, in the below `CODE_METADATA_AGGREGATIONS` dictionary, to mapper and reducer functions that (a) - collect the raw data at a per code-modifier level from MEDS patient-level shards and (b) aggregates two or + collect the raw data at a per code-modifier level from MEDS subject-level shards and (b) aggregates two or more per-shard metadata files into a single metadata file, which can be used to merge metadata across all shards into a single file. @@ -45,14 +46,14 @@ class METADATA_FN(StrEnum): or on the command line. Args: - "code/n_patients": Collects the number of unique patients who have (anywhere in their record) the code + "code/n_subjects": Collects the number of unique subjects who have (anywhere in their record) the code & modifiers group. "code/n_occurrences": Collects the total number of occurrences of the code & modifiers group across - all observations for all patients. - "values/n_patients": Collects the number of unique patients who have a non-null, non-nan + all observations for all subjects. + "values/n_subjects": Collects the number of unique subjects who have a non-null, non-nan numeric_value field for the code & modifiers group. "values/n_occurrences": Collects the total number of non-null, non-nan numeric_value occurrences for - the code & modifiers group across all observations for all patients. + the code & modifiers group across all observations for all subjects. "values/n_ints": Collects the number of times the observed, non-null numeric_value for the code & modifiers group is an integral value (i.e., a whole number, not an integral type). "values/sum": Collects the sum of the non-null, non-nan numeric_value values for the code & @@ -67,9 +68,9 @@ class METADATA_FN(StrEnum): the configuration file using the dictionary syntax for the aggregation. """ - CODE_N_PATIENTS = "code/n_patients" + CODE_N_PATIENTS = "code/n_subjects" CODE_N_OCCURRENCES = "code/n_occurrences" - VALUES_N_PATIENTS = "values/n_patients" + VALUES_N_PATIENTS = "values/n_subjects" VALUES_N_OCCURRENCES = "values/n_occurrences" VALUES_N_INTS = "values/n_ints" VALUES_SUM = "values/sum" @@ -157,10 +158,10 @@ def quantile_reducer(cols: cs._selector_proxy_, quantiles: list[float]) -> pl.Ex PRESENT_VALS = VAL.filter(VAL_PRESENT) CODE_METADATA_AGGREGATIONS: dict[METADATA_FN, MapReducePair] = { - METADATA_FN.CODE_N_PATIENTS: MapReducePair(pl.col("patient_id").n_unique(), pl.sum_horizontal), + METADATA_FN.CODE_N_PATIENTS: MapReducePair(pl.col(subject_id_field).n_unique(), pl.sum_horizontal), METADATA_FN.CODE_N_OCCURRENCES: MapReducePair(pl.len(), pl.sum_horizontal), METADATA_FN.VALUES_N_PATIENTS: MapReducePair( - pl.col("patient_id").filter(VAL_PRESENT).n_unique(), pl.sum_horizontal + pl.col(subject_id_field).filter(VAL_PRESENT).n_unique(), pl.sum_horizontal ), METADATA_FN.VALUES_N_OCCURRENCES: MapReducePair(PRESENT_VALS.len(), pl.sum_horizontal), METADATA_FN.VALUES_N_INTS: MapReducePair(VAL.filter(VAL_PRESENT & IS_INT).len(), pl.sum_horizontal), @@ -203,9 +204,9 @@ def validate_args_and_get_code_cols(stage_cfg: DictConfig, code_modifiers: list[ Traceback (most recent call last): ... ValueError: Metadata aggregation function INVALID not found in METADATA_FN enumeration. Values are: - code/n_patients, code/n_occurrences, values/n_patients, values/n_occurrences, values/n_ints, + code/n_subjects, code/n_occurrences, values/n_subjects, values/n_occurrences, values/n_ints, values/sum, values/sum_sqd, values/min, values/max, values/quantiles - >>> valid_cfg = DictConfig({"aggregations": ["code/n_patients", {"name": "values/n_ints"}]}) + >>> valid_cfg = DictConfig({"aggregations": ["code/n_subjects", {"name": "values/n_ints"}]}) >>> validate_args_and_get_code_cols(valid_cfg, 33) Traceback (most recent call last): ... @@ -264,7 +265,7 @@ def mapper_fntr( A function that extracts the specified metadata from a MEDS cohort shard after grouping by the specified code & modifier columns. **Note**: The output of this function will, if ``stage_cfg.do_summarize_over_all_codes`` is True, contain the metadata summarizing all observations - across all codes and patients in the shard, with both ``code`` and all ``code_modifiers`` set + across all codes and subjects in the shard, with both ``code`` and all ``code_modifiers`` set to `None` in the output dataframe, in the same format as the code/modifier specific rows with non-null values. @@ -274,13 +275,13 @@ def mapper_fntr( ... "code": ["A", "B", "A", "B", "C", "A", "C", "B", "D"], ... "modifier1": [1, 2, 1, 2, 1, 2, 1, 2, None], ... "modifier_ignored": [3, 3, 4, 4, 5, 5, 6, 6, 7], - ... "patient_id": [1, 2, 1, 3, 1, 2, 2, 2, 1], + ... "subject_id": [1, 2, 1, 3, 1, 2, 2, 2, 1], ... "numeric_value": [1.1, 2., 1.1, 4., 5., 6., 7.5, float('nan'), None], ... }) >>> df shape: (9, 5) ┌──────┬───────────┬──────────────────┬────────────┬───────────────┐ - │ code ┆ modifier1 ┆ modifier_ignored ┆ patient_id ┆ numeric_value │ + │ code ┆ modifier1 ┆ modifier_ignored ┆ subject_id ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ i64 ┆ i64 ┆ f64 │ ╞══════╪═══════════╪══════════════════╪════════════╪═══════════════╡ @@ -295,14 +296,14 @@ def mapper_fntr( │ D ┆ null ┆ 7 ┆ 1 ┆ null │ └──────┴───────────┴──────────────────┴────────────┴───────────────┘ >>> stage_cfg = DictConfig({ - ... "aggregations": ["code/n_patients", "values/n_ints"], + ... "aggregations": ["code/n_subjects", "values/n_ints"], ... "do_summarize_over_all_codes": True ... }) >>> mapper = mapper_fntr(stage_cfg, None) >>> mapper(df.lazy()).collect() shape: (5, 3) ┌──────┬─────────────────┬───────────────┐ - │ code ┆ code/n_patients ┆ values/n_ints │ + │ code ┆ code/n_subjects ┆ values/n_ints │ │ --- ┆ --- ┆ --- │ │ str ┆ u32 ┆ u32 │ ╞══════╪═════════════════╪═══════════════╡ @@ -312,12 +313,12 @@ def mapper_fntr( │ C ┆ 2 ┆ 1 │ │ D ┆ 1 ┆ 0 │ └──────┴─────────────────┴───────────────┘ - >>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) + >>> stage_cfg = DictConfig({"aggregations": ["code/n_subjects", "values/n_ints"]}) >>> mapper = mapper_fntr(stage_cfg, None) >>> mapper(df.lazy()).collect() shape: (4, 3) ┌──────┬─────────────────┬───────────────┐ - │ code ┆ code/n_patients ┆ values/n_ints │ + │ code ┆ code/n_subjects ┆ values/n_ints │ │ --- ┆ --- ┆ --- │ │ str ┆ u32 ┆ u32 │ ╞══════╪═════════════════╪═══════════════╡ @@ -327,12 +328,12 @@ def mapper_fntr( │ D ┆ 1 ┆ 0 │ └──────┴─────────────────┴───────────────┘ >>> code_modifiers = ["modifier1"] - >>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) + >>> stage_cfg = DictConfig({"aggregations": ["code/n_subjects", "values/n_ints"]}) >>> mapper = mapper_fntr(stage_cfg, ListConfig(code_modifiers)) >>> mapper(df.lazy()).collect() shape: (5, 4) ┌──────┬───────────┬─────────────────┬───────────────┐ - │ code ┆ modifier1 ┆ code/n_patients ┆ values/n_ints │ + │ code ┆ modifier1 ┆ code/n_subjects ┆ values/n_ints │ │ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ u32 ┆ u32 │ ╞══════╪═══════════╪═════════════════╪═══════════════╡ @@ -376,12 +377,12 @@ def mapper_fntr( │ C ┆ 1 ┆ 2 ┆ 12.5 │ │ D ┆ null ┆ 1 ┆ 0.0 │ └──────┴───────────┴────────────────────┴────────────┘ - >>> stage_cfg = DictConfig({"aggregations": ["values/n_patients", "values/n_occurrences"]}) + >>> stage_cfg = DictConfig({"aggregations": ["values/n_subjects", "values/n_occurrences"]}) >>> mapper = mapper_fntr(stage_cfg, code_modifiers) >>> mapper(df.lazy()).collect() shape: (5, 4) ┌──────┬───────────┬───────────────────┬──────────────────────┐ - │ code ┆ modifier1 ┆ values/n_patients ┆ values/n_occurrences │ + │ code ┆ modifier1 ┆ values/n_subjects ┆ values/n_occurrences │ │ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ u32 ┆ u32 │ ╞══════╪═══════════╪═══════════════════╪══════════════════════╡ @@ -455,7 +456,7 @@ def mapper_fntr( def by_code_mapper(df: pl.LazyFrame) -> pl.LazyFrame: return df.group_by(code_key_columns).agg(**agg_operations).sort(code_key_columns) - def all_patients_mapper(df: pl.LazyFrame) -> pl.LazyFrame: + def all_subjects_mapper(df: pl.LazyFrame) -> pl.LazyFrame: local_agg_operations = agg_operations.copy() if METADATA_FN.VALUES_QUANTILES in agg_operations: local_agg_operations[METADATA_FN.VALUES_QUANTILES] = agg_operations[ @@ -467,8 +468,8 @@ def all_patients_mapper(df: pl.LazyFrame) -> pl.LazyFrame: def mapper(df: pl.LazyFrame) -> pl.LazyFrame: by_code = by_code_mapper(df) - all_patients = all_patients_mapper(df) - return pl.concat([all_patients, by_code], how="diagonal_relaxed").select( + all_subjects = all_subjects_mapper(df) + return pl.concat([all_subjects, by_code], how="diagonal_relaxed").select( *code_key_columns, *agg_operations.keys() ) @@ -502,9 +503,9 @@ def reducer_fntr( >>> df_1 = pl.DataFrame({ ... "code": [None, "A", "A", "B", "C"], ... "modifier1": [None, 1, 2, 1, 2], - ... "code/n_patients": [10, 1, 1, 2, 2], + ... "code/n_subjects": [10, 1, 1, 2, 2], ... "code/n_occurrences": [13, 2, 1, 3, 2], - ... "values/n_patients": [8, 1, 1, 2, 2], + ... "values/n_subjects": [8, 1, 1, 2, 2], ... "values/n_occurrences": [12, 2, 1, 3, 2], ... "values/n_ints": [4, 0, 1, 3, 1], ... "values/sum": [13.2, 2.2, 6.0, 14.0, 12.5], @@ -516,9 +517,9 @@ def reducer_fntr( >>> df_2 = pl.DataFrame({ ... "code": ["A", "A", "B", "C"], ... "modifier1": [1, 2, 1, None], - ... "code/n_patients": [3, 3, 4, 4], + ... "code/n_subjects": [3, 3, 4, 4], ... "code/n_occurrences": [10, 11, 8, 11], - ... "values/n_patients": [0, 1, 2, 2], + ... "values/n_subjects": [0, 1, 2, 2], ... "values/n_occurrences": [0, 4, 3, 2], ... "values/n_ints": [0, 1, 3, 1], ... "values/sum": [0., 7.0, 14.0, 12.5], @@ -530,9 +531,9 @@ def reducer_fntr( >>> df_3 = pl.DataFrame({ ... "code": ["D"], ... "modifier1": [1], - ... "code/n_patients": [2], + ... "code/n_subjects": [2], ... "code/n_occurrences": [2], - ... "values/n_patients": [1], + ... "values/n_subjects": [1], ... "values/n_occurrences": [3], ... "values/n_ints": [3], ... "values/sum": [2], @@ -542,12 +543,12 @@ def reducer_fntr( ... "values/quantiles": [[]], ... }) >>> code_modifiers = ["modifier1"] - >>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) + >>> stage_cfg = DictConfig({"aggregations": ["code/n_subjects", "values/n_ints"]}) >>> reducer = reducer_fntr(stage_cfg, code_modifiers) >>> reducer(df_1, df_2, df_3) shape: (7, 4) ┌──────┬───────────┬─────────────────┬───────────────┐ - │ code ┆ modifier1 ┆ code/n_patients ┆ values/n_ints │ + │ code ┆ modifier1 ┆ code/n_subjects ┆ values/n_ints │ │ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ i64 ┆ i64 │ ╞══════╪═══════════╪═════════════════╪═══════════════╡ @@ -562,9 +563,9 @@ def reducer_fntr( >>> cfg = DictConfig({ ... "code_modifiers": ["modifier1"], ... "code_processing_stages": { - ... "stage1": ["code/n_patients", "values/n_ints"], + ... "stage1": ["code/n_subjects", "values/n_ints"], ... "stage2": ["code/n_occurrences", "values/sum"], - ... "stage3.A": ["values/n_patients", "values/n_occurrences"], + ... "stage3.A": ["values/n_subjects", "values/n_occurrences"], ... "stage3.B": ["values/sum_sqd", "values/min", "values/max"], ... "stage4": ["INVALID"], ... } @@ -586,12 +587,12 @@ def reducer_fntr( │ C ┆ 2 ┆ 2 ┆ 12.5 │ │ D ┆ 1 ┆ 2 ┆ 2.0 │ └──────┴───────────┴────────────────────┴────────────┘ - >>> stage_cfg = DictConfig({"aggregations": ["values/n_patients", "values/n_occurrences"]}) + >>> stage_cfg = DictConfig({"aggregations": ["values/n_subjects", "values/n_occurrences"]}) >>> reducer = reducer_fntr(stage_cfg, code_modifiers) >>> reducer(df_1, df_2, df_3) shape: (7, 4) ┌──────┬───────────┬───────────────────┬──────────────────────┐ - │ code ┆ modifier1 ┆ values/n_patients ┆ values/n_occurrences │ + │ code ┆ modifier1 ┆ values/n_subjects ┆ values/n_occurrences │ │ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ i64 ┆ i64 │ ╞══════╪═══════════╪═══════════════════╪══════════════════════╡ diff --git a/src/MEDS_transforms/configs/extract.yaml b/src/MEDS_transforms/configs/extract.yaml index ae3bf076..3abd498e 100644 --- a/src/MEDS_transforms/configs/extract.yaml +++ b/src/MEDS_transforms/configs/extract.yaml @@ -2,7 +2,7 @@ defaults: - pipeline - stage_configs: - shard_events - - split_and_shard_patients + - split_and_shard_subjects - merge_to_MEDS_cohort - extract_code_metadata - finalize_MEDS_metadata @@ -32,7 +32,7 @@ shards_map_fp: "${cohort_dir}/metadata/.shards.json" stages: - shard_events - - split_and_shard_patients + - split_and_shard_subjects - convert_to_sharded_events - merge_to_MEDS_cohort - extract_code_metadata diff --git a/src/MEDS_transforms/configs/preprocess.yaml b/src/MEDS_transforms/configs/preprocess.yaml index ea509cdc..dab87a9a 100644 --- a/src/MEDS_transforms/configs/preprocess.yaml +++ b/src/MEDS_transforms/configs/preprocess.yaml @@ -2,7 +2,7 @@ defaults: - pipeline - stage_configs: - reshard_to_split - - filter_patients + - filter_subjects - add_time_derived_measurements - count_code_occurrences - filter_measurements @@ -24,7 +24,7 @@ code_modifiers: ??? # Pipeline Structure stages: - - filter_patients + - filter_subjects - add_time_derived_measurements - preliminary_counts - filter_measurements diff --git a/src/MEDS_transforms/configs/stage_configs/count_code_occurrences.yaml b/src/MEDS_transforms/configs/stage_configs/count_code_occurrences.yaml index 076a1a05..b17b74e4 100644 --- a/src/MEDS_transforms/configs/stage_configs/count_code_occurrences.yaml +++ b/src/MEDS_transforms/configs/stage_configs/count_code_occurrences.yaml @@ -1,5 +1,5 @@ count_code_occurrences: aggregations: - "code/n_occurrences" - - "code/n_patients" + - "code/n_subjects" do_summarize_over_all_codes: true # This indicates we should include overall, code-independent counts diff --git a/src/MEDS_transforms/configs/stage_configs/filter_measurements.yaml b/src/MEDS_transforms/configs/stage_configs/filter_measurements.yaml index 12ff62f0..0d0a5bd9 100644 --- a/src/MEDS_transforms/configs/stage_configs/filter_measurements.yaml +++ b/src/MEDS_transforms/configs/stage_configs/filter_measurements.yaml @@ -1,3 +1,3 @@ filter_measurements: - min_patients_per_code: null + min_subjects_per_code: null min_occurrences_per_code: null diff --git a/src/MEDS_transforms/configs/stage_configs/filter_patients.yaml b/src/MEDS_transforms/configs/stage_configs/filter_patients.yaml deleted file mode 100644 index 70332b14..00000000 --- a/src/MEDS_transforms/configs/stage_configs/filter_patients.yaml +++ /dev/null @@ -1,3 +0,0 @@ -filter_patients: - min_events_per_patient: null - min_measurements_per_patient: null diff --git a/src/MEDS_transforms/configs/stage_configs/filter_subjects.yaml b/src/MEDS_transforms/configs/stage_configs/filter_subjects.yaml new file mode 100644 index 00000000..2706ffc1 --- /dev/null +++ b/src/MEDS_transforms/configs/stage_configs/filter_subjects.yaml @@ -0,0 +1,3 @@ +filter_subjects: + min_events_per_subject: null + min_measurements_per_subject: null diff --git a/src/MEDS_transforms/configs/stage_configs/fit_normalization.yaml b/src/MEDS_transforms/configs/stage_configs/fit_normalization.yaml index e522470e..6bd90cb6 100644 --- a/src/MEDS_transforms/configs/stage_configs/fit_normalization.yaml +++ b/src/MEDS_transforms/configs/stage_configs/fit_normalization.yaml @@ -1,7 +1,7 @@ fit_normalization: aggregations: - "code/n_occurrences" - - "code/n_patients" + - "code/n_subjects" - "values/n_occurrences" - "values/sum" - "values/sum_sqd" diff --git a/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml b/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml index 16dc5051..fd0dc8aa 100644 --- a/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml +++ b/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml @@ -1,2 +1,2 @@ reshard_to_split: - n_patients_per_shard: 50000 + n_subjects_per_shard: 50000 diff --git a/src/MEDS_transforms/configs/stage_configs/split_and_shard_patients.yaml b/src/MEDS_transforms/configs/stage_configs/split_and_shard_subjects.yaml similarity index 73% rename from src/MEDS_transforms/configs/stage_configs/split_and_shard_patients.yaml rename to src/MEDS_transforms/configs/stage_configs/split_and_shard_subjects.yaml index c4015bde..7dbed118 100644 --- a/src/MEDS_transforms/configs/stage_configs/split_and_shard_patients.yaml +++ b/src/MEDS_transforms/configs/stage_configs/split_and_shard_subjects.yaml @@ -1,7 +1,7 @@ -split_and_shard_patients: +split_and_shard_subjects: is_metadata: True output_dir: ${cohort_dir}/metadata - n_patients_per_shard: 50000 + n_subjects_per_shard: 50000 external_splits_json_fp: null split_fracs: train: 0.8 diff --git a/src/MEDS_transforms/extract/README.md b/src/MEDS_transforms/extract/README.md index 47f7e7d9..a60a8c81 100644 --- a/src/MEDS_transforms/extract/README.md +++ b/src/MEDS_transforms/extract/README.md @@ -4,8 +4,8 @@ This directory contains the scripts and functions used to extract raw data into dataset is: 1. Arranged in a series of files on disk of an allowed format (e.g., `.csv`, `.csv.gz`, `.parquet`)... -2. Such that each file stores a dataframe containing data about patients such that each row of any given - table corresponds to zero or more observations about a patient at a given time... +2. Such that each file stores a dataframe containing data about subjects such that each row of any given + table corresponds to zero or more observations about a subject at a given time... 3. And you can configure how to extract those observations in the time, code, and numeric value format of MEDS in the event conversion `yaml` file format specified below, then... this tool can automatically extract your raw data into a MEDS dataset for you in an efficient, reproducible, @@ -53,7 +53,7 @@ step](#step-0-pre-meds) and the [Data Cleaning step](#step-3-data-cleanup), for ### Event Conversion Configuration The event conversion configuration file tells MEDS Extract how to convert each row of a file among your raw -data files into one or more MEDS measurements (meaning a tuple of a patient ID, a time, a categorical +data files into one or more MEDS measurements (meaning a tuple of a subject ID, a time, a categorical code, and/or various other value or properties columns, most commonly a numeric value). This file is written in yaml and has the following format: @@ -93,11 +93,11 @@ each row of the file will be converted into a MEDS event according to the logic here, as string literals _cannot_ be used for these columns. There are several more nuanced aspects to the configuration file that have not yet been discussed. First, the -configuration file also specifies how to identify the patient ID from the raw data. This can be done either by -specifying a global `patient_id_col` field at the top level of the configuration file, or by specifying a -`patient_id_col` field at the per-file or per-event level. Multiple specifications can be used simultaneously, -with the most local taking precedent. If no patient ID column is specified, the patient ID will be assumed to -be stored in a `patient_id` column. If the patient ID column is not found, an error will be raised. +configuration file also specifies how to identify the subject ID from the raw data. This can be done either by +specifying a global `subject_id_col` field at the top level of the configuration file, or by specifying a +`subject_id_col` field at the per-file or per-event level. Multiple specifications can be used simultaneously, +with the most local taking precedent. If no subject ID column is specified, the subject ID will be assumed to +be stored in a `subject_id` column. If the subject ID column is not found, an error will be raised. Second, you can also specify how to link the codes constructed for each event block to code-specific metadata in these blocks. This is done by specifying a `_metadata` block in the event block. The format of this block @@ -117,7 +117,7 @@ the [Partial MIMIC-IV Example](#partial-mimic-iv-example) below for an example o ```yaml subjects: - patient_id_col: MRN + subject_id_col: MRN eye_color: code: - EYE_COLOR @@ -144,7 +144,7 @@ admit_vitals: ##### Partial MIMIC-IV Example ```yaml -patient_id_col: subject_id +subject_id_col: subject_id hosp/admissions: admission: code: @@ -259,4 +259,4 @@ Note that this tool is _not_: TODO: Add issues for all of these. 1. Single event blocks for files should be specifiable directly, without an event block name. -2. Time format should be specifiable at the file or global level, like patient ID. +2. Time format should be specifiable at the file or global level, like subject ID. diff --git a/src/MEDS_transforms/extract/convert_to_sharded_events.py b/src/MEDS_transforms/extract/convert_to_sharded_events.py index ee4e9d70..c83f9347 100755 --- a/src/MEDS_transforms/extract/convert_to_sharded_events.py +++ b/src/MEDS_transforms/extract/convert_to_sharded_events.py @@ -97,7 +97,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy """Extracts a single event dataframe from the raw data. Args: - df: The raw data DataFrame. This must have a `"patient_id"` column containing the patient ID. The + df: The raw data DataFrame. This must have a `"subject_id"` column containing the subject ID. The other columns it must have are determined by the `event_cfg` configuration dictionary. event_cfg: A dictionary containing the configuration for the event. This must contain two critical keys (`"code"` and `"time"`) and may contain additional keys for other columns to include @@ -128,7 +128,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy A DataFrame containing the event data extracted from the raw data, containing only unique rows across all columns. If the raw data has no duplicates when considering the event column space, the output dataframe will have the same number of rows as the raw data and be in the same order. The output - dataframe will contain at least three columns: `"patient_id"`, `"code"`, and `"time"`. If the + dataframe will contain at least three columns: `"subject_id"`, `"code"`, and `"time"`. If the event has additional columns, they will be included in the output dataframe as well. **_Events that would be extracted with a null code or a time that should be specified via a column with or without a formatting option but in practice is null will be dropped._** Note that this dropping logic @@ -145,7 +145,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> _ = pl.Config.set_tbl_rows(20) >>> _ = pl.Config.set_tbl_cols(20) >>> raw_data = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "code": ["A", "B", "C", "D"], ... "code_modifier": ["1", "2", "3", "4"], ... "time": ["2021-01-01", "2021-01-02", "2021-01-03", "2021-01-04"], @@ -160,7 +160,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(raw_data, event_cfg) shape: (4, 4) ┌────────────┬───────────┬─────────────────────┬───────────────┐ - │ patient_id ┆ code ┆ time ┆ numeric_value │ + │ subject_id ┆ code ┆ time ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ datetime[μs] ┆ i64 │ ╞════════════╪═══════════╪═════════════════════╪═══════════════╡ @@ -170,7 +170,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy │ 2 ┆ FOO//D//4 ┆ 2021-01-04 00:00:00 ┆ 4 │ └────────────┴───────────┴─────────────────────┴───────────────┘ >>> data_with_nulls = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "code": ["A", None, "C", "D"], ... "code_modifier": ["1", "2", "3", None], ... "time": [None, "2021-01-02", "2021-01-03", "2021-01-04"], @@ -185,7 +185,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(data_with_nulls, event_cfg) shape: (2, 4) ┌────────────┬────────┬─────────────────────┬───────────────┐ - │ patient_id ┆ code ┆ time ┆ numeric_value │ + │ subject_id ┆ code ┆ time ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ datetime[μs] ┆ i64 │ ╞════════════╪════════╪═════════════════════╪═══════════════╡ @@ -195,7 +195,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> from datetime import datetime >>> complex_raw_data = pl.DataFrame( ... { - ... "patient_id": [1, 1, 2, 2, 2, 3], + ... "subject_id": [1, 1, 2, 2, 2, 3], ... "admission_time": [ ... "2021-01-01 00:00:00", ... "2021-01-02 00:00:00", @@ -227,7 +227,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy ... "eye_color": ["blue", "blue", "green", "green", "green", "brown"], ... }, ... schema={ - ... "patient_id": pl.UInt8, + ... "subject_id": pl.UInt8, ... "admission_time": pl.Utf8, ... "discharge_time": pl.Datetime, ... "admission_type": pl.Utf8, @@ -263,7 +263,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> complex_raw_data shape: (6, 9) ┌────────────┬─────────────────────┬─────────────────────┬────────────────┬────────────────────┬──────────────────┬────────────────┬────────────┬───────────┐ - │ patient_id ┆ admission_time ┆ discharge_time ┆ admission_type ┆ discharge_location ┆ discharge_status ┆ severity_score ┆ death_time ┆ eye_color │ + │ subject_id ┆ admission_time ┆ discharge_time ┆ admission_type ┆ discharge_location ┆ discharge_status ┆ severity_score ┆ death_time ┆ eye_color │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ str ┆ cat ┆ str ┆ f64 ┆ str ┆ cat │ ╞════════════╪═════════════════════╪═════════════════════╪════════════════╪════════════════════╪══════════════════╪════════════════╪════════════╪═══════════╡ @@ -277,7 +277,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(complex_raw_data, valid_admission_event_cfg) shape: (6, 4) ┌────────────┬──────────────┬─────────────────────┬───────────────┐ - │ patient_id ┆ code ┆ time ┆ numeric_value │ + │ subject_id ┆ code ┆ time ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ f64 │ ╞════════════╪══════════════╪═════════════════════╪═══════════════╡ @@ -294,7 +294,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy ... ) shape: (6, 4) ┌────────────┬──────────────┬─────────────────────┬───────────────┐ - │ patient_id ┆ code ┆ time ┆ numeric_value │ + │ subject_id ┆ code ┆ time ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ f64 │ ╞════════════╪══════════════╪═════════════════════╪═══════════════╡ @@ -311,7 +311,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy ... ) shape: (6, 4) ┌────────────┬──────────────┬─────────────────────┬───────────────┐ - │ patient_id ┆ code ┆ time ┆ numeric_value │ + │ subject_id ┆ code ┆ time ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ f64 │ ╞════════════╪══════════════╪═════════════════════╪═══════════════╡ @@ -325,7 +325,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(complex_raw_data, valid_discharge_event_cfg) shape: (6, 5) ┌────────────┬─────────────────┬─────────────────────┬───────────────────┬────────────┐ - │ patient_id ┆ code ┆ time ┆ categorical_value ┆ text_value │ + │ subject_id ┆ code ┆ time ┆ categorical_value ┆ text_value │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ str ┆ str │ ╞════════════╪═════════════════╪═════════════════════╪═══════════════════╪════════════╡ @@ -339,7 +339,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(complex_raw_data, valid_death_event_cfg) shape: (3, 3) ┌────────────┬───────┬─────────────────────┐ - │ patient_id ┆ code ┆ time │ + │ subject_id ┆ code ┆ time │ │ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] │ ╞════════════╪═══════╪═════════════════════╡ @@ -351,7 +351,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(complex_raw_data, valid_static_event_cfg) shape: (3, 3) ┌────────────┬──────────────────┬──────────────┐ - │ patient_id ┆ code ┆ time │ + │ subject_id ┆ code ┆ time │ │ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] │ ╞════════════╪══════════════════╪══════════════╡ @@ -371,10 +371,10 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy Traceback (most recent call last): ... ValueError: Invalid time literal: 12-01-23 - >>> extract_event(complex_raw_data, {"code": "test", "time": None, "patient_id": 3}) + >>> extract_event(complex_raw_data, {"code": "test", "time": None, "subject_id": 3}) Traceback (most recent call last): ... - KeyError: "Event column name 'patient_id' cannot be overridden." + KeyError: "Event column name 'subject_id' cannot be overridden." >>> extract_event(complex_raw_data, {"code": "test", "time": None, "foobar": "fuzz"}) Traceback (most recent call last): ... @@ -389,7 +389,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy ValueError: Source column 'discharge_time' for event column foobar is not numeric, string, or categorical! Cannot be used as an event col. """ # noqa: E501 event_cfg = copy.deepcopy(event_cfg) - event_exprs = {"patient_id": pl.col("patient_id")} + event_exprs = {"subject_id": pl.col("subject_id")} if "code" not in event_cfg: raise KeyError( @@ -401,8 +401,8 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy "Event configuration dictionary must contain 'time' key. " f"Got: [{', '.join(event_cfg.keys())}]." ) - if "patient_id" in event_cfg: - raise KeyError("Event column name 'patient_id' cannot be overridden.") + if "subject_id" in event_cfg: + raise KeyError("Event column name 'subject_id' cannot be overridden.") code_expr, code_null_filter_expr, needed_cols = get_code_expr(event_cfg.pop("code")) @@ -502,7 +502,7 @@ def convert_to_events( """Converts a DataFrame of raw data into a DataFrame of events. Args: - df: The raw data DataFrame. This must have a `"patient_id"` column containing the patient ID. The + df: The raw data DataFrame. This must have a `"subject_id"` column containing the subject ID. The other columns it must have are determined by the `event_cfgs` configuration dictionary. For the precise mechanism of column determination, see the `extract_event` function. event_cfgs: A dictionary containing the configurations for the events to extract. The keys of this @@ -518,7 +518,7 @@ def convert_to_events( events extracted from the raw data, with the rows from each event DataFrame concatenated together. After concatenation, this dataframe will not be deduplicated, so if the raw data results in duplicates across events of different name, these will be preserved in the output DataFrame. - The output DataFrame will contain at least three columns: `"patient_id"`, `"code"`, and `"time"`. + The output DataFrame will contain at least three columns: `"subject_id"`, `"code"`, and `"time"`. If any events have additional columns, these will be included in the output DataFrame as well. All columns across all event configurations will be included in the output DataFrame, with `null` values filled in for events that do not have a particular column. @@ -533,7 +533,7 @@ def convert_to_events( >>> from datetime import datetime >>> complex_raw_data = pl.DataFrame( ... { - ... "patient_id": [1, 1, 2, 2, 2, 3], + ... "subject_id": [1, 1, 2, 2, 2, 3], ... "admission_time": [ ... "2021-01-01 00:00:00", ... "2021-01-02 00:00:00", @@ -564,7 +564,7 @@ def convert_to_events( ... "eye_color": ["blue", "blue", "green", "green", "green", "brown"], ... }, ... schema={ - ... "patient_id": pl.UInt8, + ... "subject_id": pl.UInt8, ... "admission_time": pl.Utf8, ... "discharge_time": pl.Datetime, ... "admission_type": pl.Utf8, @@ -602,7 +602,7 @@ def convert_to_events( >>> complex_raw_data shape: (6, 8) ┌────────────┬─────────────────────┬─────────────────────┬────────────────┬────────────────────┬────────────────┬────────────┬───────────┐ - │ patient_id ┆ admission_time ┆ discharge_time ┆ admission_type ┆ discharge_location ┆ severity_score ┆ death_time ┆ eye_color │ + │ subject_id ┆ admission_time ┆ discharge_time ┆ admission_type ┆ discharge_location ┆ severity_score ┆ death_time ┆ eye_color │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ str ┆ cat ┆ f64 ┆ str ┆ cat │ ╞════════════╪═════════════════════╪═════════════════════╪════════════════╪════════════════════╪════════════════╪════════════╪═══════════╡ @@ -616,7 +616,7 @@ def convert_to_events( >>> convert_to_events(complex_raw_data, event_cfgs) shape: (18, 7) ┌────────────┬───────────┬─────────────────────┬────────────────┬───────────────────────┬────────────────────┬───────────┐ - │ patient_id ┆ code ┆ time ┆ admission_type ┆ severity_on_admission ┆ discharge_location ┆ eye_color │ + │ subject_id ┆ code ┆ time ┆ admission_type ┆ severity_on_admission ┆ discharge_location ┆ eye_color │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ str ┆ f64 ┆ cat ┆ cat │ ╞════════════╪═══════════╪═════════════════════╪════════════════╪═══════════════════════╪════════════════════╪═══════════╡ @@ -666,7 +666,7 @@ def convert_to_events( @hydra.main(version_base=None, config_path=str(CONFIG_YAML.parent), config_name=CONFIG_YAML.stem) def main(cfg: DictConfig): - """Converts the event-sharded raw data into MEDS events and storing them in patient subsharded flat files. + """Converts the event-sharded raw data into MEDS events and storing them in subject subsharded flat files. All arguments are specified through the command line into the `cfg` object through Hydra. @@ -680,7 +680,7 @@ def main(cfg: DictConfig): file. """ - input_dir, patient_subsharded_dir, metadata_input_dir = stage_init(cfg) + input_dir, subject_subsharded_dir, metadata_input_dir = stage_init(cfg) shards = json.loads(Path(cfg.shards_map_fp).read_text()) @@ -694,13 +694,13 @@ def main(cfg: DictConfig): event_conversion_cfg = OmegaConf.load(event_conversion_cfg_fp) logger.info(f"Event conversion config:\n{OmegaConf.to_yaml(event_conversion_cfg)}") - default_patient_id_col = event_conversion_cfg.pop("patient_id_col", "patient_id") + default_subject_id_col = event_conversion_cfg.pop("subject_id_col", "subject_id") - patient_subsharded_dir.mkdir(parents=True, exist_ok=True) - OmegaConf.save(event_conversion_cfg, patient_subsharded_dir / "event_conversion_config.yaml") + subject_subsharded_dir.mkdir(parents=True, exist_ok=True) + OmegaConf.save(event_conversion_cfg, subject_subsharded_dir / "event_conversion_config.yaml") - patient_splits = list(shards.items()) - random.shuffle(patient_splits) + subject_splits = list(shards.items()) + random.shuffle(subject_splits) event_configs = list(event_conversion_cfg.items()) random.shuffle(event_configs) @@ -708,28 +708,28 @@ def main(cfg: DictConfig): # Here, we'll be reading files directly, so we'll turn off globbing read_fn = partial(pl.scan_parquet, glob=False) - for sp, patients in patient_splits: + for sp, subjects in subject_splits: for input_prefix, event_cfgs in event_configs: event_cfgs = copy.deepcopy(event_cfgs) - input_patient_id_column = event_cfgs.pop("patient_id_col", default_patient_id_col) + input_subject_id_column = event_cfgs.pop("subject_id_col", default_subject_id_col) event_shards = list((input_dir / input_prefix).glob("*.parquet")) random.shuffle(event_shards) for shard_fp in event_shards: - out_fp = patient_subsharded_dir / sp / input_prefix / shard_fp.name + out_fp = subject_subsharded_dir / sp / input_prefix / shard_fp.name logger.info(f"Converting {shard_fp} to events and saving to {out_fp}") def compute_fn(df: pl.LazyFrame) -> pl.LazyFrame: - typed_patients = pl.Series(patients, dtype=df.schema[input_patient_id_column]) + typed_subjects = pl.Series(subjects, dtype=df.schema[input_subject_id_column]) - if input_patient_id_column != "patient_id": - df = df.rename({input_patient_id_column: "patient_id"}) + if input_subject_id_column != "subject_id": + df = df.rename({input_subject_id_column: "subject_id"}) try: logger.info(f"Extracting events for {input_prefix}/{shard_fp.name}") return convert_to_events( - df.filter(pl.col("patient_id").is_in(typed_patients)), + df.filter(pl.col("subject_id").is_in(typed_subjects)), event_cfgs=copy.deepcopy(event_cfgs), ) except Exception as e: diff --git a/src/MEDS_transforms/extract/extract_code_metadata.py b/src/MEDS_transforms/extract/extract_code_metadata.py index e9133eb6..818d88a2 100644 --- a/src/MEDS_transforms/extract/extract_code_metadata.py +++ b/src/MEDS_transforms/extract/extract_code_metadata.py @@ -250,9 +250,9 @@ def get_events_and_metadata_by_metadata_fp(event_configs: dict | DictConfig) -> Examples: >>> event_configs = { - ... "patient_id_col": "MRN", + ... "subject_id_col": "MRN", ... "icu/procedureevents": { - ... "patient_id_col": "subject_id", + ... "subject_id_col": "subject_id", ... "start": { ... "code": ["PROCEDURE", "START", "col(itemid)"], ... "_metadata": { @@ -304,11 +304,11 @@ def get_events_and_metadata_by_metadata_fp(event_configs: dict | DictConfig) -> out = {} for file_pfx, event_cfgs_for_pfx in event_configs.items(): - if file_pfx == "patient_id_col": + if file_pfx == "subject_id_col": continue for event_key, event_cfg in event_cfgs_for_pfx.items(): - if event_key == "patient_id_col": + if event_key == "subject_id_col": continue for metadata_pfx, metadata_cfg in event_cfg.get("_metadata", {}).items(): diff --git a/src/MEDS_transforms/extract/finalize_MEDS_data.py b/src/MEDS_transforms/extract/finalize_MEDS_data.py index f9d68730..54e1d20e 100644 --- a/src/MEDS_transforms/extract/finalize_MEDS_data.py +++ b/src/MEDS_transforms/extract/finalize_MEDS_data.py @@ -38,7 +38,7 @@ def get_and_validate_data_schema(df: pl.LazyFrame, stage_cfg: DictConfig) -> pa. >>> get_and_validate_data_schema(df.lazy(), dict(do_retype=False)) # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... - ValueError: MEDS Data DataFrame must have a 'patient_id' column of type Int64. + ValueError: MEDS Data DataFrame must have a 'subject_id' column of type Int64. MEDS Data DataFrame must have a 'time' column of type Datetime(time_unit='us', time_zone=None). MEDS Data DataFrame must have a 'code' column of type String. @@ -46,28 +46,28 @@ def get_and_validate_data_schema(df: pl.LazyFrame, stage_cfg: DictConfig) -> pa. >>> get_and_validate_data_schema(df.lazy(), {}) # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... - ValueError: MEDS Data DataFrame must have a 'patient_id' column of type Int64. + ValueError: MEDS Data DataFrame must have a 'subject_id' column of type Int64. MEDS Data DataFrame must have a 'code' column of type String. >>> from datetime import datetime >>> df = pl.DataFrame({ - ... "patient_id": pl.Series([1, 2], dtype=pl.UInt32), + ... "subject_id": pl.Series([1, 2], dtype=pl.UInt32), ... "time": [datetime(2021, 1, 1), datetime(2021, 1, 2)], ... "code": ["A", "B"], "text_value": ["1", None], "numeric_value": [None, 34.2] ... }) >>> get_and_validate_data_schema(df.lazy(), dict(do_retype=False)) # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... - ValueError: MEDS Data 'patient_id' column must be of type Int64. Got UInt32. + ValueError: MEDS Data 'subject_id' column must be of type Int64. Got UInt32. MEDS Data 'numeric_value' column must be of type Float32. Got Float64. >>> get_and_validate_data_schema(df.lazy(), {}) pyarrow.Table - patient_id: int64 + subject_id: int64 time: timestamp[us] code: string numeric_value: float text_value: large_string ---- - patient_id: [[1,2]] + subject_id: [[1,2]] time: [[2021-01-01 00:00:00.000000,2021-01-02 00:00:00.000000]] code: [["A","B"]] numeric_value: [[null,34.2]] @@ -111,7 +111,7 @@ def main(cfg: DictConfig): """Writes out schema compliant MEDS data files for the extracted dataset. In particular, this script ensures that all shard files are MEDS compliant with the mandatory columns - - `patient_id` (Int64) + - `subject_id` (Int64) - `time` (DateTime) - `code` (String) - `numeric_value` (Float32) diff --git a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py index 366d89aa..7a6dcc07 100755 --- a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py +++ b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py @@ -15,7 +15,8 @@ code_metadata_schema, dataset_metadata_schema, held_out_split, - patient_split_schema, + subject_id_field, + subject_split_schema, train_split, tuning_split, ) @@ -121,8 +122,8 @@ def main(cfg: DictConfig): - `etl_name` (string) - `etl_version` (string) - `meds_version` (string) - (3) a `metadata/patient_splits.parquet` file exists that has the mandatory columns - - `patient_id` (Int64) + (3) a `metadata/subject_splits.parquet` file exists that has the mandatory columns + - `subject_id` (Int64) - `split` (string) This stage *_should almost always be the last metadata stage in an extraction pipeline._* @@ -151,9 +152,9 @@ def main(cfg: DictConfig): output_code_metadata_fp = output_metadata_dir / "codes.parquet" dataset_metadata_fp = output_metadata_dir / "dataset.json" - patient_splits_fp = output_metadata_dir / "patient_splits.parquet" + subject_splits_fp = output_metadata_dir / "subject_splits.parquet" - for out_fp in [output_code_metadata_fp, dataset_metadata_fp, patient_splits_fp]: + for out_fp in [output_code_metadata_fp, dataset_metadata_fp, subject_splits_fp]: out_fp.parent.mkdir(parents=True, exist_ok=True) if out_fp.exists() and cfg.do_overwrite: out_fp.unlink() @@ -194,28 +195,28 @@ def main(cfg: DictConfig): # Split creation shards_map_fp = Path(cfg.shards_map_fp) - logger.info("Creating patient splits from {str(shards_map_fp.resolve())}") + logger.info("Creating subject splits from {str(shards_map_fp.resolve())}") shards_map = json.loads(shards_map_fp.read_text()) - patient_splits = [] + subject_splits = [] seen_splits = {train_split: 0, tuning_split: 0, held_out_split: 0} - for shard, patient_ids in shards_map.items(): + for shard, subject_ids in shards_map.items(): split = "/".join(shard.split("/")[:-1]) if split not in seen_splits: seen_splits[split] = 0 - seen_splits[split] += len(patient_ids) + seen_splits[split] += len(subject_ids) - patient_splits.extend([{"patient_id": pid, "split": split} for pid in patient_ids]) + subject_splits.extend([{subject_id_field: pid, "split": split} for pid in subject_ids]) for split, cnt in seen_splits.items(): if cnt: - logger.info(f"Split {split} has {cnt} patients") + logger.info(f"Split {split} has {cnt} subjects") else: logger.warning(f"Split {split} not found in shards map") - patient_splits_tbl = pa.Table.from_pylist(patient_splits, schema=patient_split_schema) - logger.info(f"Writing finalized patient splits to {str(patient_splits_fp.resolve())}") - pq.write_table(patient_splits_tbl, patient_splits_fp) + subject_splits_tbl = pa.Table.from_pylist(subject_splits, schema=subject_split_schema) + logger.info(f"Writing finalized subject splits to {str(subject_splits_fp.resolve())}") + pq.write_table(subject_splits_tbl, subject_splits_fp) if __name__ == "__main__": diff --git a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py index 49a45e1d..e611c773 100755 --- a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py +++ b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py @@ -30,8 +30,8 @@ def merge_subdirs_and_sort( warning is logged, but an error is *not* raised. Which rows are retained if the uniqeu-by columns are not all columns is not guaranteed, but is also *not* random, so this may have statistical implications. - additional_sort_by: Additional columns to sort by, in addition to the default sorting by patient ID - and time. If `None`, only patient ID and time are used. If a list of strings, these + additional_sort_by: Additional columns to sort by, in addition to the default sorting by subject ID + and time. If `None`, only subject ID and time are used. If a list of strings, these columns are used in addition to the default sorting. If a column is not found in the dataframe, it is omitted from the sort-by, a warning is logged, but an error is *not* raised. This functionality is useful both for deterministic testing and in cases where a data owner wants to impose @@ -41,7 +41,7 @@ def merge_subdirs_and_sort( A single dataframe containing all the data from the parquet files in the subdirs of `sp_dir`. These files will be concatenated diagonally, taking the union of all rows in all dataframes and all unique columns in all dataframes to form the merged output. The returned dataframe will be made unique by the - columns specified in `unique_by` and sorted by first patient ID, then time, then all columns in + columns specified in `unique_by` and sorted by first subject ID, then time, then all columns in `additional_sort_by`, if any. Raises: @@ -50,15 +50,15 @@ def merge_subdirs_and_sort( Examples: >>> from tempfile import TemporaryDirectory - >>> df1 = pl.DataFrame({"patient_id": [1, 2], "time": [10, 20], "code": ["A", "B"]}) + >>> df1 = pl.DataFrame({"subject_id": [1, 2], "time": [10, 20], "code": ["A", "B"]}) >>> df2 = pl.DataFrame({ - ... "patient_id": [1, 1, 3], + ... "subject_id": [1, 1, 3], ... "time": [2, 1, 8], ... "code": ["C", "D", "E"], ... "numeric_value": [None, 2.0, None], ... }) >>> df3 = pl.DataFrame({ - ... "patient_id": [1, 1, 3], + ... "subject_id": [1, 1, 3], ... "time": [2, 2, 8], ... "code": ["C", "D", "E"], ... "numeric_value": [6.2, 2.0, None], @@ -84,7 +84,7 @@ def merge_subdirs_and_sort( ... ).collect() shape: (8, 4) ┌────────────┬──────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str ┆ f64 │ ╞════════════╪══════╪══════╪═══════════════╡ @@ -112,7 +112,7 @@ def merge_subdirs_and_sort( ... ).collect() shape: (7, 4) ┌────────────┬──────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str ┆ f64 │ ╞════════════╪══════╪══════╪═══════════════╡ @@ -131,18 +131,18 @@ def merge_subdirs_and_sort( ... df2.write_parquet(sp_dir / "subdir1" / "file2.parquet") ... (sp_dir / "subdir2").mkdir() ... df3.write_parquet(sp_dir / "subdir2" / "df.parquet") - ... # We just display the patient ID, time, and code columns as the numeric value column + ... # We just display the subject ID, time, and code columns as the numeric value column ... # is not guaranteed to be deterministic in the output given some rows will be dropped due to ... # the unique-by constraint. ... merge_subdirs_and_sort( ... sp_dir, ... event_subsets=["subdir1", "subdir2"], - ... unique_by=["patient_id", "time", "code"], + ... unique_by=["subject_id", "time", "code"], ... additional_sort_by=["code", "numeric_value"] - ... ).select("patient_id", "time", "code").collect() + ... ).select("subject_id", "time", "code").collect() shape: (6, 3) ┌────────────┬──────┬──────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str │ ╞════════════╪══════╪══════╡ @@ -188,7 +188,7 @@ def merge_subdirs_and_sort( case _: raise ValueError(f"Invalid unique_by value: {unique_by}") - sort_by = ["patient_id", "time"] + sort_by = ["subject_id", "time"] if additional_sort_by is not None: for s in additional_sort_by: if s in df_columns: @@ -201,12 +201,12 @@ def merge_subdirs_and_sort( @hydra.main(version_base=None, config_path=str(CONFIG_YAML.parent), config_name=CONFIG_YAML.stem) def main(cfg: DictConfig): - """Merges the patient sub-sharded events into a single parquet file per patient shard. + """Merges the subject sub-sharded events into a single parquet file per subject shard. This function takes all dataframes (in parquet files) in any subdirs of the `cfg.stage_cfg.input_dir` and merges them into a single dataframe. All dataframes in the subdirs are assumed to be in the unnested, MEDS - format, and cover the same group of patients (specific to the shard being processed). The merged dataframe - will also be sorted by patient ID and time. + format, and cover the same group of subjects (specific to the shard being processed). The merged dataframe + will also be sorted by subject ID and time. All arguments are specified through the command line into the `cfg` object through Hydra. @@ -219,14 +219,14 @@ def main(cfg: DictConfig): stage_configs.merge_to_MEDS_cohort.unique_by: The list of columns that should be ensured to be unique after the dataframes are merged. Defaults to `"*"`, which means all columns are used. stage_configs.merge_to_MEDS_cohort.additional_sort_by: Additional columns to sort by, in addition to - the default sorting by patient ID and time. Defaults to `None`, which means only patient ID + the default sorting by subject ID and time. Defaults to `None`, which means only subject ID and time are used. Returns: Writes the merged dataframes to the shard-specific output filepath in the `cfg.stage_cfg.output_dir`. """ event_conversion_cfg = OmegaConf.load(cfg.event_conversion_config_fp) - event_conversion_cfg.pop("patient_id_col", None) + event_conversion_cfg.pop("subject_id_col", None) read_fn = partial( merge_subdirs_and_sort, diff --git a/src/MEDS_transforms/extract/shard_events.py b/src/MEDS_transforms/extract/shard_events.py index 18450bbb..5eebc710 100755 --- a/src/MEDS_transforms/extract/shard_events.py +++ b/src/MEDS_transforms/extract/shard_events.py @@ -11,6 +11,7 @@ import hydra import polars as pl from loguru import logger +from meds import subject_id_field from omegaconf import DictConfig, OmegaConf from MEDS_transforms.extract import CONFIG_YAML @@ -169,7 +170,7 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: event conversion configurations that are specific to each file based on its stem (filename without the extension). It compiles a list of column names needed for each file from the configuration, which includes both general - columns like row index and patient ID, as well as specific columns defined + columns like row index and subject ID, as well as specific columns defined for medical events and times formatted in a special 'col(column_name)' syntax. Args: @@ -185,7 +186,7 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: Examples: >>> cfg = DictConfig({ - ... "patient_id_col": "patient_id_global", + ... "subject_id_col": "subject_id_global", ... "hosp/patients": { ... "eye_color": { ... "code": ["EYE_COLOR", "col(eye_color)"], "time": None, "mod": "mod_col" @@ -195,7 +196,7 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: ... } ... }, ... "icu/chartevents": { - ... "patient_id_col": "patient_id_icu", + ... "subject_id_col": "subject_id_icu", ... "heart_rate": { ... "code": "HEART_RATE", "time": "charttime", "numeric_value": "HR" ... }, @@ -212,19 +213,19 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: ... } ... }) >>> retrieve_columns(cfg) # doctest: +NORMALIZE_WHITESPACE - {'hosp/patients': ['eye_color', 'height', 'mod_col', 'patient_id_global'], - 'icu/chartevents': ['HR', 'charttime', 'itemid', 'mod_lab', 'patient_id_icu', 'value', 'valuenum', + {'hosp/patients': ['eye_color', 'height', 'mod_col', 'subject_id_global'], + 'icu/chartevents': ['HR', 'charttime', 'itemid', 'mod_lab', 'subject_id_icu', 'value', 'valuenum', 'valueuom'], - 'icu/meds': ['medication', 'medtime', 'patient_id_global']} + 'icu/meds': ['medication', 'medtime', 'subject_id_global']} >>> cfg = DictConfig({ ... "subjects": { - ... "patient_id_col": "MRN", + ... "subject_id_col": "MRN", ... "eye_color": {"code": ["col(eye_color)"], "time": None}, ... }, ... "labs": {"lab": {"code": "col(labtest)", "time": "charttime"}}, ... }) >>> retrieve_columns(cfg) - {'subjects': ['MRN', 'eye_color'], 'labs': ['charttime', 'labtest', 'patient_id']} + {'subjects': ['MRN', 'eye_color'], 'labs': ['charttime', 'labtest', 'subject_id']} """ event_conversion_cfg = copy.deepcopy(event_conversion_cfg) @@ -232,11 +233,11 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: # Initialize a dictionary to store file paths as keys and lists of column names as values. prefix_to_columns = {} - default_patient_id_col = event_conversion_cfg.pop("patient_id_col", "patient_id") + default_subject_id_col = event_conversion_cfg.pop("subject_id_col", subject_id_field) for input_prefix, event_cfgs in event_conversion_cfg.items(): - input_patient_id_column = event_cfgs.pop("patient_id_col", default_patient_id_col) + input_subject_id_column = event_cfgs.pop("subject_id_col", default_subject_id_col) - prefix_to_columns[input_prefix] = {input_patient_id_column} + prefix_to_columns[input_prefix] = {input_subject_id_column} for event_cfg in event_cfgs.values(): # If the config has a 'code' key and it contains column fields, parse and add them. diff --git a/src/MEDS_transforms/extract/split_and_shard_patients.py b/src/MEDS_transforms/extract/split_and_shard_subjects.py similarity index 71% rename from src/MEDS_transforms/extract/split_and_shard_patients.py rename to src/MEDS_transforms/extract/split_and_shard_subjects.py index 61f17263..30814051 100755 --- a/src/MEDS_transforms/extract/split_and_shard_patients.py +++ b/src/MEDS_transforms/extract/split_and_shard_subjects.py @@ -13,77 +13,77 @@ from MEDS_transforms.utils import stage_init -def shard_patients[ +def shard_subjects[ SUBJ_ID_T ]( - patients: np.ndarray, - n_patients_per_shard: int = 50000, + subjects: np.ndarray, + n_subjects_per_shard: int = 50000, external_splits: dict[str, Sequence[SUBJ_ID_T]] | None = None, split_fracs_dict: dict[str, float] | None = {"train": 0.8, "tuning": 0.1, "held_out": 0.1}, seed: int = 1, ) -> dict[str, list[SUBJ_ID_T]]: - """Shard a list of patients, nested within train/tuning/held-out splits. + """Shard a list of subjects, nested within train/tuning/held-out splits. - This function takes a list of patients and shards them into train/tuning/held-out splits, with the shards + This function takes a list of subjects and shards them into train/tuning/held-out splits, with the shards of a consistent size, nested within the splits. The function will also respect external splits, if provided, such that mandated splits (such as prospective held out sets or pre-existing, task-specific held out sets) are with-held and sharded as separate splits from the IID splits defined by `split_fracs_dict`. It returns a dictionary mapping the split and shard names (realized as f"{split}/{shard}") to the list of - patients in that shard. + subjects in that shard. Args: - patients: The list of patients to shard. - n_patients_per_shard: The maximum number of patients to include in each shard. + subjects: The list of subjects to shard. + n_subjects_per_shard: The maximum number of subjects to include in each shard. external_splits: The externally defined splits to respect. If provided, the keys of this dictionary - will be used as split names, and the values as the list of patients in that split. These + will be used as split names, and the values as the list of subjects in that split. These pre-defined splits will be excluded from IID splits generated by this function, but will be sharded like normal. Note that this is largely only appropriate for held-out sets for pre-defined - tasks or test cases (e.g., prospective tests); training patients should often still be included in + tasks or test cases (e.g., prospective tests); training subjects should often still be included in the IID splits to maximize the amount of data that can be used for training. - split_fracs_dict: A dictionary mapping the split name to the fraction of patients to include in that + split_fracs_dict: A dictionary mapping the split name to the fraction of subjects to include in that split. Defaults to 80% train, 10% tuning, 10% held-out. This can be None or empty only when external splits fully specify the population. - seed: The random seed to use for shuffling the patients before seeding and sharding. This is useful + seed: The random seed to use for shuffling the subjects before seeding and sharding. This is useful for ensuring reproducibility. Returns: - A dictionary mapping f"{split}/{shard}" to the list of patients in that shard. This may include - overlapping patients across a subset of these splits, but never across shards within a split. Any + A dictionary mapping f"{split}/{shard}" to the list of subjects in that shard. This may include + overlapping subjects across a subset of these splits, but never across shards within a split. Any overlap will solely occur between the an external split and another external split. Raises: ValueError: If the sum of the split fractions in `split_fracs_dict` is not equal to 1. Examples: - >>> patients = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int) - >>> shard_patients(patients, n_patients_per_shard=3) + >>> subjects = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int) + >>> shard_subjects(subjects, n_subjects_per_shard=3) {'train/0': [9, 4, 8], 'train/1': [2, 1, 10], 'train/2': [6, 5], 'tuning/0': [3], 'held_out/0': [7]} >>> external_splits = { ... 'taskA/held_out': np.array([8, 9, 10], dtype=int), ... 'taskB/held_out': np.array([10, 8, 9], dtype=int), ... } - >>> shard_patients(patients, 3, external_splits) # doctest: +NORMALIZE_WHITESPACE + >>> shard_subjects(subjects, 3, external_splits) # doctest: +NORMALIZE_WHITESPACE {'train/0': [5, 7, 4], 'train/1': [1, 2], 'tuning/0': [3], 'held_out/0': [6], 'taskA/held_out/0': [8, 9, 10], 'taskB/held_out/0': [10, 8, 9]} - >>> shard_patients(patients, n_patients_per_shard=3, split_fracs_dict={'train': 0.5}) + >>> shard_subjects(subjects, n_subjects_per_shard=3, split_fracs_dict={'train': 0.5}) Traceback (most recent call last): ... ValueError: The sum of the split fractions must be equal to 1. - >>> shard_patients([1, 2], n_patients_per_shard=3) + >>> shard_subjects([1, 2], n_subjects_per_shard=3) Traceback (most recent call last): ... - ValueError: Unable to adjust splits to ensure all splits have at least 1 patient. + ValueError: Unable to adjust splits to ensure all splits have at least 1 subject. >>> external_splits = { ... 'train': np.array([1, 2, 3, 4, 5, 6], dtype=int), ... 'test': np.array([7, 8, 9, 10], dtype=int), ... } - >>> shard_patients(patients, 6, external_splits, split_fracs_dict=None) + >>> shard_subjects(subjects, 6, external_splits, split_fracs_dict=None) {'train/0': [1, 2, 3, 4, 5, 6], 'test/0': [7, 8, 9, 10]} - >>> shard_patients(patients, 3, external_splits) + >>> shard_subjects(subjects, 3, external_splits) {'train/0': [5, 1, 3], 'train/1': [2, 6, 4], 'test/0': [10, 7], 'test/1': [8, 9]} """ @@ -94,62 +94,62 @@ def shard_patients[ if not isinstance(external_splits[k], np.ndarray): logger.warning( f"External split {k} is not a numpy array and thus type safety is not guaranteed. " - f"Attempting to convert to numpy array of dtype {patients.dtype}." + f"Attempting to convert to numpy array of dtype {subjects.dtype}." ) - external_splits[k] = np.array(external_splits[k], dtype=patients.dtype) + external_splits[k] = np.array(external_splits[k], dtype=subjects.dtype) - patients = np.unique(patients) + subjects = np.unique(subjects) # Splitting all_external_splits = set().union(*external_splits.values()) - is_in_external_split = np.isin(patients, list(all_external_splits)) - patient_ids_to_split = patients[~is_in_external_split] + is_in_external_split = np.isin(subjects, list(all_external_splits)) + subject_ids_to_split = subjects[~is_in_external_split] splits = external_splits rng = np.random.default_rng(seed) - if n_patients := len(patient_ids_to_split): + if n_subjects := len(subject_ids_to_split): if sum(split_fracs_dict.values()) != 1: raise ValueError("The sum of the split fractions must be equal to 1.") split_names_idx = rng.permutation(len(split_fracs_dict)) split_names = np.array(list(split_fracs_dict.keys()))[split_names_idx] split_fracs = np.array([split_fracs_dict[k] for k in split_names]) - split_lens = np.round(split_fracs[:-1] * n_patients).astype(int) - split_lens = np.append(split_lens, n_patients - split_lens.sum()) + split_lens = np.round(split_fracs[:-1] * n_subjects).astype(int) + split_lens = np.append(split_lens, n_subjects - split_lens.sum()) if split_lens.min() == 0: logger.warning( - "Some splits are empty. Adjusting splits to ensure all splits have at least 1 patient." + "Some splits are empty. Adjusting splits to ensure all splits have at least 1 subject." ) max_split = split_lens.argmax() split_lens[max_split] -= 1 split_lens[split_lens.argmin()] += 1 if split_lens.min() == 0: - raise ValueError("Unable to adjust splits to ensure all splits have at least 1 patient.") + raise ValueError("Unable to adjust splits to ensure all splits have at least 1 subject.") - patients = rng.permutation(patient_ids_to_split) - patients_per_split = np.split(patients, split_lens.cumsum()) + subjects = rng.permutation(subject_ids_to_split) + subjects_per_split = np.split(subjects, split_lens.cumsum()) - splits = {**{k: v for k, v in zip(split_names, patients_per_split)}, **splits} + splits = {**{k: v for k, v in zip(split_names, subjects_per_split)}, **splits} else: if split_fracs_dict: logger.warning( - "External splits were provided covering all patients, but split_fracs_dict was not empty. " + "External splits were provided covering all subjects, but split_fracs_dict was not empty. " "Ignoring the split_fracs_dict." ) else: - logger.info("External splits were provided covering all patients.") + logger.info("External splits were provided covering all subjects.") # Sharding final_shards = {} for sp, pts in splits.items(): - if len(pts) <= n_patients_per_shard: + if len(pts) <= n_subjects_per_shard: final_shards[f"{sp}/0"] = pts.tolist() else: pts = rng.permutation(pts) n_pts = len(pts) - n_shards = int(np.ceil(n_pts / n_patients_per_shard)) + n_shards = int(np.ceil(n_pts / n_subjects_per_shard)) shards = np.array_split(pts, n_shards) for i, shard in enumerate(shards): final_shards[f"{sp}/{i}"] = shard.tolist() @@ -157,12 +157,12 @@ def shard_patients[ seen = {} for k, pts in final_shards.items(): - logger.info(f"Split {k} has {len(pts)} patients.") + logger.info(f"Split {k} has {len(pts)} subjects.") for kk, v in seen.items(): shared = set(pts).intersection(v) if shared: - logger.info(f" - intersects {kk} on {len(shared)} patients.") + logger.info(f" - intersects {kk} on {len(shared)} subjects.") seen[k] = set(pts) @@ -171,9 +171,9 @@ def shard_patients[ @hydra.main(version_base=None, config_path=str(CONFIG_YAML.parent), config_name=CONFIG_YAML.stem) def main(cfg: DictConfig): - """Extracts the set of unique patients from the raw data and splits/shards them and saves the result. + """Extracts the set of unique subjects from the raw data and splits/shards them and saves the result. - This stage splits the patients into training, tuning, and held-out sets, and further splits those sets + This stage splits the subjects into training, tuning, and held-out sets, and further splits those sets into shards. All arguments are specified through the command line into the `cfg` object through Hydra. @@ -181,19 +181,19 @@ def main(cfg: DictConfig): The `cfg.stage_cfg` object is a special key that is imputed by OmegaConf to contain the stage-specific configuration arguments based on the global, pipeline-level configuration file. It cannot be overwritten directly on the command line, but can be overwritten implicitly by overwriting components of the - `stage_configs.split_and_shard_patients` key. + `stage_configs.split_and_shard_subjects` key. Args: - stage_configs.split_and_shard_patients.n_patients_per_shard: The maximum number of patients to include - in any shard. Realized shards will not necessarily have this many patients, though they will never - exceed this number. Instead, the number of shards necessary to include all patients in a split - such that no shard exceeds this number will be calculated, then the patients will be evenly, + stage_configs.split_and_shard_subjects.n_subjects_per_shard: The maximum number of subjects to include + in any shard. Realized shards will not necessarily have this many subjects, though they will never + exceed this number. Instead, the number of shards necessary to include all subjects in a split + such that no shard exceeds this number will be calculated, then the subjects will be evenly, randomly split amongst those shards so that all shards within a split have approximately the same number of patietns. - stage_configs.split_and_shard_patients.external_splits_json_fp: The path to a json file containing any + stage_configs.split_and_shard_subjects.external_splits_json_fp: The path to a json file containing any pre-defined splits for specialty held-out test sets beyond the IID held out set that will be produced (e.g., for prospective datasets, etc.). - stage_configs.split_and_shard_patients.split_fracs: The fraction of patients to include in the IID + stage_configs.split_and_shard_subjects.split_fracs: The fraction of subjects to include in the IID training, tuning, and held-out sets. Split fractions can be changed for the default names by adding a hydra-syntax command line argument for the nested name; e.g., `split_fracs.train=0.7 split_fracs.tuning=0.1 split_fracs.held_out=0.2`. A split can be removed with the `~` override @@ -209,38 +209,38 @@ def main(cfg: DictConfig): raise FileNotFoundError(f"Event conversion config file not found: {event_conversion_cfg_fp}") logger.info( - f"Reading event conversion config from {event_conversion_cfg_fp} (needed for patient ID columns)" + f"Reading event conversion config from {event_conversion_cfg_fp} (needed for subject ID columns)" ) event_conversion_cfg = OmegaConf.load(event_conversion_cfg_fp) logger.info(f"Event conversion config:\n{OmegaConf.to_yaml(event_conversion_cfg)}") dfs = [] - default_patient_id_col = event_conversion_cfg.pop("patient_id_col", "patient_id") + default_subject_id_col = event_conversion_cfg.pop("subject_id_col", "subject_id") for input_prefix, event_cfgs in event_conversion_cfg.items(): - input_patient_id_column = event_cfgs.get("patient_id_col", default_patient_id_col) + input_subject_id_column = event_cfgs.get("subject_id_col", default_subject_id_col) input_fps = list((subsharded_dir / input_prefix).glob("**/*.parquet")) input_fps_strs = "\n".join(f" - {str(fp.resolve())}" for fp in input_fps) - logger.info(f"Reading patient IDs from {input_prefix} files:\n{input_fps_strs}") + logger.info(f"Reading subject IDs from {input_prefix} files:\n{input_fps_strs}") for input_fp in input_fps: dfs.append( pl.scan_parquet(input_fp, glob=False) - .select(pl.col(input_patient_id_column).alias("patient_id")) + .select(pl.col(input_subject_id_column).alias("subject_id")) .unique() ) - logger.info(f"Joining all patient IDs from {len(dfs)} dataframes") - patient_ids = ( + logger.info(f"Joining all subject IDs from {len(dfs)} dataframes") + subject_ids = ( pl.concat(dfs) - .select(pl.col("patient_id").drop_nulls().drop_nans().unique()) - .collect(streaming=True)["patient_id"] + .select(pl.col("subject_id").drop_nulls().drop_nans().unique()) + .collect(streaming=True)["subject_id"] .to_numpy(use_pyarrow=True) ) - logger.info(f"Found {len(patient_ids)} unique patient IDs of type {patient_ids.dtype}") + logger.info(f"Found {len(subject_ids)} unique subject IDs of type {subject_ids.dtype}") if cfg.stage_cfg.external_splits_json_fp: external_splits_json_fp = Path(cfg.stage_cfg.external_splits_json_fp) @@ -255,21 +255,21 @@ def main(cfg: DictConfig): else: external_splits = None - logger.info("Sharding and splitting patients") + logger.info("Sharding and splitting subjects") - sharded_patients = shard_patients( - patients=patient_ids, + sharded_subjects = shard_subjects( + subjects=subject_ids, external_splits=external_splits, split_fracs_dict=cfg.stage_cfg.split_fracs, - n_patients_per_shard=cfg.stage_cfg.n_patients_per_shard, + n_subjects_per_shard=cfg.stage_cfg.n_subjects_per_shard, seed=cfg.seed, ) shards_map_fp = Path(cfg.shards_map_fp) - logger.info(f"Writing sharded patients to {str(shards_map_fp.resolve())}") + logger.info(f"Writing sharded subjects to {str(shards_map_fp.resolve())}") shards_map_fp.parent.mkdir(parents=True, exist_ok=True) - shards_map_fp.write_text(json.dumps(sharded_patients)) - logger.info("Done writing sharded patients") + shards_map_fp.write_text(json.dumps(sharded_subjects)) + logger.info("Done writing sharded subjects") if __name__ == "__main__": diff --git a/src/MEDS_transforms/filters/README.md b/src/MEDS_transforms/filters/README.md index 9d582f06..22baa4b7 100644 --- a/src/MEDS_transforms/filters/README.md +++ b/src/MEDS_transforms/filters/README.md @@ -1,5 +1,5 @@ # Filters -Filters remove wholesale events within the data, either at the patient or event level. For transformations +Filters remove wholesale events within the data, either at the subject or event level. For transformations that simply _occlude_ aspects of the data (e.g., by setting a code variable to `UNK`), see the `transforms` library section. diff --git a/src/MEDS_transforms/filters/filter_measurements.py b/src/MEDS_transforms/filters/filter_measurements.py index 36a69387..4c0db29b 100644 --- a/src/MEDS_transforms/filters/filter_measurements.py +++ b/src/MEDS_transforms/filters/filter_measurements.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""A polars-to-polars transformation function for filtering patients by sequence length.""" +"""A polars-to-polars transformation function for filtering subjects by sequence length.""" from collections.abc import Callable import hydra @@ -13,7 +13,7 @@ def filter_measurements_fntr( stage_cfg: DictConfig, code_metadata: pl.LazyFrame, code_modifiers: list[str] | None = None ) -> Callable[[pl.LazyFrame], pl.LazyFrame]: - """Returns a function that filters patient events to only encompass those with a set of permissible codes. + """Returns a function that filters subject events to only encompass those with a set of permissible codes. Args: df: The input DataFrame. @@ -26,44 +26,44 @@ def filter_measurements_fntr( >>> code_metadata_df = pl.DataFrame({ ... "code": ["A", "A", "B", "C"], ... "modifier1": [1, 2, 1, 2], - ... "code/n_patients": [2, 1, 3, 2], + ... "code/n_subjects": [2, 1, 3, 2], ... "code/n_occurrences": [4, 5, 3, 2], ... }) >>> data = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "code": ["A", "B", "A", "C"], ... "modifier1": [1, 1, 2, 2], ... }).lazy() - >>> stage_cfg = DictConfig({"min_patients_per_code": 2, "min_occurrences_per_code": 3}) + >>> stage_cfg = DictConfig({"min_subjects_per_code": 2, "min_occurrences_per_code": 3}) >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) >>> fn(data).collect() shape: (2, 3) ┌────────────┬──────┬───────────┐ - │ patient_id ┆ code ┆ modifier1 │ + │ subject_id ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪═══════════╡ │ 1 ┆ A ┆ 1 │ │ 1 ┆ B ┆ 1 │ └────────────┴──────┴───────────┘ - >>> stage_cfg = DictConfig({"min_patients_per_code": 1, "min_occurrences_per_code": 4}) + >>> stage_cfg = DictConfig({"min_subjects_per_code": 1, "min_occurrences_per_code": 4}) >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) >>> fn(data).collect() shape: (2, 3) ┌────────────┬──────┬───────────┐ - │ patient_id ┆ code ┆ modifier1 │ + │ subject_id ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪═══════════╡ │ 1 ┆ A ┆ 1 │ │ 2 ┆ A ┆ 2 │ └────────────┴──────┴───────────┘ - >>> stage_cfg = DictConfig({"min_patients_per_code": 1}) + >>> stage_cfg = DictConfig({"min_subjects_per_code": 1}) >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) >>> fn(data).collect() shape: (4, 3) ┌────────────┬──────┬───────────┐ - │ patient_id ┆ code ┆ modifier1 │ + │ subject_id ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪═══════════╡ @@ -72,12 +72,12 @@ def filter_measurements_fntr( │ 2 ┆ A ┆ 2 │ │ 2 ┆ C ┆ 2 │ └────────────┴──────┴───────────┘ - >>> stage_cfg = DictConfig({"min_patients_per_code": None, "min_occurrences_per_code": None}) + >>> stage_cfg = DictConfig({"min_subjects_per_code": None, "min_occurrences_per_code": None}) >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) >>> fn(data).collect() shape: (4, 3) ┌────────────┬──────┬───────────┐ - │ patient_id ┆ code ┆ modifier1 │ + │ subject_id ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪═══════════╡ @@ -91,7 +91,7 @@ def filter_measurements_fntr( >>> fn(data).collect() shape: (1, 3) ┌────────────┬──────┬───────────┐ - │ patient_id ┆ code ┆ modifier1 │ + │ subject_id ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪═══════════╡ @@ -99,12 +99,12 @@ def filter_measurements_fntr( └────────────┴──────┴───────────┘ """ - min_patients_per_code = stage_cfg.get("min_patients_per_code", None) + min_subjects_per_code = stage_cfg.get("min_subjects_per_code", None) min_occurrences_per_code = stage_cfg.get("min_occurrences_per_code", None) filter_exprs = [] - if min_patients_per_code is not None: - filter_exprs.append(pl.col("code/n_patients") >= min_patients_per_code) + if min_subjects_per_code is not None: + filter_exprs.append(pl.col("code/n_subjects") >= min_subjects_per_code) if min_occurrences_per_code is not None: filter_exprs.append(pl.col("code/n_occurrences") >= min_occurrences_per_code) @@ -118,10 +118,10 @@ def filter_measurements_fntr( allowed_code_metadata = (code_metadata.filter(pl.all_horizontal(filter_exprs)).select(join_cols)).lazy() def filter_measurements_fn(df: pl.LazyFrame) -> pl.LazyFrame: - f"""Filters patient events to only encompass those with a set of permissible codes. + f"""Filters subject events to only encompass those with a set of permissible codes. In particular, this function filters the DataFrame to only include (code, modifier) pairs that have - at least {min_patients_per_code} patients and {min_occurrences_per_code} occurrences. + at least {min_subjects_per_code} subjects and {min_occurrences_per_code} occurrences. """ idx_col = "_row_idx" diff --git a/src/MEDS_transforms/filters/filter_patients.py b/src/MEDS_transforms/filters/filter_subjects.py similarity index 67% rename from src/MEDS_transforms/filters/filter_patients.py rename to src/MEDS_transforms/filters/filter_subjects.py index 36dc3985..007168d6 100644 --- a/src/MEDS_transforms/filters/filter_patients.py +++ b/src/MEDS_transforms/filters/filter_subjects.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""A polars-to-polars transformation function for filtering patients by sequence length.""" +"""A polars-to-polars transformation function for filtering subjects by sequence length.""" from collections.abc import Callable from functools import partial @@ -12,25 +12,25 @@ from MEDS_transforms.mapreduce.mapper import map_over -def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_patient: int) -> pl.LazyFrame: - """Filters patients by the number of measurements they have. +def filter_subjects_by_num_measurements(df: pl.LazyFrame, min_measurements_per_subject: int) -> pl.LazyFrame: + """Filters subjects by the number of measurements they have. Args: df: The input DataFrame. - min_measurements_per_patient: The minimum number of measurements a patient must have to be included. + min_measurements_per_subject: The minimum number of measurements a subject must have to be included. Returns: The filtered DataFrame. Examples: >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 2, 2, 3], + ... "subject_id": [1, 1, 1, 2, 2, 3], ... "time": [1, 2, 1, 1, 2, 1], ... }) - >>> filter_patients_by_num_measurements(df, 1) + >>> filter_subjects_by_num_measurements(df, 1) shape: (6, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -41,10 +41,10 @@ def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_p │ 2 ┆ 2 │ │ 3 ┆ 1 │ └────────────┴──────┘ - >>> filter_patients_by_num_measurements(df, 2) + >>> filter_subjects_by_num_measurements(df, 2) shape: (5, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -54,10 +54,10 @@ def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_p │ 2 ┆ 1 │ │ 2 ┆ 2 │ └────────────┴──────┘ - >>> filter_patients_by_num_measurements(df, 3) + >>> filter_subjects_by_num_measurements(df, 3) shape: (3, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -65,47 +65,47 @@ def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_p │ 1 ┆ 2 │ │ 1 ┆ 1 │ └────────────┴──────┘ - >>> filter_patients_by_num_measurements(df, 4) + >>> filter_subjects_by_num_measurements(df, 4) shape: (0, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ └────────────┴──────┘ - >>> filter_patients_by_num_measurements(df, 2.2) + >>> filter_subjects_by_num_measurements(df, 2.2) Traceback (most recent call last): ... - TypeError: min_measurements_per_patient must be an integer; got 2.2 + TypeError: min_measurements_per_subject must be an integer; got 2.2 """ - if not isinstance(min_measurements_per_patient, int): + if not isinstance(min_measurements_per_subject, int): raise TypeError( - f"min_measurements_per_patient must be an integer; got {type(min_measurements_per_patient)} " - f"{min_measurements_per_patient}" + f"min_measurements_per_subject must be an integer; got {type(min_measurements_per_subject)} " + f"{min_measurements_per_subject}" ) - return df.filter(pl.col("time").count().over("patient_id") >= min_measurements_per_patient) + return df.filter(pl.col("time").count().over("subject_id") >= min_measurements_per_subject) -def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) -> pl.LazyFrame: - """Filters patients by the number of events (unique timepoints) they have. +def filter_subjects_by_num_events(df: pl.LazyFrame, min_events_per_subject: int) -> pl.LazyFrame: + """Filters subjects by the number of events (unique timepoints) they have. Args: df: The input DataFrame. - min_events_per_patient: The minimum number of events a patient must have to be included. + min_events_per_subject: The minimum number of events a subject must have to be included. Returns: The filtered DataFrame. Examples: >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4], + ... "subject_id": [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4], ... "time": [1, 1, 1, 1, 2, 1, 1, 2, 3, None, None, 1, 2, 3], ... }) - >>> filter_patients_by_num_events(df, 1) + >>> filter_subjects_by_num_events(df, 1) shape: (14, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -124,10 +124,10 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) │ 4 ┆ 2 │ │ 4 ┆ 3 │ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 2) + >>> filter_subjects_by_num_events(df, 2) shape: (11, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -143,10 +143,10 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) │ 4 ┆ 2 │ │ 4 ┆ 3 │ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 3) + >>> filter_subjects_by_num_events(df, 3) shape: (8, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -159,10 +159,10 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) │ 4 ┆ 2 │ │ 4 ┆ 3 │ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 4) + >>> filter_subjects_by_num_events(df, 4) shape: (5, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -172,48 +172,48 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) │ 4 ┆ 2 │ │ 4 ┆ 3 │ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 5) + >>> filter_subjects_by_num_events(df, 5) shape: (0, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 2.2) + >>> filter_subjects_by_num_events(df, 2.2) Traceback (most recent call last): ... - TypeError: min_events_per_patient must be an integer; got 2.2 + TypeError: min_events_per_subject must be an integer; got 2.2 """ - if not isinstance(min_events_per_patient, int): + if not isinstance(min_events_per_subject, int): raise TypeError( - f"min_events_per_patient must be an integer; got {type(min_events_per_patient)} " - f"{min_events_per_patient}" + f"min_events_per_subject must be an integer; got {type(min_events_per_subject)} " + f"{min_events_per_subject}" ) - return df.filter(pl.col("time").n_unique().over("patient_id") >= min_events_per_patient) + return df.filter(pl.col("time").n_unique().over("subject_id") >= min_events_per_subject) -def filter_patients_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.LazyFrame]: +def filter_subjects_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.LazyFrame]: compute_fns = [] - if stage_cfg.min_measurements_per_patient: + if stage_cfg.min_measurements_per_subject: logger.info( - f"Filtering patients with fewer than {stage_cfg.min_measurements_per_patient} measurements " + f"Filtering subjects with fewer than {stage_cfg.min_measurements_per_subject} measurements " "(observations of any kind)." ) compute_fns.append( partial( - filter_patients_by_num_measurements, - min_measurements_per_patient=stage_cfg.min_measurements_per_patient, + filter_subjects_by_num_measurements, + min_measurements_per_subject=stage_cfg.min_measurements_per_subject, ) ) - if stage_cfg.min_events_per_patient: + if stage_cfg.min_events_per_subject: logger.info( - f"Filtering patients with fewer than {stage_cfg.min_events_per_patient} events " + f"Filtering subjects with fewer than {stage_cfg.min_events_per_subject} events " "(unique timepoints)." ) compute_fns.append( - partial(filter_patients_by_num_events, min_events_per_patient=stage_cfg.min_events_per_patient) + partial(filter_subjects_by_num_events, min_events_per_subject=stage_cfg.min_events_per_subject) ) def fn(data: pl.LazyFrame) -> pl.LazyFrame: @@ -230,7 +230,7 @@ def fn(data: pl.LazyFrame) -> pl.LazyFrame: def main(cfg: DictConfig): """TODO.""" - map_over(cfg, compute_fn=filter_patients_fntr) + map_over(cfg, compute_fn=filter_subjects_fntr) if __name__ == "__main__": diff --git a/src/MEDS_transforms/mapreduce/mapper.py b/src/MEDS_transforms/mapreduce/mapper.py index 76490826..9888d9da 100644 --- a/src/MEDS_transforms/mapreduce/mapper.py +++ b/src/MEDS_transforms/mapreduce/mapper.py @@ -309,7 +309,7 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO compute function with the ``local_arg_1=baz`` parameter. Both of these local compute functions will be applied to the input DataFrame in sequence, and the resulting DataFrames will be concatenated alongside any of the dataframe that matches no matcher (which will be left unmodified) and merged in a sorted way - that respects the ``patient_id``, ``time`` ordering first, then the order of the match & revise blocks + that respects the ``subject_id``, ``time`` ordering first, then the order of the match & revise blocks themselves, then the order of the rows in each match & revise block output. Each local compute function will also use the ``global_arg_1=foo`` parameter. @@ -331,7 +331,7 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO Examples: >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 2, 2, 2], + ... "subject_id": [1, 1, 1, 2, 2, 2], ... "time": [1, 2, 2, 1, 1, 2], ... "initial_idx": [0, 1, 2, 3, 4, 5], ... "code": ["FINAL", "CODE//TEMP_2", "CODE//TEMP_1", "FINAL", "CODE//TEMP_2", "CODE//TEMP_1"] @@ -353,7 +353,7 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO >>> match_revise_fn(df.lazy()).collect() shape: (6, 4) ┌────────────┬──────┬─────────────┬────────────────┐ - │ patient_id ┆ time ┆ initial_idx ┆ code │ + │ subject_id ┆ time ┆ initial_idx ┆ code │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ i64 ┆ str │ ╞════════════╪══════╪═════════════╪════════════════╡ @@ -376,7 +376,7 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO >>> match_revise_fn(df.lazy()).collect() shape: (6, 4) ┌────────────┬──────┬─────────────┬─────────────────┐ - │ patient_id ┆ time ┆ initial_idx ┆ code │ + │ subject_id ┆ time ┆ initial_idx ┆ code │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ i64 ┆ str │ ╞════════════╪══════╪═════════════╪═════════════════╡ @@ -397,7 +397,7 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO ... ValueError: Missing needed columns {'missing'} for local matcher 0: [(col("missing")) == (String(CODE//TEMP_2))].all_horizontal() - Columns available: 'code', 'initial_idx', 'patient_id', 'time' + Columns available: 'code', 'initial_idx', 'subject_id', 'time' >>> stage_cfg = DictConfig({"global_code_end": "foo"}) >>> cfg = DictConfig({"stage_cfg": stage_cfg}) >>> match_revise_fn = match_revise_fntr(cfg, stage_cfg, compute_fn) @@ -439,7 +439,7 @@ def match_revise_fn(df: DF_T) -> DF_T: revision_parts.append(local_compute_fn(matched_df)) revision_parts.append(unmatched_df) - return pl.concat(revision_parts, how="vertical").sort(["patient_id", "time"], maintain_order=True) + return pl.concat(revision_parts, how="vertical").sort(["subject_id", "time"], maintain_order=True) return match_revise_fn @@ -580,7 +580,7 @@ def map_over( start = datetime.now() train_only = cfg.stage_cfg.get("train_only", False) - split_fp = Path(cfg.input_dir) / "metadata" / "patient_split.parquet" + split_fp = Path(cfg.input_dir) / "metadata" / "subject_split.parquet" shards, includes_only_train = shard_iterator_fntr(cfg) @@ -591,18 +591,18 @@ def map_over( ) elif split_fp.exists(): logger.info(f"Processing train split only by filtering read dfs via {str(split_fp.resolve())}") - train_patients = ( + train_subjects = ( pl.scan_parquet(split_fp) .filter(pl.col("split") == "train") - .select(pl.col("patient_id")) + .select(pl.col("subject_id")) .collect() .to_list() ) - read_fn = read_and_filter_fntr(train_patients, read_fn) + read_fn = read_and_filter_fntr(train_subjects, read_fn) else: raise FileNotFoundError( f"Train split requested, but shard prefixes can't be used and " - f"patient split file not found at {str(split_fp.resolve())}." + f"subject split file not found at {str(split_fp.resolve())}." ) elif includes_only_train: raise ValueError("All splits should be used, but shard iterator is returning only train splits?!?") diff --git a/src/MEDS_transforms/mapreduce/utils.py b/src/MEDS_transforms/mapreduce/utils.py index 4aa4c535..2832c000 100644 --- a/src/MEDS_transforms/mapreduce/utils.py +++ b/src/MEDS_transforms/mapreduce/utils.py @@ -315,14 +315,14 @@ def shard_iterator( >>> from tempfile import TemporaryDirectory >>> import polars as pl >>> df = pl.DataFrame({ - ... "patient_id": [1, 2, 3, 4, 5, 6, 7, 8, 9], + ... "subject_id": [1, 2, 3, 4, 5, 6, 7, 8, 9], ... "code": ["A", "B", "C", "D", "E", "F", "G", "H", "I"], ... "time": [1, 2, 3, 4, 5, 6, 1, 2, 3], ... }) >>> shards = {"train/0": [1, 2, 3, 4], "train/1": [5, 6, 7], "tuning": [8], "held_out": [9]} >>> def write_dfs(input_dir: Path, df: pl.DataFrame=df, shards: dict=shards, sfx: str=".parquet"): - ... for shard_name, patient_ids in shards.items(): - ... df = df.filter(pl.col("patient_id").is_in(patient_ids)) + ... for shard_name, subject_ids in shards.items(): + ... df = df.filter(pl.col("subject_id").is_in(subject_ids)) ... shard_fp = input_dir / f"{shard_name}{sfx}" ... shard_fp.parent.mkdir(exist_ok=True, parents=True) ... if sfx == ".parquet": df.write_parquet(shard_fp) @@ -485,7 +485,7 @@ def shard_iterator( elif train_only: logger.info( f"train_only={train_only} requested but no dedicated train shards found; processing all shards " - "and relying on `patient_splits.parquet` for filtering." + "and relying on `subject_splits.parquet` for filtering." ) shards = shuffle_shards(shards, cfg) diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index f9361965..fb3358c1 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -13,7 +13,7 @@ from omegaconf import DictConfig from MEDS_transforms import PREPROCESS_CONFIG_YAML -from MEDS_transforms.extract.split_and_shard_patients import shard_patients +from MEDS_transforms.extract.split_and_shard_subjects import shard_subjects from MEDS_transforms.mapreduce.utils import rwlock_wrap, shard_iterator, shuffle_shards from MEDS_transforms.utils import stage_init, write_lazyframe @@ -34,9 +34,9 @@ def make_new_shards_fn(df: pl.DataFrame, cfg: DictConfig, stage_cfg: DictConfig) for pt_id, sp in df.iter_rows(): splits_map[sp].append(pt_id) - return shard_patients( - patients=df["patient_id"].to_numpy(), - n_patients_per_shard=stage_cfg.n_patients_per_shard, + return shard_subjects( + subjects=df["subject_id"].to_numpy(), + n_subjects_per_shard=stage_cfg.n_subjects_per_shard, external_splits=splits_map, split_fracs_dict=None, seed=cfg.get("seed", 1), @@ -51,13 +51,13 @@ def write_json(d: dict, fp: Path) -> None: version_base=None, config_path=str(PREPROCESS_CONFIG_YAML.parent), config_name=PREPROCESS_CONFIG_YAML.stem ) def main(cfg: DictConfig): - """Re-shard a MEDS cohort to in a manner that subdivides patient splits.""" + """Re-shard a MEDS cohort to in a manner that subdivides subject splits.""" stage_init(cfg) output_dir = Path(cfg.stage_cfg.output_dir) - splits_file = Path(cfg.input_dir) / "metadata" / "patient_splits.parquet" + splits_file = Path(cfg.input_dir) / "metadata" / "subject_splits.parquet" shards_fp = output_dir / ".shards.json" rwlock_wrap( @@ -92,22 +92,22 @@ def main(cfg: DictConfig): logger.info("Starting sub-sharding") for subshard_name, out_fp in new_shards_iter: - patients = new_sharded_splits[subshard_name] + subjects = new_sharded_splits[subshard_name] def read_fn(input_dir: Path) -> pl.LazyFrame: df = None logger.info(f"Reading shards for {subshard_name} (file names are in the input sharding scheme):") for in_fp, _ in orig_shards_iter: logger.info(f" - {str(in_fp.relative_to(input_dir).resolve())}") - new_df = pl.scan_parquet(in_fp, glob=False).filter(pl.col("patient_id").is_in(patients)) + new_df = pl.scan_parquet(in_fp, glob=False).filter(pl.col("subject_id").is_in(subjects)) if df is None: df = new_df else: - df = df.merge_sorted(new_df, key="patient_id") + df = df.merge_sorted(new_df, key="subject_id") return df def compute_fn(df: list[pl.DataFrame]) -> pl.LazyFrame: - return df.sort(by=["patient_id", "time"], maintain_order=True, multithreaded=False) + return df.sort(by=["subject_id", "time"], maintain_order=True, multithreaded=False) def write_fn(df: pl.LazyFrame, out_fp: Path) -> None: write_lazyframe(df, out_fp) diff --git a/src/MEDS_transforms/transforms/add_time_derived_measurements.py b/src/MEDS_transforms/transforms/add_time_derived_measurements.py index 19a7abf0..3f48a8cd 100644 --- a/src/MEDS_transforms/transforms/add_time_derived_measurements.py +++ b/src/MEDS_transforms/transforms/add_time_derived_measurements.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""Transformations for adding time-derived measurements (e.g., a patient's age) to a MEDS dataset.""" +"""Transformations for adding time-derived measurements (e.g., a subject's age) to a MEDS dataset.""" from collections.abc import Callable import hydra @@ -25,7 +25,7 @@ def add_new_events_fntr(fn: Callable[[pl.DataFrame], pl.DataFrame]) -> Callable[ >>> from datetime import datetime >>> df = pl.DataFrame( ... { - ... "patient_id": [1, 1, 1, 1, 2, 2, 3, 3], + ... "subject_id": [1, 1, 1, 1, 2, 2, 3, 3], ... "time": [ ... None, ... datetime(1990, 1, 1), @@ -38,12 +38,12 @@ def add_new_events_fntr(fn: Callable[[pl.DataFrame], pl.DataFrame]) -> Callable[ ... ], ... "code": ["static", "DOB", "lab//A", "lab//B", "DOB", "lab//A", "lab//B", "dx//1"], ... }, - ... schema={"patient_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, + ... schema={"subject_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, ... ) >>> df shape: (8, 3) ┌────────────┬─────────────────────┬────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str │ ╞════════════╪═════════════════════╪════════╡ @@ -62,7 +62,7 @@ def add_new_events_fntr(fn: Callable[[pl.DataFrame], pl.DataFrame]) -> Callable[ >>> age_fn(df) shape: (2, 4) ┌────────────┬─────────────────────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str ┆ f32 │ ╞════════════╪═════════════════════╪══════╪═══════════════╡ @@ -74,7 +74,7 @@ def add_new_events_fntr(fn: Callable[[pl.DataFrame], pl.DataFrame]) -> Callable[ >>> add_age_fn(df) shape: (10, 4) ┌────────────┬─────────────────────┬────────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str ┆ f32 │ ╞════════════╪═════════════════════╪════════╪═══════════════╡ @@ -96,7 +96,7 @@ def out_fn(df: pl.DataFrame) -> pl.DataFrame: df = df.with_row_index("__idx") new_events = new_events.with_columns(pl.lit(0, dtype=df.schema["__idx"]).alias("__idx")) return ( - pl.concat([df, new_events], how="diagonal").sort(by=["patient_id", "time", "__idx"]).drop("__idx") + pl.concat([df, new_events], how="diagonal").sort(by=["subject_id", "time", "__idx"]).drop("__idx") ) return out_fn @@ -170,7 +170,7 @@ def normalize_time_unit(unit: str) -> tuple[str, float]: def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: - """Create a function that adds a patient's age to a DataFrame. + """Create a function that adds a subject's age to a DataFrame. Args: cfg: The configuration for the age function. This must contain the following mandatory keys: @@ -179,8 +179,8 @@ def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: - "age_unit": The unit for the age event when converted to a numeric value in the output data. Returns: - A function that returns the to-be-added "age" events with the patient's age for all input events with - unique, non-null times in the data, for all patients who have an observed date of birth. It does + A function that returns the to-be-added "age" events with the subject's age for all input events with + unique, non-null times in the data, for all subjects who have an observed date of birth. It does not add an event for times that are equal to the date of birth. Raises: @@ -190,7 +190,7 @@ def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: >>> from datetime import datetime >>> df = pl.DataFrame( ... { - ... "patient_id": [1, 1, 1, 1, 1, 2, 2, 3, 3], + ... "subject_id": [1, 1, 1, 1, 1, 2, 2, 3, 3], ... "time": [ ... None, ... datetime(1990, 1, 1), @@ -204,12 +204,12 @@ def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: ... ], ... "code": ["static", "DOB", "lab//A", "lab//B", "rx", "DOB", "lab//A", "lab//B", "dx//1"], ... }, - ... schema={"patient_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, + ... schema={"subject_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, ... ) >>> df shape: (9, 3) ┌────────────┬─────────────────────┬────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str │ ╞════════════╪═════════════════════╪════════╡ @@ -228,7 +228,7 @@ def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: >>> age_fn(df) shape: (3, 4) ┌────────────┬─────────────────────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str ┆ f32 │ ╞════════════╪═════════════════════╪══════╪═══════════════╡ @@ -248,15 +248,15 @@ def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: microseconds_in_unit = int(1e6) * seconds_in_unit def fn(df: pl.LazyFrame) -> pl.LazyFrame: - dob_expr = pl.when(pl.col("code") == cfg.DOB_code).then(pl.col("time")).min().over("patient_id") + dob_expr = pl.when(pl.col("code") == cfg.DOB_code).then(pl.col("time")).min().over("subject_id") age_expr = (pl.col("time") - dob_expr).dt.total_microseconds() / microseconds_in_unit age_expr = age_expr.cast(pl.Float32, strict=False) return ( df.drop_nulls(subset=["time"]) - .unique(subset=["patient_id", "time"], maintain_order=True) + .unique(subset=["subject_id", "time"], maintain_order=True) .select( - "patient_id", + "subject_id", "time", pl.lit(cfg.age_code, dtype=df.schema["code"]).alias("code"), age_expr.alias("numeric_value"), @@ -283,7 +283,7 @@ def time_of_day_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: >>> from datetime import datetime >>> df = pl.DataFrame( ... { - ... "patient_id": [1, 1, 1, 1, 2, 2, 3, 3], + ... "subject_id": [1, 1, 1, 1, 2, 2, 3, 3], ... "time": [ ... None, ... datetime(1990, 1, 1, 1, 0), @@ -296,12 +296,12 @@ def time_of_day_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: ... ], ... "code": ["static", "DOB", "lab//A", "lab//B", "DOB", "lab//A", "lab//B", "dx//1"], ... }, - ... schema={"patient_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, + ... schema={"subject_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, ... ) >>> df shape: (8, 3) ┌────────────┬─────────────────────┬────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str │ ╞════════════╪═════════════════════╪════════╡ @@ -319,7 +319,7 @@ def time_of_day_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: >>> time_of_day_fn(df) shape: (6, 3) ┌────────────┬─────────────────────┬──────────────────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str │ ╞════════════╪═════════════════════╪══════════════════════╡ @@ -357,8 +357,8 @@ def tod_code(start: int, end: int) -> str: time_of_day = time_of_day.when(hour >= end).then(tod_code(end, 24)) return ( df.drop_nulls(subset=["time"]) - .unique(subset=["patient_id", "time"], maintain_order=True) - .select("patient_id", "time", time_of_day.alias("code")) + .unique(subset=["subject_id", "time"], maintain_order=True) + .select("subject_id", "time", time_of_day.alias("code")) ) return fn diff --git a/src/MEDS_transforms/transforms/normalization.py b/src/MEDS_transforms/transforms/normalization.py index fbad9acc..363192aa 100644 --- a/src/MEDS_transforms/transforms/normalization.py +++ b/src/MEDS_transforms/transforms/normalization.py @@ -16,7 +16,7 @@ def normalize( """Normalize a MEDS dataset across both categorical and continuous dimensions. This function expects a MEDS dataset in flattened form, with columns for: - - `patient_id` + - `subject_id` - `time` - `code` - `numeric_value` @@ -61,7 +61,7 @@ def normalize( >>> from datetime import datetime >>> MEDS_df = pl.DataFrame( ... { - ... "patient_id": [1, 1, 1, 2, 2, 2, 3], + ... "subject_id": [1, 1, 1, 2, 2, 2, 3], ... "time": [ ... datetime(2021, 1, 1), ... datetime(2021, 1, 1), @@ -76,7 +76,7 @@ def normalize( ... "unit": ["mg/dL", "g/dL", None, "mg/dL", None, None, None], ... }, ... schema = { - ... "patient_id": pl.UInt32, + ... "subject_id": pl.UInt32, ... "time": pl.Datetime, ... "code": pl.Utf8, ... "numeric_value": pl.Float64, @@ -100,7 +100,7 @@ def normalize( >>> normalize(MEDS_df.lazy(), code_metadata).collect() shape: (6, 4) ┌────────────┬─────────────────────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ u32 ┆ f64 │ ╞════════════╪═════════════════════╪══════╪═══════════════╡ @@ -113,7 +113,7 @@ def normalize( └────────────┴─────────────────────┴──────┴───────────────┘ >>> MEDS_df = pl.DataFrame( ... { - ... "patient_id": [1, 1, 1, 2, 2, 2, 3], + ... "subject_id": [1, 1, 1, 2, 2, 2, 3], ... "time": [ ... datetime(2021, 1, 1), ... datetime(2021, 1, 1), @@ -128,7 +128,7 @@ def normalize( ... "unit": ["mg/dL", "g/dL", None, "mg/dL", None, None, None], ... }, ... schema = { - ... "patient_id": pl.UInt32, + ... "subject_id": pl.UInt32, ... "time": pl.Datetime, ... "code": pl.Utf8, ... "numeric_value": pl.Float64, @@ -154,7 +154,7 @@ def normalize( >>> normalize(MEDS_df.lazy(), code_metadata, ["unit"]).collect() shape: (6, 4) ┌────────────┬─────────────────────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ u32 ┆ f64 │ ╞════════════╪═════════════════════╪══════╪═══════════════╡ @@ -201,7 +201,7 @@ def normalize( ) .select( idx_col, - "patient_id", + "subject_id", "time", pl.col("code/vocab_index").alias("code"), ((pl.col("numeric_value") - pl.col("values/mean")) / pl.col("values/std")).alias("numeric_value"), diff --git a/src/MEDS_transforms/transforms/occlude_outliers.py b/src/MEDS_transforms/transforms/occlude_outliers.py index 107407de..528977ba 100644 --- a/src/MEDS_transforms/transforms/occlude_outliers.py +++ b/src/MEDS_transforms/transforms/occlude_outliers.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""A polars-to-polars transformation function for filtering patients by sequence length.""" +"""A polars-to-polars transformation function for filtering subjects by sequence length.""" from collections.abc import Callable import hydra @@ -13,7 +13,7 @@ def occlude_outliers_fntr( stage_cfg: DictConfig, code_metadata: pl.LazyFrame, code_modifiers: list[str] | None = None ) -> Callable[[pl.LazyFrame], pl.LazyFrame]: - """Filters patient events to only encompass those with a set of permissible codes. + """Filters subject events to only encompass those with a set of permissible codes. Args: df: The input DataFrame. @@ -33,7 +33,7 @@ def occlude_outliers_fntr( ... # for clarity: --- stddev = [3.0, 0.0, 3.0, 1.0] ... }) >>> data = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "code": ["A", "B", "A", "C"], ... "modifier1": [1, 1, 2, 2], ... # for clarity: mean [0.0, 4.0, 4.0, 1.0] @@ -45,7 +45,7 @@ def occlude_outliers_fntr( >>> fn(data).collect() shape: (4, 5) ┌────────────┬──────┬───────────┬───────────────┬─────────────────────────┐ - │ patient_id ┆ code ┆ modifier1 ┆ numeric_value ┆ numeric_value/is_inlier │ + │ subject_id ┆ code ┆ modifier1 ┆ numeric_value ┆ numeric_value/is_inlier │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 ┆ f64 ┆ bool │ ╞════════════╪══════╪═══════════╪═══════════════╪═════════════════════════╡ @@ -78,7 +78,7 @@ def occlude_outliers_fntr( code_metadata = code_metadata.lazy().select(cols_to_select) def occlude_outliers_fn(df: pl.LazyFrame) -> pl.LazyFrame: - f"""Filters out outlier numeric values from patient events. + f"""Filters out outlier numeric values from subject events. In particular, this function filters the DataFrame to only include numeric values that are within {stddev_cutoff} standard deviations of the mean for the corresponding (code, modifier) pair. diff --git a/src/MEDS_transforms/transforms/reorder_measurements.py b/src/MEDS_transforms/transforms/reorder_measurements.py index 1205f771..5218551e 100644 --- a/src/MEDS_transforms/transforms/reorder_measurements.py +++ b/src/MEDS_transforms/transforms/reorder_measurements.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""A polars-to-polars transformation function for filtering patients by sequence length.""" +"""A polars-to-polars transformation function for filtering subjects by sequence length.""" from collections.abc import Callable import hydra @@ -19,7 +19,7 @@ def reorder_by_code_fntr( Args: stage_cfg: The stage-specific configuration object which contains the `ordered_code_patterns` field - that defines the order of the codes within each patient event (unique timepoint). Each element of + that defines the order of the codes within each subject event (unique timepoint). Each element of this list should be a regex pattern that matches codes that should be re-ordered at the index of the regex pattern in the list. Codes are matched in the order of the list, and if a code matches multiple regex patterns, it will be ordered by the first regex pattern that matches it. @@ -34,7 +34,7 @@ def reorder_by_code_fntr( Examples: >>> code_metadata_df = pl.DataFrame({"code": ["A", "A", "B", "C"], "modifier1": [1, 2, 1, 2]}) >>> data = pl.DataFrame({ - ... "patient_id":[1, 1, 2, 2], "time": [1, 1, 1, 1], + ... "subject_id":[1, 1, 2, 2], "time": [1, 1, 1, 1], ... "code": ["A", "B", "A", "C"], "modifier1": [1, 2, 1, 2] ... }) >>> stage_cfg = DictConfig({"ordered_code_patterns": ["B", "A"]}) @@ -42,7 +42,7 @@ def reorder_by_code_fntr( >>> fn(data.lazy()).collect() shape: (4, 4) ┌────────────┬──────┬──────┬───────────┐ - │ patient_id ┆ time ┆ code ┆ modifier1 │ + │ subject_id ┆ time ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪══════╪═══════════╡ @@ -55,7 +55,7 @@ def reorder_by_code_fntr( ... "code": ["LAB//foo", "ADMISSION//bar", "LAB//baz", "ADMISSION//qux", "DISCHARGE"], ... }) >>> data = pl.DataFrame({ - ... "patient_id":[1, 1, 1, 2, 2, 2], + ... "subject_id":[1, 1, 1, 2, 2, 2], ... "time": [1, 1, 1, 1, 2, 3], ... "code": ["LAB//foo", "ADMISSION//bar", "LAB//baz", "ADMISSION//qux", "DISCHARGE", "LAB//baz"], ... }) @@ -66,7 +66,7 @@ def reorder_by_code_fntr( >>> fn(data.lazy()).collect() shape: (6, 3) ┌────────────┬──────┬────────────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str │ ╞════════════╪══════╪════════════════╡ @@ -81,7 +81,7 @@ def reorder_by_code_fntr( >>> fn(data.lazy()).collect() shape: (6, 3) ┌────────────┬──────┬────────────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str │ ╞════════════╪══════╪════════════════╡ @@ -141,7 +141,7 @@ def reorder_fn(df: pl.LazyFrame) -> pl.LazyFrame: return ( df.join(code_indices, on=join_cols, how="left", coalesce=True) - .sort("patient_id", "time", "code_order_idx", maintain_order=True) + .sort("subject_id", "time", "code_order_idx", maintain_order=True) .drop("code_order_idx") ) @@ -152,13 +152,13 @@ def reorder_fn(df: pl.LazyFrame) -> pl.LazyFrame: version_base=None, config_path=str(PREPROCESS_CONFIG_YAML.parent), config_name=PREPROCESS_CONFIG_YAML.stem ) def main(cfg: DictConfig): - """Reorders measurements within each patient event (unique timepoint) by the specified code order. + """Reorders measurements within each subject event (unique timepoint) by the specified code order. In particular, given a set of [regex crate compatible](https://docs.rs/regex/latest/regex/) regexes in the `stage_cfg.ordered_code_patterns` list, this script will re-order the measurements within each event (unique timepoint) such that the measurements are sorted by the index of the first regex that matches their code in the `ordered_code_patterns` list. So, if the `ordered_code_patterns` list is - `["foo$", "bar", "foo.*"]`, and a single patient event has measurements with codes + `["foo$", "bar", "foo.*"]`, and a single subject event has measurements with codes `["foobar", "barbaz", "foo", "quat"]`, the measurements will be re-ordered to the order: `["foo", "foobar", "barbaz", "quat"]`, because: - "foo" matches the first regex in the list (the `foo$` matches any string with "foo" at the end). @@ -176,7 +176,7 @@ def main(cfg: DictConfig): Args: stage_configs.reorder_measurements.ordered_code_patterns: A list of regex patterns that specify the - order of the codes within each patient event (unique timepoint). To specify this on the command + order of the codes within each subject event (unique timepoint). To specify this on the command line, use the hydra list syntax by enclosing the entire key-value string argument in single quotes: ``'stage_configs.reorder_measurements.ordered_code_patterns=["foo$", "bar", "foo.*"]'``. """ diff --git a/src/MEDS_transforms/transforms/tensorization.py b/src/MEDS_transforms/transforms/tensorization.py index 0266ce21..92fa7e00 100644 --- a/src/MEDS_transforms/transforms/tensorization.py +++ b/src/MEDS_transforms/transforms/tensorization.py @@ -32,7 +32,7 @@ def convert_to_NRT(df: pl.LazyFrame) -> JointNestedRaggedTensorDict: Examples: >>> df = pl.DataFrame({ - ... "patient_id": [1, 2], + ... "subject_id": [1, 2], ... "time_delta_days": [[float("nan"), 12.0], [float("nan")]], ... "code": [[[101.0, 102.0], [103.0]], [[201.0, 202.0]]], ... "numeric_value": [[[2.0, 3.0], [4.0]], [[6.0, 7.0]]] @@ -40,7 +40,7 @@ def convert_to_NRT(df: pl.LazyFrame) -> JointNestedRaggedTensorDict: >>> df shape: (2, 4) ┌────────────┬─────────────────┬───────────────────────────┬─────────────────────┐ - │ patient_id ┆ time_delta_days ┆ code ┆ numeric_value │ + │ subject_id ┆ time_delta_days ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ list[f64] ┆ list[list[f64]] ┆ list[list[f64]] │ ╞════════════╪═════════════════╪═══════════════════════════╪═════════════════════╡ diff --git a/src/MEDS_transforms/transforms/tokenization.py b/src/MEDS_transforms/transforms/tokenization.py index d6f5003f..31cb2643 100644 --- a/src/MEDS_transforms/transforms/tokenization.py +++ b/src/MEDS_transforms/transforms/tokenization.py @@ -6,7 +6,7 @@ All these functions take in _normalized_ data -- meaning data where there are _no longer_ any code modifiers, as those have been normalized alongside codes into integer indices (in the output code column). The only -columns of concern here thus are `patient_id`, `time`, `code`, `numeric_value`. +columns of concern here thus are `subject_id`, `time`, `code`, `numeric_value`. """ from pathlib import Path @@ -69,7 +69,7 @@ def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFra Examples: >>> from datetime import datetime >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "time": [None, datetime(2021, 1, 1), None, datetime(2021, 1, 2)], ... "code": [100, 101, 200, 201], ... "numeric_value": [1.0, 2.0, 3.0, 4.0] @@ -78,7 +78,7 @@ def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFra >>> static.collect() shape: (2, 3) ┌────────────┬──────┬───────────────┐ - │ patient_id ┆ code ┆ numeric_value │ + │ subject_id ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ f64 │ ╞════════════╪══════╪═══════════════╡ @@ -88,7 +88,7 @@ def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFra >>> dynamic.collect() shape: (2, 4) ┌────────────┬─────────────────────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ datetime[μs] ┆ i64 ┆ f64 │ ╞════════════╪═════════════════════╪══════╪═══════════════╡ @@ -103,19 +103,19 @@ def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFra def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: - """This function extracts static data and schema information (sequence of patient unique times). + """This function extracts static data and schema information (sequence of subject unique times). Args: df: The input data. Returns: - A `pl.LazyFrame` object containing the static data and the unique times of the patient, grouped - by patient as lists, in the same order as the patient IDs occurred in the original file. + A `pl.LazyFrame` object containing the static data and the unique times of the subject, grouped + by subject as lists, in the same order as the subject IDs occurred in the original file. Examples: >>> from datetime import datetime >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 1, 2, 2, 2], + ... "subject_id": [1, 1, 1, 1, 2, 2, 2], ... "time": [ ... None, datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 13), ... None, datetime(2021, 1, 2), datetime(2021, 1, 2)], @@ -126,17 +126,17 @@ def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: >>> df.drop("time") shape: (2, 4) ┌────────────┬───────────┬───────────────┬─────────────────────┐ - │ patient_id ┆ code ┆ numeric_value ┆ start_time │ + │ subject_id ┆ code ┆ numeric_value ┆ start_time │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ list[i64] ┆ list[f64] ┆ datetime[μs] │ ╞════════════╪═══════════╪═══════════════╪═════════════════════╡ │ 1 ┆ [100] ┆ [1.0] ┆ 2021-01-01 00:00:00 │ │ 2 ┆ [200] ┆ [5.0] ┆ 2021-01-02 00:00:00 │ └────────────┴───────────┴───────────────┴─────────────────────┘ - >>> df.select("patient_id", "time").explode("time") + >>> df.select("subject_id", "time").explode("time") shape: (3, 2) ┌────────────┬─────────────────────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ datetime[μs] │ ╞════════════╪═════════════════════╡ @@ -148,21 +148,21 @@ def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: static, dynamic = split_static_and_dynamic(df) - # This collects static data by patient ID and stores only (as a list) the codes and numeric values. - static_by_patient = static.group_by("patient_id", maintain_order=True).agg("code", "numeric_value") + # This collects static data by subject ID and stores only (as a list) the codes and numeric values. + static_by_subject = static.group_by("subject_id", maintain_order=True).agg("code", "numeric_value") - # This collects the unique times for each patient. - schema_by_patient = dynamic.group_by("patient_id", maintain_order=True).agg( + # This collects the unique times for each subject. + schema_by_subject = dynamic.group_by("subject_id", maintain_order=True).agg( pl.col("time").min().alias("start_time"), pl.col("time").unique(maintain_order=True) ) - # TODO(mmd): Consider tracking patient offset explicitly here. + # TODO(mmd): Consider tracking subject offset explicitly here. - return static_by_patient.join(schema_by_patient, on="patient_id", how="inner") + return static_by_subject.join(schema_by_subject, on="subject_id", how="inner") -def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: - """This function extracts sequences of patient events, which are sequences of measurements. +def extract_seq_of_subject_events(df: pl.LazyFrame) -> pl.LazyFrame: + """This function extracts sequences of subject events, which are sequences of measurements. The result of this can be naturally tensorized into a `JointNestedRaggedTensorDict` object. @@ -170,8 +170,8 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: df: The input data. Returns: - A `pl.LazyFrame` object containing the sequences of patient events, with the following columns: - - `patient_id`: The patient ID. + A `pl.LazyFrame` object containing the sequences of subject events, with the following columns: + - `subject_id`: The subject ID. - `time_delta_days`: The time delta in days, as a list of floats (ragged). - `code`: The code, as a list of lists of ints (ragged in both levels). - `numeric_value`: The numeric value as a list of lists of floats (ragged in both levels). @@ -179,17 +179,17 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: Examples: >>> from datetime import datetime >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 1, 2, 2, 2], + ... "subject_id": [1, 1, 1, 1, 2, 2, 2], ... "time": [ ... None, datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 13), ... None, datetime(2021, 1, 2), datetime(2021, 1, 2)], ... "code": [100, 101, 102, 103, 200, 201, 202], ... "numeric_value": pl.Series([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], dtype=pl.Float32) ... }).lazy() - >>> extract_seq_of_patient_events(df).collect() + >>> extract_seq_of_subject_events(df).collect() shape: (2, 4) ┌────────────┬─────────────────┬─────────────────────┬─────────────────────┐ - │ patient_id ┆ time_delta_days ┆ code ┆ numeric_value │ + │ subject_id ┆ time_delta_days ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ list[f32] ┆ list[list[i64]] ┆ list[list[f32]] │ ╞════════════╪═════════════════╪═════════════════════╪═════════════════════╡ @@ -203,9 +203,9 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: time_delta_days_expr = (pl.col("time").diff().dt.total_seconds() / SECONDS_PER_DAY).cast(pl.Float32) return ( - dynamic.group_by("patient_id", "time", maintain_order=True) + dynamic.group_by("subject_id", "time", maintain_order=True) .agg(pl.col("code").name.keep(), fill_to_nans("numeric_value").name.keep()) - .group_by("patient_id", maintain_order=True) + .group_by("subject_id", maintain_order=True) .agg( fill_to_nans(time_delta_days_expr).alias("time_delta_days"), "code", @@ -258,7 +258,7 @@ def main(cfg: DictConfig): event_seq_out_fp, pl.scan_parquet, write_lazyframe, - extract_seq_of_patient_events, + extract_seq_of_subject_events, do_overwrite=cfg.do_overwrite, ) diff --git a/src/MEDS_transforms/utils.py b/src/MEDS_transforms/utils.py index 59a7cd69..b62f7d12 100644 --- a/src/MEDS_transforms/utils.py +++ b/src/MEDS_transforms/utils.py @@ -412,15 +412,15 @@ def is_col_field(field: str | None) -> bool: bool: True if the field is formatted as "col(column_name)", False otherwise. Examples: - >>> is_col_field("col(patient_id)") + >>> is_col_field("col(subject_id)") True - >>> is_col_field("col(patient_id") + >>> is_col_field("col(subject_id") False - >>> is_col_field("patient_id)") + >>> is_col_field("subject_id)") False - >>> is_col_field("column(patient_id)") + >>> is_col_field("column(subject_id)") False - >>> is_col_field("patient_id") + >>> is_col_field("subject_id") False >>> is_col_field(None) False @@ -440,16 +440,16 @@ def parse_col_field(field: str) -> str: ValueError: If the input string does not match the expected format. Examples: - >>> parse_col_field("col(patient_id)") - 'patient_id' - >>> parse_col_field("col(patient_id") + >>> parse_col_field("col(subject_id)") + 'subject_id' + >>> parse_col_field("col(subject_id") Traceback (most recent call last): ... - ValueError: Invalid column field: col(patient_id - >>> parse_col_field("column(patient_id)") + ValueError: Invalid column field: col(subject_id + >>> parse_col_field("column(subject_id)") Traceback (most recent call last): ... - ValueError: Invalid column field: column(patient_id) + ValueError: Invalid column field: column(subject_id) """ if not is_col_field(field): raise ValueError(f"Invalid column field: {field}") diff --git a/tests/test_add_time_derived_measurements.py b/tests/test_add_time_derived_measurements.py index e5653a18..964cad9c 100644 --- a/tests/test_add_time_derived_measurements.py +++ b/tests/test_add_time_derived_measurements.py @@ -4,6 +4,7 @@ scripts. """ +from meds import subject_id_field from .transform_tester_base import ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, single_stage_transform_tester from .utils import parse_meds_csvs @@ -96,8 +97,8 @@ ``` """ -WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +WANT_TRAIN_0 = f""" +{subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)", @@ -156,9 +157,9 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE, """ -# All patients in this shard had only 4 events. -WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +# All subjects in this shard had only 4 events. +WANT_TRAIN_1 = f""" +{subject_id_field},time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00","TIME_OF_DAY//[00,06)", @@ -185,8 +186,8 @@ 814703,"02/05/2010, 07:02:30",DISCHARGE, """ -WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +WANT_TUNING_0 = f""" +{subject_id_field},time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00","TIME_OF_DAY//[00,06)", @@ -201,8 +202,8 @@ 754281,"01/03/2010, 08:22:13",DISCHARGE, """ -WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +WANT_HELD_OUT_0 = f""" +{subject_id_field},time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)", diff --git a/tests/test_aggregate_code_metadata.py b/tests/test_aggregate_code_metadata.py index 21698cb6..7d3d2a4d 100644 --- a/tests/test_aggregate_code_metadata.py +++ b/tests/test_aggregate_code_metadata.py @@ -13,7 +13,7 @@ ) WANT_OUTPUT_CODE_METADATA_FILE = """ -code,code/n_occurrences,code/n_patients,values/n_occurrences,values/n_patients,values/sum,values/sum_sqd,values/n_ints,values/min,values/max,description,parent_codes +code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/n_subjects,values/sum,values/sum_sqd,values/n_ints,values/min,values/max,description,parent_codes ,44,4,28,4,3198.8389005974336,382968.28937288234,6,86.0,175.271118,, ADMISSION//CARDIAC,2,2,0,0,0,0,0,,,, ADMISSION//ORTHOPEDIC,1,1,0,0,0,0,0,,,, @@ -45,9 +45,9 @@ "TEMP", ], "code/n_occurrences": [44, 2, 1, 1, 4, 4, 1, 1, 2, 4, 12, 12], - "code/n_patients": [4, 2, 1, 1, 4, 4, 1, 1, 2, 4, 4, 4], + "code/n_subjects": [4, 2, 1, 1, 4, 4, 1, 1, 2, 4, 4, 4], "values/n_occurrences": [28, 0, 0, 0, 0, 0, 0, 0, 0, 4, 12, 12], - "values/n_patients": [4, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 4], + "values/n_subjects": [4, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 4], "values/sum": [ 3198.8389005974336, 0, @@ -163,9 +163,9 @@ AGGREGATIONS = [ "code/n_occurrences", - "code/n_patients", + "code/n_subjects", "values/n_occurrences", - "values/n_patients", + "values/n_subjects", "values/sum", "values/sum_sqd", "values/n_ints", diff --git a/tests/test_extract.py b/tests/test_extract.py index d8a3c3eb..787d5d83 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -15,7 +15,7 @@ if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": SHARD_EVENTS_SCRIPT = extraction_root / "shard_events.py" - SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_patients.py" + SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_subjects.py" CONVERT_TO_SHARDED_EVENTS_SCRIPT = extraction_root / "convert_to_sharded_events.py" MERGE_TO_MEDS_COHORT_SCRIPT = extraction_root / "merge_to_MEDS_cohort.py" EXTRACT_CODE_METADATA_SCRIPT = extraction_root / "extract_code_metadata.py" @@ -23,7 +23,7 @@ FINALIZE_METADATA_SCRIPT = extraction_root / "finalize_MEDS_metadata.py" else: SHARD_EVENTS_SCRIPT = "MEDS_extract-shard_events" - SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_patients" + SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_subjects" CONVERT_TO_SHARDED_EVENTS_SCRIPT = "MEDS_extract-convert_to_sharded_events" MERGE_TO_MEDS_COHORT_SCRIPT = "MEDS_extract-merge_to_MEDS_cohort" EXTRACT_CODE_METADATA_SCRIPT = "MEDS_extract-extract_code_metadata" @@ -53,7 +53,7 @@ """ ADMIT_VITALS_CSV = """ -patient_id,admit_date,disch_date,department,vitals_date,HR,temp +subject_id,admit_date,disch_date,department,vitals_date,HR,temp 239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 754281,"01/03/2010, 06:27:59","01/03/2010, 08:22:13",PULMONARY,"01/03/2010, 06:27:59",142.0,99.8 814703,"02/05/2010, 05:55:39","02/05/2010, 07:02:30",ORTHOPEDIC,"02/05/2010, 05:55:39",170.2,100.1 @@ -88,7 +88,7 @@ EVENT_CFGS_YAML = """ subjects: - patient_id_col: MRN + subject_id_col: MRN eye_color: code: - EYE_COLOR @@ -144,9 +144,9 @@ "held_out/0": [1500733], } -PATIENT_SPLITS_DF = pl.DataFrame( +SUBJECT_SPLITS_DF = pl.DataFrame( { - "patient_id": [239684, 1195293, 68729, 814703, 754281, 1500733], + "subject_id": [239684, 1195293, 68729, 814703, 754281, 1500733], "split": ["train", "train", "train", "train", "tuning", "held_out"], } ) @@ -156,17 +156,17 @@ def get_expected_output(df: str) -> pl.DataFrame: return ( pl.read_csv(source=StringIO(df)) .select( - "patient_id", + "subject_id", pl.col("time").str.strptime(pl.Datetime, "%m/%d/%Y, %H:%M:%S").alias("time"), pl.col("code"), "numeric_value", ) - .sort(by=["patient_id", "time"]) + .sort(by=["subject_id", "time"]) ) MEDS_OUTPUT_TRAIN_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -176,7 +176,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, 239684,"05/11/2010, 17:41:51",HR,102.6 239684,"05/11/2010, 17:41:51",TEMP,96.0 @@ -204,7 +204,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_1_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -214,7 +214,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_1_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, 68729,"05/26/2010, 02:30:56",HR,86.0 68729,"05/26/2010, 02:30:56",TEMP,97.8 @@ -226,14 +226,14 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TUNING_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, """ MEDS_OUTPUT_TUNING_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, 754281,"01/03/2010, 06:27:59",HR,142.0 754281,"01/03/2010, 06:27:59",TEMP,99.8 @@ -241,14 +241,14 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_HELD_OUT_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, """ MEDS_OUTPUT_HELD_OUT_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, 1500733,"06/03/2010, 14:54:38",HR,91.4 1500733,"06/03/2010, 14:54:38",TEMP,100.0 @@ -338,19 +338,19 @@ def test_extraction(): # Run the extraction script # 1. Sub-shard the data (this will be a null operation in this case, but it is worth doing just in # case. - # 2. Collect the patient splits. - # 3. Extract the events and sub-shard by patient. + # 2. Collect the subject splits. + # 3. Extract the events and sub-shard by subject. # 4. Merge to the final output. extraction_config_kwargs = { "input_dir": str(raw_cohort_dir.resolve()), "cohort_dir": str(MEDS_cohort_dir.resolve()), "event_conversion_config_fp": str(event_cfgs_yaml.resolve()), - "stage_configs.split_and_shard_patients.split_fracs.train": 4 / 6, - "stage_configs.split_and_shard_patients.split_fracs.tuning": 1 / 6, - "stage_configs.split_and_shard_patients.split_fracs.held_out": 1 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.train": 4 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.tuning": 1 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.held_out": 1 / 6, "stage_configs.shard_events.row_chunksize": 10, - "stage_configs.split_and_shard_patients.n_patients_per_shard": 2, + "stage_configs.split_and_shard_subjects.n_subjects_per_shard": 2, "hydra.verbose": True, "etl_metadata.dataset_name": "TEST", "etl_metadata.dataset_version": "1.0", @@ -405,11 +405,11 @@ def test_extraction(): check_row_order=False, ) - # Stage 2: Collect the patient splits + # Stage 2: Collect the subject splits stderr, stdout = run_command( SPLIT_AND_SHARD_SCRIPT, extraction_config_kwargs, - "split_and_shard_patients", + "split_and_shard_subjects", ) all_stderrs.append(stderr) @@ -435,12 +435,12 @@ def test_extraction(): "NEEDING TO BE UPDATED." ) except AssertionError as e: - print("Failed to split patients") + print("Failed to split subjects") print(f"stderr:\n{stderr}") print(f"stdout:\n{stdout}") raise e - # Stage 3: Extract the events and sub-shard by patient + # Stage 3: Extract the events and sub-shard by subject stderr, stdout = run_command( CONVERT_TO_SHARDED_EVENTS_SCRIPT, extraction_config_kwargs, @@ -449,8 +449,8 @@ def test_extraction(): all_stderrs.append(stderr) all_stdouts.append(stdout) - patient_subsharded_folder = MEDS_cohort_dir / "convert_to_sharded_events" - assert patient_subsharded_folder.is_dir(), f"Expected {patient_subsharded_folder} to be a directory." + subject_subsharded_folder = MEDS_cohort_dir / "convert_to_sharded_events" + assert subject_subsharded_folder.is_dir(), f"Expected {subject_subsharded_folder} to be a directory." for split, expected_outputs in SUB_SHARDED_OUTPUTS.items(): for prefix, expected_df_L in expected_outputs.items(): @@ -459,7 +459,7 @@ def test_extraction(): expected_df = pl.concat([get_expected_output(df) for df in expected_df_L]) - fps = list((patient_subsharded_folder / split / prefix).glob("*.parquet")) + fps = list((subject_subsharded_folder / split / prefix).glob("*.parquet")) assert len(fps) > 0 # We add a "unique" here as there may be some duplicates across the row-group sub-shards. @@ -511,12 +511,12 @@ def test_extraction(): check_row_order=False, ) - assert got_df["patient_id"].is_sorted(), f"Patient IDs should be sorted for split {split}." + assert got_df["subject_id"].is_sorted(), f"Subject IDs should be sorted for split {split}." for subj in splits[split]: - got_df_subj = got_df.filter(pl.col("patient_id") == subj) + got_df_subj = got_df.filter(pl.col("subject_id") == subj) assert got_df_subj[ "time" - ].is_sorted(), f"Times should be sorted for patient {subj} in split {split}." + ].is_sorted(), f"Times should be sorted for subject {subj} in split {split}." except AssertionError as e: print(f"Failed on split {split}") @@ -593,12 +593,12 @@ def test_extraction(): check_row_order=False, ) - assert got_df["patient_id"].is_sorted(), f"Patient IDs should be sorted for split {split}." + assert got_df["subject_id"].is_sorted(), f"Subject IDs should be sorted for split {split}." for subj in splits[split]: - got_df_subj = got_df.filter(pl.col("patient_id") == subj) + got_df_subj = got_df.filter(pl.col("subject_id") == subj) assert got_df_subj[ "time" - ].is_sorted(), f"Times should be sorted for patient {subj} in split {split}." + ].is_sorted(), f"Times should be sorted for subject {subj} in split {split}." except AssertionError as e: print(f"Failed on split {split}") @@ -651,14 +651,14 @@ def test_extraction(): assert got_json == MEDS_OUTPUT_DATASET_METADATA_JSON, f"Dataset metadata differs: {got_json}" # Check the splits parquet - output_file = MEDS_cohort_dir / "metadata" / "patient_splits.parquet" + output_file = MEDS_cohort_dir / "metadata" / "subject_splits.parquet" assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" got_df = pl.read_parquet(output_file, glob=False, use_pyarrow=True) assert_df_equal( - PATIENT_SPLITS_DF, + SUBJECT_SPLITS_DF, got_df, - "Patient splits should be equal to the expected splits.", + "Subject splits should be equal to the expected splits.", check_column_order=False, check_row_order=False, ) diff --git a/tests/test_extract_no_metadata.py b/tests/test_extract_no_metadata.py index f1945af0..2391a977 100644 --- a/tests/test_extract_no_metadata.py +++ b/tests/test_extract_no_metadata.py @@ -15,7 +15,7 @@ if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": SHARD_EVENTS_SCRIPT = extraction_root / "shard_events.py" - SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_patients.py" + SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_subjects.py" CONVERT_TO_SHARDED_EVENTS_SCRIPT = extraction_root / "convert_to_sharded_events.py" MERGE_TO_MEDS_COHORT_SCRIPT = extraction_root / "merge_to_MEDS_cohort.py" EXTRACT_CODE_METADATA_SCRIPT = extraction_root / "extract_code_metadata.py" @@ -23,7 +23,7 @@ FINALIZE_METADATA_SCRIPT = extraction_root / "finalize_MEDS_metadata.py" else: SHARD_EVENTS_SCRIPT = "MEDS_extract-shard_events" - SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_patients" + SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_subjects" CONVERT_TO_SHARDED_EVENTS_SCRIPT = "MEDS_extract-convert_to_sharded_events" MERGE_TO_MEDS_COHORT_SCRIPT = "MEDS_extract-merge_to_MEDS_cohort" EXTRACT_CODE_METADATA_SCRIPT = "MEDS_extract-extract_code_metadata" @@ -53,7 +53,7 @@ """ ADMIT_VITALS_CSV = """ -patient_id,admit_date,disch_date,department,vitals_date,HR,temp +subject_id,admit_date,disch_date,department,vitals_date,HR,temp 239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 754281,"01/03/2010, 06:27:59","01/03/2010, 08:22:13",PULMONARY,"01/03/2010, 06:27:59",142.0,99.8 814703,"02/05/2010, 05:55:39","02/05/2010, 07:02:30",ORTHOPEDIC,"02/05/2010, 05:55:39",170.2,100.1 @@ -88,7 +88,7 @@ EVENT_CFGS_YAML = """ subjects: - patient_id_col: MRN + subject_id_col: MRN eye_color: code: - EYE_COLOR @@ -133,9 +133,9 @@ "held_out/0": [1500733], } -PATIENT_SPLITS_DF = pl.DataFrame( +SUBJECT_SPLITS_DF = pl.DataFrame( { - "patient_id": [239684, 1195293, 68729, 814703, 754281, 1500733], + "subject_id": [239684, 1195293, 68729, 814703, 754281, 1500733], "split": ["train", "train", "train", "train", "tuning", "held_out"], } ) @@ -145,17 +145,17 @@ def get_expected_output(df: str) -> pl.DataFrame: return ( pl.read_csv(source=StringIO(df)) .select( - "patient_id", + "subject_id", pl.col("time").str.strptime(pl.Datetime, "%m/%d/%Y, %H:%M:%S").alias("time"), pl.col("code"), "numeric_value", ) - .sort(by=["patient_id", "time"]) + .sort(by=["subject_id", "time"]) ) MEDS_OUTPUT_TRAIN_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -165,7 +165,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, 239684,"05/11/2010, 17:41:51",HR,102.6 239684,"05/11/2010, 17:41:51",TEMP,96.0 @@ -193,7 +193,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_1_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -203,7 +203,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_1_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, 68729,"05/26/2010, 02:30:56",HR,86.0 68729,"05/26/2010, 02:30:56",TEMP,97.8 @@ -215,14 +215,14 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TUNING_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, """ MEDS_OUTPUT_TUNING_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, 754281,"01/03/2010, 06:27:59",HR,142.0 754281,"01/03/2010, 06:27:59",TEMP,99.8 @@ -230,14 +230,14 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_HELD_OUT_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, """ MEDS_OUTPUT_HELD_OUT_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, 1500733,"06/03/2010, 14:54:38",HR,91.4 1500733,"06/03/2010, 14:54:38",TEMP,100.0 @@ -322,19 +322,19 @@ def test_extraction(): # Run the extraction script # 1. Sub-shard the data (this will be a null operation in this case, but it is worth doing just in # case. - # 2. Collect the patient splits. - # 3. Extract the events and sub-shard by patient. + # 2. Collect the subject splits. + # 3. Extract the events and sub-shard by subject. # 4. Merge to the final output. extraction_config_kwargs = { "input_dir": str(raw_cohort_dir.resolve()), "cohort_dir": str(MEDS_cohort_dir.resolve()), "event_conversion_config_fp": str(event_cfgs_yaml.resolve()), - "stage_configs.split_and_shard_patients.split_fracs.train": 4 / 6, - "stage_configs.split_and_shard_patients.split_fracs.tuning": 1 / 6, - "stage_configs.split_and_shard_patients.split_fracs.held_out": 1 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.train": 4 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.tuning": 1 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.held_out": 1 / 6, "stage_configs.shard_events.row_chunksize": 10, - "stage_configs.split_and_shard_patients.n_patients_per_shard": 2, + "stage_configs.split_and_shard_subjects.n_subjects_per_shard": 2, "hydra.verbose": True, "etl_metadata.dataset_name": "TEST", "etl_metadata.dataset_version": "1.0", @@ -389,11 +389,11 @@ def test_extraction(): check_row_order=False, ) - # Stage 2: Collect the patient splits + # Stage 2: Collect the subject splits stderr, stdout = run_command( SPLIT_AND_SHARD_SCRIPT, extraction_config_kwargs, - "split_and_shard_patients", + "split_and_shard_subjects", ) all_stderrs.append(stderr) @@ -419,12 +419,12 @@ def test_extraction(): "NEEDING TO BE UPDATED." ) except AssertionError as e: - print("Failed to split patients") + print("Failed to split subjects") print(f"stderr:\n{stderr}") print(f"stdout:\n{stdout}") raise e - # Stage 3: Extract the events and sub-shard by patient + # Stage 3: Extract the events and sub-shard by subject stderr, stdout = run_command( CONVERT_TO_SHARDED_EVENTS_SCRIPT, extraction_config_kwargs, @@ -433,8 +433,8 @@ def test_extraction(): all_stderrs.append(stderr) all_stdouts.append(stdout) - patient_subsharded_folder = MEDS_cohort_dir / "convert_to_sharded_events" - assert patient_subsharded_folder.is_dir(), f"Expected {patient_subsharded_folder} to be a directory." + subject_subsharded_folder = MEDS_cohort_dir / "convert_to_sharded_events" + assert subject_subsharded_folder.is_dir(), f"Expected {subject_subsharded_folder} to be a directory." for split, expected_outputs in SUB_SHARDED_OUTPUTS.items(): for prefix, expected_df_L in expected_outputs.items(): @@ -443,7 +443,7 @@ def test_extraction(): expected_df = pl.concat([get_expected_output(df) for df in expected_df_L]) - fps = list((patient_subsharded_folder / split / prefix).glob("*.parquet")) + fps = list((subject_subsharded_folder / split / prefix).glob("*.parquet")) assert len(fps) > 0 # We add a "unique" here as there may be some duplicates across the row-group sub-shards. @@ -495,12 +495,12 @@ def test_extraction(): check_row_order=False, ) - assert got_df["patient_id"].is_sorted(), f"Patient IDs should be sorted for split {split}." + assert got_df["subject_id"].is_sorted(), f"Subject IDs should be sorted for split {split}." for subj in splits[split]: - got_df_subj = got_df.filter(pl.col("patient_id") == subj) + got_df_subj = got_df.filter(pl.col("subject_id") == subj) assert got_df_subj[ "time" - ].is_sorted(), f"Times should be sorted for patient {subj} in split {split}." + ].is_sorted(), f"Times should be sorted for subject {subj} in split {split}." except AssertionError as e: print(f"Failed on split {split}") @@ -560,12 +560,12 @@ def test_extraction(): check_row_order=False, ) - assert got_df["patient_id"].is_sorted(), f"Patient IDs should be sorted for split {split}." + assert got_df["subject_id"].is_sorted(), f"Subject IDs should be sorted for split {split}." for subj in splits[split]: - got_df_subj = got_df.filter(pl.col("patient_id") == subj) + got_df_subj = got_df.filter(pl.col("subject_id") == subj) assert got_df_subj[ "time" - ].is_sorted(), f"Times should be sorted for patient {subj} in split {split}." + ].is_sorted(), f"Times should be sorted for subject {subj} in split {split}." except AssertionError as e: print(f"Failed on split {split}") @@ -618,14 +618,14 @@ def test_extraction(): assert got_json == MEDS_OUTPUT_DATASET_METADATA_JSON, f"Dataset metadata differs: {got_json}" # Check the splits parquet - output_file = MEDS_cohort_dir / "metadata" / "patient_splits.parquet" + output_file = MEDS_cohort_dir / "metadata" / "subject_splits.parquet" assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" got_df = pl.read_parquet(output_file, glob=False, use_pyarrow=True) assert_df_equal( - PATIENT_SPLITS_DF, + SUBJECT_SPLITS_DF, got_df, - "Patient splits should be equal to the expected splits.", + "Subject splits should be equal to the expected splits.", check_column_order=False, check_row_order=False, ) diff --git a/tests/test_filter_measurements.py b/tests/test_filter_measurements.py index cb919d1b..cb5b4eec 100644 --- a/tests/test_filter_measurements.py +++ b/tests/test_filter_measurements.py @@ -10,7 +10,7 @@ # This is the code metadata # MEDS_CODE_METADATA_CSV = """ -# code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code +# code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code # ,44,4,28,3198.8389005974336,382968.28937288234,, # ADMISSION//CARDIAC,2,2,0,,,, # ADMISSION//ORTHOPEDIC,1,1,0,,,, @@ -25,11 +25,11 @@ # TEMP,12,4,12,1181.4999999999998,116373.38999999998,"Body Temperature",LOINC/8310-5 # """ # -# We'll keep only the codes that occur for at least 2 patients, which are: ADMISSION//CARDIAC, DISCHARGE, DOB, +# We'll keep only the codes that occur for at least 2 subjects, which are: ADMISSION//CARDIAC, DISCHARGE, DOB, # EYE_COLOR//HAZEL, HEIGHT, HR, TEMP WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, @@ -61,7 +61,7 @@ """ WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -77,7 +77,7 @@ """ WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, 754281,"01/03/2010, 06:27:59",HR,142.0 @@ -86,7 +86,7 @@ """ WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, 1500733,"06/03/2010, 14:54:38",HR,91.4 @@ -112,14 +112,14 @@ def test_filter_measurements(): single_stage_transform_tester( transform_script=FILTER_MEASUREMENTS_SCRIPT, stage_name="filter_measurements", - transform_stage_kwargs={"min_patients_per_code": 2}, + transform_stage_kwargs={"min_subjects_per_code": 2}, want_data=WANT_SHARDS, ) # This is the code metadata # MEDS_CODE_METADATA_CSV = """ -# code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code +# code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code # ,44,4,28,3198.8389005974336,382968.28937288234,, # ADMISSION//CARDIAC,2,2,0,,,, # ADMISSION//ORTHOPEDIC,1,1,0,,,, @@ -144,7 +144,7 @@ def test_filter_measurements(): # - Other codes won't be filtered, so we will retain HEIGHT, DISCHARGE, DOB, TEMP MR_WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, @@ -166,7 +166,7 @@ def test_filter_measurements(): """ MR_WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, 68729,"05/26/2010, 02:30:56",TEMP,97.8 @@ -178,7 +178,7 @@ def test_filter_measurements(): """ MR_WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, 754281,"01/03/2010, 06:27:59",TEMP,99.8 @@ -186,7 +186,7 @@ def test_filter_measurements(): """ MR_WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, 1500733,"06/03/2010, 14:54:38",TEMP,100.0 @@ -214,13 +214,13 @@ def test_match_revise_filter_measurements(): stage_name="filter_measurements", transform_stage_kwargs={ "_match_revise": [ - {"_matcher": {"code": "ADMISSION//CARDIAC"}, "min_patients_per_code": 2}, - {"_matcher": {"code": "ADMISSION//ORTHOPEDIC"}, "min_patients_per_code": 2}, - {"_matcher": {"code": "ADMISSION//PULMONARY"}, "min_patients_per_code": 2}, - {"_matcher": {"code": "HR"}, "min_patients_per_code": 15}, - {"_matcher": {"code": "EYE_COLOR//BLUE"}, "min_patients_per_code": 4}, - {"_matcher": {"code": "EYE_COLOR//BROWN"}, "min_patients_per_code": 4}, - {"_matcher": {"code": "EYE_COLOR//HAZEL"}, "min_patients_per_code": 4}, + {"_matcher": {"code": "ADMISSION//CARDIAC"}, "min_subjects_per_code": 2}, + {"_matcher": {"code": "ADMISSION//ORTHOPEDIC"}, "min_subjects_per_code": 2}, + {"_matcher": {"code": "ADMISSION//PULMONARY"}, "min_subjects_per_code": 2}, + {"_matcher": {"code": "HR"}, "min_subjects_per_code": 15}, + {"_matcher": {"code": "EYE_COLOR//BLUE"}, "min_subjects_per_code": 4}, + {"_matcher": {"code": "EYE_COLOR//BROWN"}, "min_subjects_per_code": 4}, + {"_matcher": {"code": "EYE_COLOR//HAZEL"}, "min_subjects_per_code": 4}, ], }, want_data=MR_WANT_SHARDS, diff --git a/tests/test_filter_patients.py b/tests/test_filter_subjects.py similarity index 75% rename from tests/test_filter_patients.py rename to tests/test_filter_subjects.py index 0b07836a..1defee47 100644 --- a/tests/test_filter_patients.py +++ b/tests/test_filter_subjects.py @@ -1,15 +1,16 @@ -"""Tests the filter patients script. +"""Tests the filter subjects script. Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ +from meds import subject_id_field -from .transform_tester_base import FILTER_PATIENTS_SCRIPT, single_stage_transform_tester +from .transform_tester_base import FILTER_SUBJECTS_SCRIPT, single_stage_transform_tester from .utils import parse_meds_csvs -WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +WANT_TRAIN_0 = f""" +{subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -42,18 +43,18 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE, """ -# All patients in this shard had only 4 events. -WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +# All subjects in this shard had only 4 events. +WANT_TRAIN_1 = f""" +{subject_id_field},time,code,numeric_value """ -# All patients in this shard had only 4 events. -WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +# All subjects in this shard had only 4 events. +WANT_TUNING_0 = f""" +{subject_id_field},time,code,numeric_value """ -WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +WANT_HELD_OUT_0 = f""" +{subject_id_field},time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, @@ -77,10 +78,10 @@ ) -def test_filter_patients(): +def test_filter_subjects(): single_stage_transform_tester( - transform_script=FILTER_PATIENTS_SCRIPT, - stage_name="filter_patients", - transform_stage_kwargs={"min_events_per_patient": 5}, + transform_script=FILTER_SUBJECTS_SCRIPT, + stage_name="filter_subjects", + transform_stage_kwargs={"min_events_per_subject": 5}, want_data=WANT_SHARDS, ) diff --git a/tests/test_fit_vocabulary_indices.py b/tests/test_fit_vocabulary_indices.py index ce7c40a6..f67ebe1f 100644 --- a/tests/test_fit_vocabulary_indices.py +++ b/tests/test_fit_vocabulary_indices.py @@ -12,7 +12,7 @@ ) WANT_CSV = """ -code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_codes,code/vocab_index +code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_codes,code/vocab_index ,44,4,28,3198.8389005974336,382968.28937288234,,,1 ADMISSION//CARDIAC,2,2,0,,,,,2 ADMISSION//ORTHOPEDIC,1,1,0,,,,,3 diff --git a/tests/test_multi_stage_preprocess_pipeline.py b/tests/test_multi_stage_preprocess_pipeline.py index eda9060e..32c4a277 100644 --- a/tests/test_multi_stage_preprocess_pipeline.py +++ b/tests/test_multi_stage_preprocess_pipeline.py @@ -4,7 +4,7 @@ scripts. In this test, the following stages are run: - - filter_patients + - filter_subjects - add_time_derived_measurements - fit_outlier_detection - occlude_outliers @@ -20,12 +20,13 @@ from datetime import datetime import polars as pl +from meds import subject_id_field from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict from .transform_tester_base import ( ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, AGGREGATE_CODE_METADATA_SCRIPT, - FILTER_PATIENTS_SCRIPT, + FILTER_SUBJECTS_SCRIPT, FIT_VOCABULARY_INDICES_SCRIPT, NORMALIZATION_SCRIPT, OCCLUDE_OUTLIERS_SCRIPT, @@ -51,8 +52,8 @@ ) STAGE_CONFIG_YAML = """ -filter_patients: - min_events_per_patient: 5 +filter_subjects: + min_events_per_subject: 5 add_time_derived_measurements: age: DOB_code: "DOB" # This is the MEDS official code for BIRTH @@ -71,17 +72,17 @@ fit_normalization: aggregations: - "code/n_occurrences" - - "code/n_patients" + - "code/n_subjects" - "values/n_occurrences" - "values/sum" - "values/sum_sqd" """ -# After filtering out patients with fewer than 5 events: +# After filtering out subjects with fewer than 5 events: WANT_FILTER = parse_shards_yaml( - """ - "filter_patients/train/0": |-2 - patient_id,time,code,numeric_value + f""" + "filter_subjects/train/0": |-2 + {subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -113,14 +114,14 @@ 1195293,"06/20/2010, 20:41:33",TEMP,100.4 1195293,"06/20/2010, 20:50:04",DISCHARGE, - "filter_patients/train/1": |-2 - patient_id,time,code,numeric_value + "filter_subjects/train/1": |-2 + {subject_id_field},time,code,numeric_value - "filter_patients/tuning/0": |-2 - patient_id,time,code,numeric_value + "filter_subjects/tuning/0": |-2 + {subject_id_field},time,code,numeric_value - "filter_patients/held_out/0": |-2 - patient_id,time,code,numeric_value + "filter_subjects/held_out/0": |-2 + {subject_id_field},time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, @@ -136,9 +137,9 @@ ) WANT_TIME_DERIVED = parse_shards_yaml( - """ + f""" "add_time_derived_measurements/train/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)", @@ -197,13 +198,13 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE, "add_time_derived_measurements/train/1": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value "add_time_derived_measurements/tuning/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value "add_time_derived_measurements/held_out/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)", @@ -387,9 +388,9 @@ """ WANT_OCCLUDE_OUTLIERS = parse_shards_yaml( - """ + f""" "occlude_outliers/train/0": |-2 - patient_id,time,code,numeric_value,numeric_value/is_inlier + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier 239684,,EYE_COLOR//BROWN,, 239684,,HEIGHT,,false 239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)",, @@ -448,13 +449,13 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE,, "occlude_outliers/train/1": |-2 - patient_id,time,code,numeric_value,numeric_value/is_inlier + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier "occlude_outliers/tuning/0": |-2 - patient_id,time,code,numeric_value,numeric_value/is_inlier + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier "occlude_outliers/held_out/0": |-2 - patient_id,time,code,numeric_value,numeric_value/is_inlier + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier 1500733,,EYE_COLOR//BROWN,, 1500733,,HEIGHT,,false 1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)",, @@ -488,7 +489,7 @@ ... .group_by("code") ... .agg( ... pl.len().alias("code/n_occurrences"), -... pl.col("patient_id").n_unique().alias("code/n_patients"), +... pl.col("subject_id").n_unique().alias("code/n_subjects"), ... VALS.len().alias("values/n_occurrences"), ... VALS.sum().alias("values/sum"), ... (VALS**2).sum().alias("values/sum_sqd") @@ -497,7 +498,7 @@ >>> post_transform.filter(pl.col("values/n_occurrences") > 0) shape: (3, 6) ┌──────┬────────────────────┬─────────────────┬──────────────────────┬────────────┬────────────────┐ -│ code ┆ code/n_occurrences ┆ code/n_patients ┆ values/n_occurrences ┆ values/sum ┆ values/sum_sqd │ +│ code ┆ code/n_occurrences ┆ code/n_subjects ┆ values/n_occurrences ┆ values/sum ┆ values/sum_sqd │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ u32 ┆ u32 ┆ u32 ┆ f32 ┆ f32 │ ╞══════╪════════════════════╪═════════════════╪══════════════════════╪════════════╪════════════════╡ @@ -508,7 +509,7 @@ >>> print(post_transform.filter(pl.col("values/n_occurrences") > 0).to_dict(as_series=False)) {'code': ['HR', 'TEMP', 'AGE'], 'code/n_occurrences': [10, 10, 12], - 'code/n_patients': [2, 2, 2], + 'code/n_subjects': [2, 2, 2], 'values/n_occurrences': [7, 6, 7], 'values/sum': [776.7999877929688, 600.1000366210938, 224.02084350585938], 'values/sum_sqd': [86249.921875, 60020.21484375, 7169.33349609375]} @@ -533,7 +534,7 @@ "DOB", ], "code/n_occurrences": [1, 1, 10, 10, 12, 2, 10, 2, 2, 2, 2, 2], - "code/n_patients": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2], + "code/n_subjects": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2], "values/n_occurrences": [0, 0, 7, 6, 7, 0, 0, 0, 0, 0, 0, 0], "values/sum": [ 0.0, @@ -597,7 +598,7 @@ "description": pl.String, "parent_codes": pl.List(pl.String), "code/n_occurrences": pl.UInt8, - "code/n_patients": pl.UInt8, + "code/n_subjects": pl.UInt8, "values/n_occurrences": pl.UInt8, # In the real stage, this is shrunk, so it differs from the ex. "values/sum": pl.Float32, "values/sum_sqd": pl.Float32, @@ -625,7 +626,7 @@ ], "code/vocab_index": [5, 6, 8, 9, 2, 7, 12, 11, 10, 1, 3, 4], "code/n_occurrences": [1, 1, 10, 10, 12, 2, 10, 2, 2, 2, 2, 2], - "code/n_patients": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2], + "code/n_subjects": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2], "values/n_occurrences": [0, 0, 7, 6, 7, 0, 0, 0, 0, 0, 0, 0], "values/sum": [ 0.0, @@ -689,7 +690,7 @@ "description": pl.String, "parent_codes": pl.List(pl.String), "code/n_occurrences": pl.UInt8, - "code/n_patients": pl.UInt8, + "code/n_subjects": pl.UInt8, "code/vocab_index": pl.UInt8, "values/n_occurrences": pl.UInt8, "values/sum": pl.Float32, @@ -772,9 +773,9 @@ # Note we have dropped the row in the held out shard that doesn't have a code in the vocabulary! WANT_NORMALIZATION = parse_shards_yaml( - """ + f""" "normalization/train/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value 239684,,6, 239684,,7, 239684,"12/28/1980, 00:00:00",10, @@ -833,13 +834,13 @@ 1195293,"06/20/2010, 20:50:04",3, "normalization/train/1": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value "normalization/tuning/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value "normalization/held_out/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value 1500733,,6, 1500733,,7, 1500733,"07/20/1986, 00:00:00",10, @@ -864,7 +865,7 @@ ) TOKENIZATION_SCHEMA_DF_SCHEMA = { - "patient_id": pl.UInt32, + subject_id_field: pl.Int64, "code": pl.List(pl.UInt8), "numeric_value": pl.List(pl.Float32), "start_time": pl.Datetime("us"), @@ -873,7 +874,7 @@ WANT_TOKENIZATION_SCHEMAS = { "tokenization/schemas/train/0": pl.DataFrame( { - "patient_id": [239684, 1195293], + subject_id_field: [239684, 1195293], "code": [[6, 7], [5, 7]], "numeric_value": [[None, None], [None, None]], "start_time": [datetime(1980, 12, 28), datetime(1978, 6, 20)], @@ -901,16 +902,16 @@ schema=TOKENIZATION_SCHEMA_DF_SCHEMA, ), "tokenization/schemas/train/1": pl.DataFrame( - {k: [] for k in ["patient_id", "code", "numeric_value", "start_time", "time"]}, + {k: [] for k in [subject_id_field, "code", "numeric_value", "start_time", "time"]}, schema=TOKENIZATION_SCHEMA_DF_SCHEMA, ), "tokenization/schemas/tuning/0": pl.DataFrame( - {k: [] for k in ["patient_id", "code", "numeric_value", "start_time", "time"]}, + {k: [] for k in [subject_id_field, "code", "numeric_value", "start_time", "time"]}, schema=TOKENIZATION_SCHEMA_DF_SCHEMA, ), "tokenization/schemas/held_out/0": pl.DataFrame( { - "patient_id": [1500733], + subject_id_field: [1500733], "code": [[6, 7]], "numeric_value": [[None, None]], "start_time": [datetime(1986, 7, 20)], @@ -928,18 +929,9 @@ ), } -TOKENIZATION_CODE = """ -```python - ->>> import polars as pl ->>> from tests.test_multi_stage_preprocess_pipeline import WANT_NORMALIZATION as dfs ->>> - -``` -""" TOKENIZATION_EVENT_SEQS_DF_SCHEMA = { - "patient_id": pl.UInt32, + subject_id_field: pl.Int64, "code": pl.List(pl.List(pl.UInt8)), "numeric_value": pl.List(pl.List(pl.Float32)), "time_delta_days": pl.List(pl.Float32), @@ -948,7 +940,7 @@ WANT_TOKENIZATION_EVENT_SEQS = { "tokenization/event_seqs/train/0": pl.DataFrame( { - "patient_id": [239684, 1195293], + subject_id_field: [239684, 1195293], "code": [ [[10, 4], [11, 2, 1, 8, 9], [11, 2, 8, 9], [12, 2, 8, 9], [12, 2, 8, 9], [12, 2, 3]], [ @@ -995,16 +987,16 @@ schema=TOKENIZATION_EVENT_SEQS_DF_SCHEMA, ), "tokenization/event_seqs/train/1": pl.DataFrame( - {k: [] for k in ["patient_id", "code", "numeric_value", "time_delta_days"]}, + {k: [] for k in [subject_id_field, "code", "numeric_value", "time_delta_days"]}, schema=TOKENIZATION_EVENT_SEQS_DF_SCHEMA, ), "tokenization/event_seqs/tuning/0": pl.DataFrame( - {k: [] for k in ["patient_id", "code", "numeric_value", "time_delta_days"]}, + {k: [] for k in [subject_id_field, "code", "numeric_value", "time_delta_days"]}, schema=TOKENIZATION_EVENT_SEQS_DF_SCHEMA, ), "tokenization/event_seqs/held_out/0": pl.DataFrame( { - "patient_id": [1500733], + subject_id_field: [1500733], "code": [ [ [10, 4], @@ -1057,7 +1049,7 @@ def test_pipeline(): multi_stage_transform_tester( transform_scripts=[ - FILTER_PATIENTS_SCRIPT, + FILTER_SUBJECTS_SCRIPT, ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, AGGREGATE_CODE_METADATA_SCRIPT, OCCLUDE_OUTLIERS_SCRIPT, @@ -1068,7 +1060,7 @@ def test_pipeline(): TENSORIZATION_SCRIPT, ], stage_names=[ - "filter_patients", + "filter_subjects", "add_time_derived_measurements", "fit_outlier_detection", "occlude_outliers", diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 46992eda..14207c4c 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -12,7 +12,7 @@ # This is the code metadata file we'll use in this transform test. It is different than the default as we need # a code/vocab_index MEDS_CODE_METADATA_CSV = """ -code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,code/vocab_index +code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,code/vocab_index ADMISSION//CARDIAC,2,2,0,,,1 ADMISSION//ORTHOPEDIC,1,1,0,,,2 ADMISSION//PULMONARY,1,1,0,,,3 @@ -129,7 +129,7 @@ # TEMP: 11 WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,7, 239684,,9,1.5770268440246582 239684,"12/28/1980, 00:00:00",5, @@ -163,7 +163,7 @@ """ WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,8, 68729,,9,-0.5438239574432373 68729,"03/09/1978, 00:00:00",5, @@ -181,7 +181,7 @@ """ WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,7, 754281,,9,0.28697699308395386 754281,"12/19/1988, 00:00:00",5, @@ -192,7 +192,7 @@ """ WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,7, 1500733,,9,-0.7995940446853638 1500733,"07/20/1986, 00:00:00",5, diff --git a/tests/test_occlude_outliers.py b/tests/test_occlude_outliers.py index 63e9376b..f13a4fa0 100644 --- a/tests/test_occlude_outliers.py +++ b/tests/test_occlude_outliers.py @@ -12,7 +12,7 @@ # This is the code metadata # MEDS_CODE_METADATA_CSV = """ -# code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code +# code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code # ,44,4,28,3198.8389005974336,382968.28937288234,, # ADMISSION//CARDIAC,2,2,0,,,, # ADMISSION//ORTHOPEDIC,1,1,0,,,, @@ -75,7 +75,7 @@ """ # noqa: E501 WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value,numeric_value/is_inlier +subject_id,time,code,numeric_value,numeric_value/is_inlier 239684,,EYE_COLOR//BROWN,, 239684,,HEIGHT,,false 239684,"12/28/1980, 00:00:00",DOB,, @@ -109,7 +109,7 @@ """ WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value,numeric_value/is_inlier +subject_id,time,code,numeric_value,numeric_value/is_inlier 68729,,EYE_COLOR//HAZEL,, 68729,,HEIGHT,160.3953106166676,true 68729,"03/09/1978, 00:00:00",DOB,, @@ -127,7 +127,7 @@ """ WANT_TUNING_0 = """ -patient_id,time,code,numeric_value,numeric_value/is_inlier +subject_id,time,code,numeric_value,numeric_value/is_inlier 754281,,EYE_COLOR//BROWN,, 754281,,HEIGHT,166.22261567137025,true 754281,"12/19/1988, 00:00:00",DOB,, @@ -138,7 +138,7 @@ """ WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value,numeric_value/is_inlier +subject_id,time,code,numeric_value,numeric_value/is_inlier 1500733,,EYE_COLOR//BROWN,, 1500733,,HEIGHT,158.60131573580904,true 1500733,"07/20/1986, 00:00:00",DOB,, diff --git a/tests/test_reorder_measurements.py b/tests/test_reorder_measurements.py index c90dee49..7cc7aaa3 100644 --- a/tests/test_reorder_measurements.py +++ b/tests/test_reorder_measurements.py @@ -19,7 +19,7 @@ WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -53,7 +53,7 @@ """ WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,HEIGHT,160.3953106166676 68729,,EYE_COLOR//HAZEL, 68729,"03/09/1978, 00:00:00",DOB, @@ -71,7 +71,7 @@ """ WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, @@ -82,7 +82,7 @@ """ WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, diff --git a/tests/test_reshard_to_split.py b/tests/test_reshard_to_split.py index 65056e5a..79751039 100644 --- a/tests/test_reshard_to_split.py +++ b/tests/test_reshard_to_split.py @@ -5,6 +5,8 @@ """ +from meds import subject_id_field + from .transform_tester_base import RESHARD_TO_SPLIT_SCRIPT, single_stage_transform_tester from .utils import parse_meds_csvs @@ -14,8 +16,8 @@ "2": [239684, 1500733], } -IN_SHARD_0 = """ -patient_id,time,code,numeric_value +IN_SHARD_0 = f""" +{subject_id_field},time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -42,8 +44,8 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE, """ -IN_SHARD_1 = """ -patient_id,time,code,numeric_value +IN_SHARD_1 = f""" +{subject_id_field},time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, @@ -60,8 +62,8 @@ 814703,"02/05/2010, 07:02:30",DISCHARGE, """ -IN_SHARD_2 = """ -patient_id,time,code,numeric_value +IN_SHARD_2 = f""" +{subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -94,8 +96,8 @@ "held_out": [1500733], } -WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +WANT_TRAIN_0 = f""" +{subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -128,8 +130,8 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE, """ -WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +WANT_TRAIN_1 = f""" +{subject_id_field},time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -146,8 +148,8 @@ 814703,"02/05/2010, 07:02:30",DISCHARGE, """ -WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +WANT_TUNING_0 = f""" +{subject_id_field},time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, @@ -157,8 +159,8 @@ 754281,"01/03/2010, 08:22:13",DISCHARGE, """ -WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +WANT_HELD_OUT_0 = f""" +{subject_id_field},time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, @@ -194,7 +196,7 @@ def test_reshard_to_split(): single_stage_transform_tester( transform_script=RESHARD_TO_SPLIT_SCRIPT, stage_name="reshard_to_split", - transform_stage_kwargs={"n_patients_per_shard": 2}, + transform_stage_kwargs={"n_subjects_per_shard": 2}, want_data=WANT_SHARDS, input_shards=IN_SHARDS, input_shards_map=IN_SHARDS_MAP, diff --git a/tests/test_tokenization.py b/tests/test_tokenization.py index 693add18..cf5883ed 100644 --- a/tests/test_tokenization.py +++ b/tests/test_tokenization.py @@ -20,17 +20,17 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: """TODO: Doctests""" out = [] - for patient_ts in ts: + for subject_ts in ts: out.append([float("nan")]) - for i in range(1, len(patient_ts)): - out[-1].append((patient_ts[i] - patient_ts[i - 1]).total_seconds() / SECONDS_PER_DAY) + for i in range(1, len(subject_ts)): + out[-1].append((subject_ts[i] - subject_ts[i - 1]).total_seconds() / SECONDS_PER_DAY) return out # TODO: Make these schemas exportable, maybe??? # TODO: Why is the code getting converted to a float? SCHEMAS_SCHEMA = { - "patient_id": NORMALIZED_MEDS_SCHEMA["patient_id"], + "subject_id": NORMALIZED_MEDS_SCHEMA["subject_id"], "code": pl.List(NORMALIZED_MEDS_SCHEMA["code"]), "numeric_value": pl.List(NORMALIZED_MEDS_SCHEMA["numeric_value"]), "start_time": NORMALIZED_MEDS_SCHEMA["time"], @@ -38,7 +38,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: } SEQ_SCHEMA = { - "patient_id": NORMALIZED_MEDS_SCHEMA["patient_id"], + "subject_id": NORMALIZED_MEDS_SCHEMA["subject_id"], "code": pl.List(pl.List(pl.UInt8)), "numeric_value": pl.List(pl.List(NORMALIZED_MEDS_SCHEMA["numeric_value"])), "time_delta_days": pl.List(pl.Float32), @@ -66,7 +66,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: ] WANT_SCHEMAS_TRAIN_0 = pl.DataFrame( { - "patient_id": [239684, 1195293], + "subject_id": [239684, 1195293], "code": [[7, 9], [6, 9]], "numeric_value": [[None, 1.5770268440246582], [None, 0.06802856922149658]], "start_time": [ts[0] for ts in TRAIN_0_TIMES], @@ -77,7 +77,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_EVENT_SEQ_TRAIN_0 = pl.DataFrame( { - "patient_id": [239684, 1195293], + "subject_id": [239684, 1195293], "time_delta_days": ts_to_time_delta_days(TRAIN_0_TIMES), "code": [ [[5], [1, 10, 11], [10, 11], [10, 11], [10, 11], [4]], @@ -114,7 +114,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_SCHEMAS_TRAIN_1 = pl.DataFrame( { - "patient_id": [68729, 814703], + "subject_id": [68729, 814703], "code": [[8, 9], [8, 9]], "numeric_value": [[None, -0.5438239574432373], [None, -1.1012336015701294]], "start_time": [ts[0] for ts in TRAIN_1_TIMES], @@ -125,7 +125,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_EVENT_SEQ_TRAIN_1 = pl.DataFrame( { - "patient_id": [68729, 814703], + "subject_id": [68729, 814703], "time_delta_days": ts_to_time_delta_days(TRAIN_1_TIMES), "code": [[[5], [3, 10, 11], [4]], [[5], [2, 10, 11], [4]]], "numeric_value": [ @@ -140,7 +140,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_SCHEMAS_TUNING_0 = pl.DataFrame( { - "patient_id": [754281], + "subject_id": [754281], "code": [[7, 9]], "numeric_value": [[None, 0.28697699308395386]], "start_time": [ts[0] for ts in TUNING_0_TIMES], @@ -151,7 +151,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_EVENT_SEQ_TUNING_0 = pl.DataFrame( { - "patient_id": [754281], + "subject_id": [754281], "time_delta_days": ts_to_time_delta_days(TUNING_0_TIMES), "code": [[[5], [3, 10, 11], [4]]], "numeric_value": [ @@ -174,7 +174,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_SCHEMAS_HELD_OUT_0 = pl.DataFrame( { - "patient_id": [1500733], + "subject_id": [1500733], "code": [[7, 9]], "numeric_value": [[None, -0.7995940446853638]], "start_time": [ts[0] for ts in HELD_OUT_0_TIMES], @@ -185,7 +185,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_EVENT_SEQ_HELD_OUT_0 = pl.DataFrame( { - "patient_id": [1500733], + "subject_id": [1500733], "time_delta_days": ts_to_time_delta_days(HELD_OUT_0_TIMES), "code": [[[5], [2, 10, 11], [10, 11], [10, 11], [4]]], "numeric_value": [ diff --git a/tests/transform_tester_base.py b/tests/transform_tester_base.py index 6845e0c5..2deddd48 100644 --- a/tests/transform_tester_base.py +++ b/tests/transform_tester_base.py @@ -22,6 +22,7 @@ import numpy as np import polars as pl import rootutils +from meds import subject_id_field from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict from .utils import MEDS_PL_SCHEMA, assert_df_equal, parse_meds_csvs, run_command @@ -40,7 +41,7 @@ # Filters FILTER_MEASUREMENTS_SCRIPT = filters_root / "filter_measurements.py" - FILTER_PATIENTS_SCRIPT = filters_root / "filter_patients.py" + FILTER_SUBJECTS_SCRIPT = filters_root / "filter_subjects.py" # Transforms ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = transforms_root / "add_time_derived_measurements.py" @@ -57,7 +58,7 @@ # Filters FILTER_MEASUREMENTS_SCRIPT = "MEDS_transform-filter_measurements" - FILTER_PATIENTS_SCRIPT = "MEDS_transform-filter_patients" + FILTER_SUBJECTS_SCRIPT = "MEDS_transform-filter_subjects" # Transforms ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = "MEDS_transform-add_time_derived_measurements" @@ -83,7 +84,7 @@ } MEDS_TRAIN_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -117,7 +118,7 @@ """ MEDS_TRAIN_1 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -135,7 +136,7 @@ """ MEDS_TUNING_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, @@ -146,7 +147,7 @@ """ MEDS_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, @@ -171,7 +172,7 @@ MEDS_CODE_METADATA_CSV = """ -code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_codes +code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_codes ,44,4,28,3198.8389005974336,382968.28937288234,, ADMISSION//CARDIAC,2,2,0,,,, ADMISSION//ORTHOPEDIC,1,1,0,,,, @@ -189,9 +190,9 @@ MEDS_CODE_METADATA_SCHEMA = { "code": pl.Utf8, "code/n_occurrences": pl.UInt8, - "code/n_patients": pl.UInt8, + "code/n_subjects": pl.UInt8, "values/n_occurrences": pl.UInt8, - "values/n_patients": pl.UInt8, + "values/n_subjects": pl.UInt8, "values/sum": pl.Float32, "values/sum_sqd": pl.Float32, "values/n_ints": pl.UInt8, @@ -323,11 +324,11 @@ def input_MEDS_dataset( if input_splits_map is None: input_splits_map = SPLITS input_splits_as_df = defaultdict(list) - for split_name, patient_ids in input_splits_map.items(): - input_splits_as_df["patient_id"].extend(patient_ids) - input_splits_as_df["split"].extend([split_name] * len(patient_ids)) + for split_name, subject_ids in input_splits_map.items(): + input_splits_as_df[subject_id_field].extend(subject_ids) + input_splits_as_df["split"].extend([split_name] * len(subject_ids)) input_splits_df = pl.DataFrame(input_splits_as_df) - input_splits_fp = MEDS_metadata_dir / "patient_splits.parquet" + input_splits_fp = MEDS_metadata_dir / "subject_splits.parquet" input_splits_df.write_parquet(input_splits_fp, use_pyarrow=True) if input_shards is None: diff --git a/tests/utils.py b/tests/utils.py index e6c9d3fd..74efae50 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,7 +11,7 @@ # TODO: Make use meds library MEDS_PL_SCHEMA = { - "patient_id": pl.UInt32, + "subject_id": pl.Int64, "time": pl.Datetime("us"), "code": pl.Utf8, "numeric_value": pl.Float32, From dfd6aa907835a5b1a2f861ce250bdd95612b453e Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 22 Aug 2024 16:01:46 -0400 Subject: [PATCH 02/62] Initial commit to save ideas. Not working --- src/MEDS_transforms/__init__.py | 1 + src/MEDS_transforms/configs/runner.yaml | 40 +++++ src/MEDS_transforms/runner.py | 203 ++++++++++++++++++++++++ src/MEDS_transforms/utils.py | 8 +- 4 files changed, 250 insertions(+), 2 deletions(-) create mode 100644 src/MEDS_transforms/configs/runner.yaml create mode 100644 src/MEDS_transforms/runner.py diff --git a/src/MEDS_transforms/__init__.py b/src/MEDS_transforms/__init__.py index e0aaaf3a..08ef7a69 100644 --- a/src/MEDS_transforms/__init__.py +++ b/src/MEDS_transforms/__init__.py @@ -11,6 +11,7 @@ PREPROCESS_CONFIG_YAML = files(__package_name__).joinpath("configs/preprocess.yaml") EXTRACT_CONFIG_YAML = files(__package_name__).joinpath("configs/extract.yaml") +RUNNER_CONFIG_YAML = files(__package_name__).joinpath("configs/runner.yaml") MANDATORY_COLUMNS = ["patient_id", "time", "code", "numeric_value"] diff --git a/src/MEDS_transforms/configs/runner.yaml b/src/MEDS_transforms/configs/runner.yaml new file mode 100644 index 00000000..b4fd83e7 --- /dev/null +++ b/src/MEDS_transforms/configs/runner.yaml @@ -0,0 +1,40 @@ +# Global IO +pipeline_config_fp: ??? +stage_runner_fp: null + +_pipeline_config: ${oc.create:${load_yaml_file:${pipeline_config_fp}}} + +_default_name: "MEDS-transforms Pipeline" +_pipeline_name: ${oc.select:_pipeline_config.etl_metadata.pipeline_name, _default_name} +_pipeline_description: ${_pipeline_config.description} + +log_dir: "${_pipeline_config.cohort_dir}/.logs" + +_stage_runners: ${oc.create:${load_yaml_file:${stage_runner_fp}}} +stages: ${_pipeline_config.stages} + +do_profile: False + +_pipeline_help_block: |- + + **${_pipeline_name} description:** + + ${_pipeline_description} + +# Hydra #${oc.select:_pipeline_help_block,""} +hydra: + job: + name: "${fix_str_for_path:${_pipeline_name}}_runner_${now:%Y-%m-%d_%H-%M-%S}" + run: + dir: "${log_dir}" + help: + app_name: "MEDS-Transforms Pipeline Runner" + + template: |- + == ${hydra.help.app_name} == + ${hydra.help.app_name} is a command line tool for running entire MEDS-transform pipelines in a single + command. + + ${get_script_docstring:runner} + + ${oc.select:${oc.create:${load_yaml_file:${pipeline_config_fp}}},""} diff --git a/src/MEDS_transforms/runner.py b/src/MEDS_transforms/runner.py new file mode 100644 index 00000000..0657289c --- /dev/null +++ b/src/MEDS_transforms/runner.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +"""This script is a helper utility to run entire pipelines from a single script. + +To do this effectively, this runner functionally takes a "meta configuration" file that contains: + 1. The path to the pipeline configuration file. + 2. Configuration details for how to run each stage of the pipeline, including mappings to the underlying + stage scripts and Hydra launcher configurations for each stage to control parallelism, resources, etc. +""" + +import hydra +from pathlib import Path +from omegaconf import DictConfig, OmegaConf +from MEDS_transforms import RUNNER_CONFIG_YAML +from MEDS_transforms.utils import hydra_loguru_init +import importlib +from typing import Any + +def get_script_from_name(stage_name: str) -> str | None: + """Returns the script name for the given stage name. + + Args: + stage_name: The name of the stage. + + Returns: + The script name for the given stage name. + + """ + + try: + _ = importlib.import_module(f"MEDS_transforms.extract.{stage_name}") + return f"MEDS_extract-{stage_name}" + except ImportError: + pass + + for pfx in ("MEDS_transforms.transforms", "MEDS_transforms.filters", "MEDS_transforms"): + try: + _ = importlib.import_module(f"{pfx}.{stage_name}") + return f"MEDS_transform-{stage_name}" + except ImportError: + pass + + return None + +def get_parallelization_args( + parallelization_cfg: dict | DictConfig | None, default_parallelization_cfg: dict | DictConfig +) -> list[str]: + """Gets the parallelization args.""" + + if parallelization_cfg is None: + return [] + + if "n_workers" in parallelization_cfg: + n_workers = parallelization_cfg["n_workers"] + elif "n_workers" in default_parallelization_cfg: + n_workers = default_parallelization_cfg["n_workers"] + else: + n_workers = 1 + + parallelization_args = [ + "--multirun", + f"worker=range(0,{n_workers})", + ] + + + if "launcher" in parallelization_cfg: + launcher = parallelization_cfg["launcher"] + elif "launcher" in default_parallelization_cfg: + launcher = default_parallelization_cfg["launcher"] + else: + launcher = None + + if launcher is None: + return parallelization_args + + if "launcher_params" in parallelization_cfg: + raise ValueError("If launcher_params is provided, launcher must also be provided.") + + parallelization_args.append(f"hydra/launcher={launcher}") + + if "launcher_params" in parallelization_cfg: + launcher_params = parallelization_cfg["launcher_params"] + elif "launcher_params" in default_parallelization_cfg: + launcher_params = default_parallelization_cfg["launcher_params"] + else: + launcher_params = {} + + for k, v in launcher_params.items(): + parallelization_args.append(f"hydra.launcher.{k}={v}") + + return parallelization_args + + +def run_stage(cfg: DictConfig, stage_name: str, default_parallelization_cfg: dict | DictConfig | None = None): + """Runs a single stage of the pipeline. + + Args: + cfg: The configuration for the entire pipeline. + stage_name: The name of the stage to run. + + """ + + if default_parallelization_cfg is None: + default_parallelization_cfg = {} + + do_profile = cfg.get("do_profile", False) + pipeline_config_fp = Path(cfg.pipeline_config_fp) + stage_config = pipeline_config_fp.stage_configs.get("stage", {}) + stage_runner_config = cfg._stage_runners.get(stage_name, {}) + + script = None + if "script" in stage_runner_config: + script = stage_runner_config.script + elif "_script" in stage_config: + script = stage_config._script + elif get_script_from_name(stage_name): + script = get_script_from_name(stage_name) + else: + raise ValueError(f"Cannot determine script for {stage_name}") + + command_parts = [ + script, + f"--config-path={str(pipeline_config_fp.parent.resolve())}", + f"--config-name={pipeline_config_fp.stem}", + "'hydra.searchpath=[pkg://MEDS_transforms.configs]'", + f"stage={stage_name}", + ] + + command_parts.extend(get_parallelization_args( + stage_runner_config.get("parallelize", {}), default_parallelization_cfg + )) + + if do_profile: + command_parts.append("++hydra.callbacks.profiler._target_=hydra_profiler.profiler.ProfilerCallback") + + full_cmd = " ".join(command_parts) + logger.info(f"Running command: {full_cmd}") + command_out = subprocess.run(full_cmd, shell=True, capture_output=True) + + # https://stackoverflow.com/questions/21953835/run-subprocess-and-print-output-to-logging + # https://loguru.readthedocs.io/en/stable/api/logger.html#loguru._logger.Logger.parse + + stderr = command_out.stderr.decode() + stdout = command_out.stdout.decode() + + if command_out.returncode != 0: + raise ValueError(f"Stage {stage_name} failed with return code {command_out.returncode}.\n{stderr}") + +@hydra.main( + version_base=None, config_path=str(RUNNER_CONFIG_YAML.parent), config_name=RUNNER_CONFIG_YAML.stem +) +def main(cfg: DictConfig): + """Runs the entire pipeline, end-to-end, based on the configuration provided. + + This script will launch many subsidiary commands via `subprocess`, one for each stage of the specified + pipeline. + """ + + hydra_loguru_init() + + pipeline_config_fp = Path(cfg.pipeline_config_fp) + if not pipeline_config_fp.exists(): + raise FileNotFoundError(f"Pipeline configuration file {pipeline_config_fp} does not exist.") + if not pipeline_config_fp.suffix == ".yaml": + raise ValueError(f"Pipeline configuration file {pipeline_config_fp} must have a .yaml extension.") + + logs_dir = Path(cfg.logs_dir) + + if do_profile: + try: + from hydra_profiler.profiler import ProfilerCallback + except ImportError as e: + raise ValueError( + "You can't run in profiling mode without installing hydra-profiler. Try installing " + "MEDS-transforms with the 'profiler' optional dependency: " + "`pip install MEDS-transforms[profiler]`." + ) from e + + global_done_file = logs_dir / f"_all_stages.done" + if global_done_file.exists(): + logger.info("All stages are already complete. Exiting.") + return + + if "parallelize" in cfg: + default_parallelization_cfg = cfg.parallelize + else: + default_parallelization_cfg = None + + for stage in cfg.stages: + done_file = logs_dir / f"{stage}.done" + + if done_file.exists(): + logger.info(f"Skipping stage {stage} as it is already complete.") + else: + logger.info(f"Running stage: {stage}") + run_stage(cfg, stage, default_parallelization_cfg=default_parallelization_cfg) + done_file.touch() + + global_done_file.touch() + +if __name__ == "__main__": + OmegaConf.register_new_resolver("load_yaml_file", OmegaConf.load, replace=False) + + main() diff --git a/src/MEDS_transforms/utils.py b/src/MEDS_transforms/utils.py index 59a7cd69..70c97aaa 100644 --- a/src/MEDS_transforms/utils.py +++ b/src/MEDS_transforms/utils.py @@ -1,6 +1,7 @@ """Core utilities for MEDS pipelines built with these tools.""" import inspect +import importlib import os import sys from pathlib import Path @@ -108,10 +109,13 @@ def get_package_version() -> str: return package_version -def get_script_docstring() -> str: +def get_script_docstring(filename: str | None = None) -> str: """Returns the docstring of the main function of the script from which this function was called.""" - main_module = sys.modules["__main__"] + if filename is not None: + main_module = importlib.import_module(f"MEDS_transforms.{filename}") + else: + main_module = sys.modules["__main__"] func = getattr(main_module, "main", None) if func and callable(func): return inspect.getdoc(func) or "" From df73aae54f6421ec2d6ffbe4c8b49246cf3d6c0e Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 22 Aug 2024 16:04:51 -0400 Subject: [PATCH 03/62] Initial commit to save ideas. Not working --- src/MEDS_transforms/configs/runner.yaml | 2 +- src/MEDS_transforms/runner.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/MEDS_transforms/configs/runner.yaml b/src/MEDS_transforms/configs/runner.yaml index b4fd83e7..527caf01 100644 --- a/src/MEDS_transforms/configs/runner.yaml +++ b/src/MEDS_transforms/configs/runner.yaml @@ -2,7 +2,7 @@ pipeline_config_fp: ??? stage_runner_fp: null -_pipeline_config: ${oc.create:${load_yaml_file:${pipeline_config_fp}}} +_pipeline_config: ${oc.create:${load_yaml_file:${oc.select:pipeline_config_fp,null}}} _default_name: "MEDS-transforms Pipeline" _pipeline_name: ${oc.select:_pipeline_config.etl_metadata.pipeline_name, _default_name} diff --git a/src/MEDS_transforms/runner.py b/src/MEDS_transforms/runner.py index 0657289c..8801cd76 100644 --- a/src/MEDS_transforms/runner.py +++ b/src/MEDS_transforms/runner.py @@ -197,7 +197,15 @@ def main(cfg: DictConfig): global_done_file.touch() +def load_file(path: str) -> Any: + with open(path, "r") as f: + return f.read() + +def load_yaml_file(path: str | None) -> dict | DictConfig: + return {} if path is None else OmegaConf.load(path) + if __name__ == "__main__": OmegaConf.register_new_resolver("load_yaml_file", OmegaConf.load, replace=False) + OmegaConf.register_new_resolver("load_file", load_file, replace=False) main() From 1da3f85ae3d6c4b699b1bbfad0c565bb28448b87 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 25 Aug 2024 13:38:07 -0400 Subject: [PATCH 04/62] Initial stuff; not working yet. --- src/MEDS_transforms/configs/runner.yaml | 16 +++++-------- src/MEDS_transforms/runner.py | 31 ++++++++++++++----------- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/src/MEDS_transforms/configs/runner.yaml b/src/MEDS_transforms/configs/runner.yaml index 527caf01..3cf8fa66 100644 --- a/src/MEDS_transforms/configs/runner.yaml +++ b/src/MEDS_transforms/configs/runner.yaml @@ -3,10 +3,10 @@ pipeline_config_fp: ??? stage_runner_fp: null _pipeline_config: ${oc.create:${load_yaml_file:${oc.select:pipeline_config_fp,null}}} +_etl_metadata: ${oc.select:_pipeline_config.etl_metadata,${oc.create:{}}} -_default_name: "MEDS-transforms Pipeline" -_pipeline_name: ${oc.select:_pipeline_config.etl_metadata.pipeline_name, _default_name} -_pipeline_description: ${_pipeline_config.description} +_pipeline_name: ${oc.select:_etl_metadata.pipeline_name, "MEDS-transforms Pipeline"} +_pipeline_description: ${oc.select:_pipeline_config.description, "No description provided."} log_dir: "${_pipeline_config.cohort_dir}/.logs" @@ -15,12 +15,6 @@ stages: ${_pipeline_config.stages} do_profile: False -_pipeline_help_block: |- - - **${_pipeline_name} description:** - - ${_pipeline_description} - # Hydra #${oc.select:_pipeline_help_block,""} hydra: job: @@ -37,4 +31,6 @@ hydra: ${get_script_docstring:runner} - ${oc.select:${oc.create:${load_yaml_file:${pipeline_config_fp}}},""} + **${_pipeline_name} description:** + + ${_pipeline_description} diff --git a/src/MEDS_transforms/runner.py b/src/MEDS_transforms/runner.py index 8801cd76..7d614c1d 100644 --- a/src/MEDS_transforms/runner.py +++ b/src/MEDS_transforms/runner.py @@ -7,13 +7,16 @@ stage scripts and Hydra launcher configurations for each stage to control parallelism, resources, etc. """ -import hydra +import importlib from pathlib import Path +from typing import Any + +import hydra from omegaconf import DictConfig, OmegaConf + from MEDS_transforms import RUNNER_CONFIG_YAML from MEDS_transforms.utils import hydra_loguru_init -import importlib -from typing import Any + def get_script_from_name(stage_name: str) -> str | None: """Returns the script name for the given stage name. @@ -23,7 +26,6 @@ def get_script_from_name(stage_name: str) -> str | None: Returns: The script name for the given stage name. - """ try: @@ -41,6 +43,7 @@ def get_script_from_name(stage_name: str) -> str | None: return None + def get_parallelization_args( parallelization_cfg: dict | DictConfig | None, default_parallelization_cfg: dict | DictConfig ) -> list[str]: @@ -61,7 +64,6 @@ def get_parallelization_args( f"worker=range(0,{n_workers})", ] - if "launcher" in parallelization_cfg: launcher = parallelization_cfg["launcher"] elif "launcher" in default_parallelization_cfg: @@ -96,7 +98,6 @@ def run_stage(cfg: DictConfig, stage_name: str, default_parallelization_cfg: dic Args: cfg: The configuration for the entire pipeline. stage_name: The name of the stage to run. - """ if default_parallelization_cfg is None: @@ -125,9 +126,9 @@ def run_stage(cfg: DictConfig, stage_name: str, default_parallelization_cfg: dic f"stage={stage_name}", ] - command_parts.extend(get_parallelization_args( - stage_runner_config.get("parallelize", {}), default_parallelization_cfg - )) + command_parts.extend( + get_parallelization_args(stage_runner_config.get("parallelize", {}), default_parallelization_cfg) + ) if do_profile: command_parts.append("++hydra.callbacks.profiler._target_=hydra_profiler.profiler.ProfilerCallback") @@ -145,6 +146,7 @@ def run_stage(cfg: DictConfig, stage_name: str, default_parallelization_cfg: dic if command_out.returncode != 0: raise ValueError(f"Stage {stage_name} failed with return code {command_out.returncode}.\n{stderr}") + @hydra.main( version_base=None, config_path=str(RUNNER_CONFIG_YAML.parent), config_name=RUNNER_CONFIG_YAML.stem ) @@ -167,7 +169,7 @@ def main(cfg: DictConfig): if do_profile: try: - from hydra_profiler.profiler import ProfilerCallback + pass except ImportError as e: raise ValueError( "You can't run in profiling mode without installing hydra-profiler. Try installing " @@ -197,15 +199,18 @@ def main(cfg: DictConfig): global_done_file.touch() + def load_file(path: str) -> Any: - with open(path, "r") as f: + with open(path) as f: return f.read() + def load_yaml_file(path: str | None) -> dict | DictConfig: - return {} if path is None else OmegaConf.load(path) + return OmegaConf.load(path) if path else {} + if __name__ == "__main__": - OmegaConf.register_new_resolver("load_yaml_file", OmegaConf.load, replace=False) + OmegaConf.register_new_resolver("load_yaml_file", load_yaml_file, replace=False) OmegaConf.register_new_resolver("load_file", load_file, replace=False) main() From 0a61775f44825c9e7a1f63bcfff0f71484b3c8ea Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 08:48:44 -0400 Subject: [PATCH 05/62] Added tests to reshard stage --- src/MEDS_transforms/mapreduce/utils.py | 2 +- src/MEDS_transforms/reshard_to_split.py | 48 ++++++++++++++++++++++--- tests/test_reshard_to_split.py | 11 ++++++ tests/transform_tester_base.py | 4 +++ tests/utils.py | 14 +++++--- 5 files changed, 69 insertions(+), 10 deletions(-) diff --git a/src/MEDS_transforms/mapreduce/utils.py b/src/MEDS_transforms/mapreduce/utils.py index 300e4c39..9a7151fd 100644 --- a/src/MEDS_transforms/mapreduce/utils.py +++ b/src/MEDS_transforms/mapreduce/utils.py @@ -483,7 +483,7 @@ def shard_iterator( shards = train_shards includes_only_train = True elif train_only: - logger.info( + logger.warning( f"train_only={train_only} requested but no dedicated train shards found; processing all shards " "and relying on `patient_splits.parquet` for filtering." ) diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index f9361965..a4a0276d 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -19,7 +19,31 @@ def valid_json_file(fp: Path) -> bool: - """Check if a file is a valid JSON file.""" + """Check if a file is a valid JSON file. + + Args: + fp: Path to the file. + + Returns: + True if the file is a valid JSON file, False otherwise. + + Examples: + >>> import tempfile + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... fp = Path(tmpdir) / "test.json" + ... valid_json_file(fp) + False + >>> with tempfile.NamedTemporaryFile(suffix=".json") as tmpfile: + ... fp = Path(tmpfile.name) + ... _ = fp.write_text("foobar not a json file.\tHello, world!") + ... valid_json_file(fp) + False + >>> with tempfile.NamedTemporaryFile(suffix=".json") as tmpfile: + ... fp = Path(tmpfile.name) + ... _ = fp.write_text('{"foo": "bar"}') + ... valid_json_file(fp) + True + """ if not fp.is_file(): return False try: @@ -30,6 +54,7 @@ def valid_json_file(fp: Path) -> bool: def make_new_shards_fn(df: pl.DataFrame, cfg: DictConfig, stage_cfg: DictConfig) -> dict[str, list[str]]: + """This function creates a new sharding scheme for the MEDS cohort.""" splits_map = defaultdict(list) for pt_id, sp in df.iter_rows(): splits_map[sp].append(pt_id) @@ -44,6 +69,20 @@ def make_new_shards_fn(df: pl.DataFrame, cfg: DictConfig, stage_cfg: DictConfig) def write_json(d: dict, fp: Path) -> None: + """Write a dictionary to a JSON file. + + Args: + d: Dictionary to write. + fp: Path to the file. + + Examples: + >>> import tempfile + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... fp = Path(tmpdir) / "test.json" + ... write_json({"foo": "bar"}, fp) + ... fp.read_text() + '{"foo": "bar"}' + """ fp.write_text(json.dumps(d)) @@ -79,9 +118,10 @@ def main(cfg: DictConfig): new_sharded_splits = json.loads(shards_fp.read_text()) - orig_shards_iter, include_only_train = shard_iterator(cfg, out_suffix="") - if include_only_train: - raise ValueError("This stage does not support include_only_train=True") + if cfg.stage_cfg.get("train_only", False): + raise ValueError("This stage does not support train_only=True") + + orig_shards_iter, _ = shard_iterator(cfg, out_suffix="") orig_shards_iter = [(in_fp, out_fp.relative_to(output_dir)) for in_fp, out_fp in orig_shards_iter] diff --git a/tests/test_reshard_to_split.py b/tests/test_reshard_to_split.py index 65056e5a..3af7f197 100644 --- a/tests/test_reshard_to_split.py +++ b/tests/test_reshard_to_split.py @@ -200,3 +200,14 @@ def test_reshard_to_split(): input_shards_map=IN_SHARDS_MAP, input_splits_map=SPLITS, ) + + single_stage_transform_tester( + transform_script=RESHARD_TO_SPLIT_SCRIPT, + stage_name="reshard_to_split", + transform_stage_kwargs={"n_patients_per_shard": 2, "+train_only": True}, + want_data=WANT_SHARDS, + input_shards=IN_SHARDS, + input_shards_map=IN_SHARDS_MAP, + input_splits_map=SPLITS, + should_error=True, + ) diff --git a/tests/transform_tester_base.py b/tests/transform_tester_base.py index bca36ad9..ec061fcd 100644 --- a/tests/transform_tester_base.py +++ b/tests/transform_tester_base.py @@ -404,6 +404,7 @@ def single_stage_transform_tester( want_data: dict[str, pl.DataFrame] | None = None, want_metadata: pl.DataFrame | None = None, assert_no_other_outputs: bool = True, + should_error: bool = False, **input_data_kwargs, ): with input_MEDS_dataset(**input_data_kwargs) as (MEDS_dir, cohort_dir): @@ -421,6 +422,7 @@ def single_stage_transform_tester( "script": transform_script, "hydra_kwargs": pipeline_config_kwargs, "test_name": f"Single stage transform: {stage_name}", + "should_error": should_error, } if do_use_config_yaml: run_command_kwargs["do_use_config_yaml"] = True @@ -431,6 +433,8 @@ def single_stage_transform_tester( # Run the transform stderr, stdout = run_command(**run_command_kwargs) + if should_error: + return try: check_outputs(cohort_dir, want_data=want_data, want_metadata=want_metadata) diff --git a/tests/utils.py b/tests/utils.py index e7220c94..f562bce3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -70,6 +70,8 @@ def dict_to_hydra_kwargs(d: dict[str, str]) -> str: ValueError: Unexpected type for value for key a: : 2021-11-01 00:00:00 """ + modifier_chars = ["~", "'", "++", "+"] + out = [] for k, v in d.items(): if not isinstance(k, str): @@ -86,11 +88,13 @@ def dict_to_hydra_kwargs(d: dict[str, str]) -> str: case dict(): inner_kwargs = dict_to_hydra_kwargs(v) for inner_kv in inner_kwargs: - if inner_kv.startswith("~"): - out.append(f"~{k}.{inner_kv[1:]}") - elif inner_kv.startswith("'"): - out.append(f"'{k}.{inner_kv[1:]}") - else: + handled = False + for mod in modifier_chars: + if inner_kv.startswith(mod): + out.append(f"{mod}{k}.{inner_kv[len(mod):]}") + handled = True + break + if not handled: out.append(f"{k}.{inner_kv}") case list() | tuple(): v = list(v) From 19502ebe07b945bef5b0ce773ef2abe38e0a493e Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 08:58:39 -0400 Subject: [PATCH 06/62] Added error case test for fitting vocabulary indices. --- tests/test_fit_vocabulary_indices.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_fit_vocabulary_indices.py b/tests/test_fit_vocabulary_indices.py index ce7c40a6..c468050f 100644 --- a/tests/test_fit_vocabulary_indices.py +++ b/tests/test_fit_vocabulary_indices.py @@ -35,3 +35,11 @@ def test_fit_vocabulary_indices_with_default_stage_config(): transform_stage_kwargs=None, want_metadata=parse_code_metadata_csv(WANT_CSV), ) + + single_stage_transform_tester( + transform_script=FIT_VOCABULARY_INDICES_SCRIPT, + stage_name="fit_vocabulary_indices", + transform_stage_kwargs={"ordering_method": "file"}, + want_metadata=parse_code_metadata_csv(WANT_CSV), + should_error=True, + ) From 35500190d7f048a0d0bb26c26e48116759626c9d Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 09:11:58 -0400 Subject: [PATCH 07/62] Added a bunch of no-cover lines for things that will get skipped in GitHub CI --- src/MEDS_transforms/__init__.py | 2 +- src/MEDS_transforms/aggregate_code_metadata.py | 2 +- src/MEDS_transforms/extract/convert_to_sharded_events.py | 2 +- src/MEDS_transforms/extract/extract_code_metadata.py | 2 +- src/MEDS_transforms/extract/finalize_MEDS_data.py | 2 +- src/MEDS_transforms/extract/finalize_MEDS_metadata.py | 2 +- src/MEDS_transforms/extract/merge_to_MEDS_cohort.py | 2 +- src/MEDS_transforms/extract/shard_events.py | 2 +- src/MEDS_transforms/extract/split_and_shard_patients.py | 2 +- src/MEDS_transforms/filters/filter_measurements.py | 2 +- src/MEDS_transforms/filters/filter_patients.py | 2 +- src/MEDS_transforms/fit_vocabulary_indices.py | 2 +- src/MEDS_transforms/reshard_to_split.py | 2 +- src/MEDS_transforms/transforms/add_time_derived_measurements.py | 2 +- src/MEDS_transforms/transforms/extract_values.py | 2 +- src/MEDS_transforms/transforms/normalization.py | 2 +- src/MEDS_transforms/transforms/occlude_outliers.py | 2 +- src/MEDS_transforms/transforms/reorder_measurements.py | 2 +- src/MEDS_transforms/transforms/tensorization.py | 2 +- src/MEDS_transforms/transforms/tokenization.py | 2 +- 20 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/MEDS_transforms/__init__.py b/src/MEDS_transforms/__init__.py index 38f2ac92..c5aba54a 100644 --- a/src/MEDS_transforms/__init__.py +++ b/src/MEDS_transforms/__init__.py @@ -6,7 +6,7 @@ __package_name__ = "MEDS_transforms" try: __version__ = version(__package_name__) -except PackageNotFoundError: +except PackageNotFoundError: # pragma: no cover __version__ = "unknown" PREPROCESS_CONFIG_YAML = files(__package_name__).joinpath("configs/preprocess.yaml") diff --git a/src/MEDS_transforms/aggregate_code_metadata.py b/src/MEDS_transforms/aggregate_code_metadata.py index 1f6828b5..ddf2a4a5 100755 --- a/src/MEDS_transforms/aggregate_code_metadata.py +++ b/src/MEDS_transforms/aggregate_code_metadata.py @@ -730,5 +730,5 @@ def main(cfg: DictConfig): run_map_reduce(cfg) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/convert_to_sharded_events.py b/src/MEDS_transforms/extract/convert_to_sharded_events.py index ee4e9d70..dc6a6b0c 100755 --- a/src/MEDS_transforms/extract/convert_to_sharded_events.py +++ b/src/MEDS_transforms/extract/convert_to_sharded_events.py @@ -744,5 +744,5 @@ def compute_fn(df: pl.LazyFrame) -> pl.LazyFrame: logger.info("Subsharded into converted events.") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/extract_code_metadata.py b/src/MEDS_transforms/extract/extract_code_metadata.py index e9133eb6..959d2ff2 100644 --- a/src/MEDS_transforms/extract/extract_code_metadata.py +++ b/src/MEDS_transforms/extract/extract_code_metadata.py @@ -449,5 +449,5 @@ def reducer_fn(*dfs): logger.info(f"Finished reduction in {datetime.now() - start}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/finalize_MEDS_data.py b/src/MEDS_transforms/extract/finalize_MEDS_data.py index f9d68730..7920e4c8 100644 --- a/src/MEDS_transforms/extract/finalize_MEDS_data.py +++ b/src/MEDS_transforms/extract/finalize_MEDS_data.py @@ -134,5 +134,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=get_and_validate_data_schema, write_fn=pq.write_table) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py index 366d89aa..0b215f12 100755 --- a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py +++ b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py @@ -218,5 +218,5 @@ def main(cfg: DictConfig): pq.write_table(patient_splits_tbl, patient_splits_fp) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py index 49a45e1d..f7de0814 100755 --- a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py +++ b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py @@ -242,5 +242,5 @@ def main(cfg: DictConfig): ) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/shard_events.py b/src/MEDS_transforms/extract/shard_events.py index 18450bbb..88f84389 100755 --- a/src/MEDS_transforms/extract/shard_events.py +++ b/src/MEDS_transforms/extract/shard_events.py @@ -429,5 +429,5 @@ def main(cfg: DictConfig): logger.info(f"Sub-sharding completed in {datetime.now() - start}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/split_and_shard_patients.py b/src/MEDS_transforms/extract/split_and_shard_patients.py index a385c735..0cee8365 100755 --- a/src/MEDS_transforms/extract/split_and_shard_patients.py +++ b/src/MEDS_transforms/extract/split_and_shard_patients.py @@ -276,5 +276,5 @@ def main(cfg: DictConfig): logger.info("Done writing sharded patients") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/filters/filter_measurements.py b/src/MEDS_transforms/filters/filter_measurements.py index 36a69387..9f301856 100644 --- a/src/MEDS_transforms/filters/filter_measurements.py +++ b/src/MEDS_transforms/filters/filter_measurements.py @@ -147,5 +147,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=filter_measurements_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/filters/filter_patients.py b/src/MEDS_transforms/filters/filter_patients.py index 36dc3985..c5630b27 100644 --- a/src/MEDS_transforms/filters/filter_patients.py +++ b/src/MEDS_transforms/filters/filter_patients.py @@ -233,5 +233,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=filter_patients_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/fit_vocabulary_indices.py b/src/MEDS_transforms/fit_vocabulary_indices.py index 0fb249b1..b5e43275 100644 --- a/src/MEDS_transforms/fit_vocabulary_indices.py +++ b/src/MEDS_transforms/fit_vocabulary_indices.py @@ -236,5 +236,5 @@ def main(cfg: DictConfig): logger.info(f"Done with {cfg.stage}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index a4a0276d..d74ba867 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -165,5 +165,5 @@ def write_fn(df: pl.LazyFrame, out_fp: Path) -> None: logger.info(f"Done with {cfg.stage}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/add_time_derived_measurements.py b/src/MEDS_transforms/transforms/add_time_derived_measurements.py index 01ec7f98..c0423c21 100644 --- a/src/MEDS_transforms/transforms/add_time_derived_measurements.py +++ b/src/MEDS_transforms/transforms/add_time_derived_measurements.py @@ -398,5 +398,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=add_time_derived_measurements_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/extract_values.py b/src/MEDS_transforms/transforms/extract_values.py index a1d42b65..f99eb6df 100644 --- a/src/MEDS_transforms/transforms/extract_values.py +++ b/src/MEDS_transforms/transforms/extract_values.py @@ -130,5 +130,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=extract_values_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/normalization.py b/src/MEDS_transforms/transforms/normalization.py index fbad9acc..ef43dda9 100644 --- a/src/MEDS_transforms/transforms/normalization.py +++ b/src/MEDS_transforms/transforms/normalization.py @@ -220,5 +220,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=normalize) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/occlude_outliers.py b/src/MEDS_transforms/transforms/occlude_outliers.py index 107407de..f9095e20 100644 --- a/src/MEDS_transforms/transforms/occlude_outliers.py +++ b/src/MEDS_transforms/transforms/occlude_outliers.py @@ -110,5 +110,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=occlude_outliers_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/reorder_measurements.py b/src/MEDS_transforms/transforms/reorder_measurements.py index 1205f771..32f2857d 100644 --- a/src/MEDS_transforms/transforms/reorder_measurements.py +++ b/src/MEDS_transforms/transforms/reorder_measurements.py @@ -184,5 +184,5 @@ def main(cfg: DictConfig): map_over(cfg, reorder_by_code_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/tensorization.py b/src/MEDS_transforms/transforms/tensorization.py index 0266ce21..bfadb7ad 100644 --- a/src/MEDS_transforms/transforms/tensorization.py +++ b/src/MEDS_transforms/transforms/tensorization.py @@ -115,5 +115,5 @@ def main(cfg: DictConfig): ) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/tokenization.py b/src/MEDS_transforms/transforms/tokenization.py index d6f5003f..8965cf1f 100644 --- a/src/MEDS_transforms/transforms/tokenization.py +++ b/src/MEDS_transforms/transforms/tokenization.py @@ -265,5 +265,5 @@ def main(cfg: DictConfig): logger.info(f"Done with {cfg.stage}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() From a2de58e88499f6492c3992459ba939886e6f8767 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 09:13:51 -0400 Subject: [PATCH 08/62] Added an error case test for tokenization. --- src/MEDS_transforms/transforms/tokenization.py | 5 ++--- tests/test_tokenization.py | 9 +++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/MEDS_transforms/transforms/tokenization.py b/src/MEDS_transforms/transforms/tokenization.py index 8965cf1f..09109822 100644 --- a/src/MEDS_transforms/transforms/tokenization.py +++ b/src/MEDS_transforms/transforms/tokenization.py @@ -229,11 +229,10 @@ def main(cfg: DictConfig): ) output_dir = Path(cfg.stage_cfg.output_dir) + if train_only := cfg.stage_cfg.get("train_only", False): + raise ValueError(f"train_only={train_only} is not supported for this stage.") shards_single_output, include_only_train = shard_iterator(cfg) - if include_only_train: - raise ValueError("Not supported for this stage.") - for in_fp, out_fp in shards_single_output: sharded_path = out_fp.relative_to(output_dir) diff --git a/tests/test_tokenization.py b/tests/test_tokenization.py index 693add18..0945c7c7 100644 --- a/tests/test_tokenization.py +++ b/tests/test_tokenization.py @@ -225,3 +225,12 @@ def test_tokenization(): input_shards=NORMALIZED_SHARDS, want_data={**WANT_SCHEMAS, **WANT_EVENT_SEQS}, ) + + single_stage_transform_tester( + transform_script=TOKENIZATION_SCRIPT, + stage_name="tokenization", + transform_stage_kwargs={"train_only": True}, + input_shards=NORMALIZED_SHARDS, + want_data={**WANT_SCHEMAS, **WANT_EVENT_SEQS}, + should_error=True, + ) From 214d0e9ca5496ebf5758df8ad2705855554b20ba Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 09:25:17 -0400 Subject: [PATCH 09/62] Added tests to filter patients --- .../filters/filter_patients.py | 45 ++++++++++++++++--- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/src/MEDS_transforms/filters/filter_patients.py b/src/MEDS_transforms/filters/filter_patients.py index c5630b27..0682257c 100644 --- a/src/MEDS_transforms/filters/filter_patients.py +++ b/src/MEDS_transforms/filters/filter_patients.py @@ -13,7 +13,7 @@ def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_patient: int) -> pl.LazyFrame: - """Filters patients by the number of measurements they have. + """Filters patients by the number of dynamic (timestamp non-null) measurements they have. Args: df: The input DataFrame. @@ -24,11 +24,11 @@ def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_p Examples: >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 2, 2, 3], - ... "time": [1, 2, 1, 1, 2, 1], + ... "patient_id": [1, 1, 1, 2, 2, 3, 3, 4], + ... "time": [1, 2, 1, 1, 2, 1, None, None], ... }) >>> filter_patients_by_num_measurements(df, 1) - shape: (6, 2) + shape: (7, 2) ┌────────────┬──────┐ │ patient_id ┆ time │ │ --- ┆ --- │ @@ -40,6 +40,7 @@ def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_p │ 2 ┆ 1 │ │ 2 ┆ 2 │ │ 3 ┆ 1 │ + │ 3 ┆ null │ └────────────┴──────┘ >>> filter_patients_by_num_measurements(df, 2) shape: (5, 2) @@ -102,7 +103,8 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) ... "patient_id": [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4], ... "time": [1, 1, 1, 1, 2, 1, 1, 2, 3, None, None, 1, 2, 3], ... }) - >>> filter_patients_by_num_events(df, 1) + >>> with pl.Config(tbl_rows=15): + ... filter_patients_by_num_events(df, 1) shape: (14, 2) ┌────────────┬──────┐ │ patient_id ┆ time │ @@ -124,7 +126,8 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) │ 4 ┆ 2 │ │ 4 ┆ 3 │ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 2) + >>> with pl.Config(tbl_rows=15): + ... filter_patients_by_num_events(df, 2) shape: (11, 2) ┌────────────┬──────┐ │ patient_id ┆ time │ @@ -195,6 +198,36 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) def filter_patients_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.LazyFrame]: + """Returns a function that filters patients by the number of measurements and events they have. + + Args: + stage_cfg: The stage configuration. Arguments include: min_measurements_per_patient, + min_events_per_patient, both of which should be integers or None which specify the minimum number + of measurements and events a patient must have to be included, respectively. + + Returns: + The function that filters patients by the number of measurements and/or events they have. + + Examples: + >>> df = pl.DataFrame({ + ... "patient_id": [1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5], + ... "time": [1, 1, 1, 1, 1, 1, 2, 3, None, None, 1, 2, 2, None, 1, 2, 3, 1], + ... }) + >>> stage_cfg = DictConfig({"min_measurements_per_patient": 4, "min_events_per_patient": 2}) + >>> filter_patients_fntr(stage_cfg)(df) + shape: (4, 2) + ┌────────────┬──────┐ + │ patient_id ┆ time │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞════════════╪══════╡ + │ 5 ┆ 1 │ + │ 5 ┆ 2 │ + │ 5 ┆ 3 │ + │ 5 ┆ 1 │ + └────────────┴──────┘ + """ + compute_fns = [] if stage_cfg.min_measurements_per_patient: logger.info( From 7d58386a408b4cb04efde4b39dd8021851d1191d Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 09:27:19 -0400 Subject: [PATCH 10/62] Added tests to filter measurements --- .../filters/filter_measurements.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/MEDS_transforms/filters/filter_measurements.py b/src/MEDS_transforms/filters/filter_measurements.py index 9f301856..24529560 100644 --- a/src/MEDS_transforms/filters/filter_measurements.py +++ b/src/MEDS_transforms/filters/filter_measurements.py @@ -97,6 +97,32 @@ def filter_measurements_fntr( ╞════════════╪══════╪═══════════╡ │ 2 ┆ A ┆ 2 │ └────────────┴──────┴───────────┘ + + This stage works even if the default row index column exists: + >>> code_metadata_df = pl.DataFrame({ + ... "code": ["A", "A", "B", "C"], + ... "modifier1": [1, 2, 1, 2], + ... "code/n_patients": [2, 1, 3, 2], + ... "code/n_occurrences": [4, 5, 3, 2], + ... }) + >>> data = pl.DataFrame({ + ... "patient_id": [1, 1, 2, 2], + ... "code": ["A", "B", "A", "C"], + ... "modifier1": [1, 1, 2, 2], + ... "_row_idx": [1, 1, 1, 1], + ... }).lazy() + >>> stage_cfg = DictConfig({"min_patients_per_code": 2, "min_occurrences_per_code": 3}) + >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) + >>> fn(data).collect() + shape: (2, 4) + ┌────────────┬──────┬───────────┬──────────┐ + │ patient_id ┆ code ┆ modifier1 ┆ _row_idx │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ str ┆ i64 ┆ i64 │ + ╞════════════╪══════╪═══════════╪══════════╡ + │ 1 ┆ A ┆ 1 ┆ 1 │ + │ 1 ┆ B ┆ 1 ┆ 1 │ + └────────────┴──────┴───────────┴──────────┘ """ min_patients_per_code = stage_cfg.get("min_patients_per_code", None) From fde006713cc845404a37294f04ca89d36f1ec64f Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 09:57:53 -0400 Subject: [PATCH 11/62] Corrected typo in filter measurements tests. --- src/MEDS_transforms/filters/filter_measurements.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/MEDS_transforms/filters/filter_measurements.py b/src/MEDS_transforms/filters/filter_measurements.py index b0d9021f..979dc0ab 100644 --- a/src/MEDS_transforms/filters/filter_measurements.py +++ b/src/MEDS_transforms/filters/filter_measurements.py @@ -102,21 +102,21 @@ def filter_measurements_fntr( >>> code_metadata_df = pl.DataFrame({ ... "code": ["A", "A", "B", "C"], ... "modifier1": [1, 2, 1, 2], - ... "code/n_patients": [2, 1, 3, 2], + ... "code/n_subjects": [2, 1, 3, 2], ... "code/n_occurrences": [4, 5, 3, 2], ... }) >>> data = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "code": ["A", "B", "A", "C"], ... "modifier1": [1, 1, 2, 2], ... "_row_idx": [1, 1, 1, 1], ... }).lazy() - >>> stage_cfg = DictConfig({"min_patients_per_code": 2, "min_occurrences_per_code": 3}) + >>> stage_cfg = DictConfig({"min_subjects_per_code": 2, "min_occurrences_per_code": 3}) >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) >>> fn(data).collect() shape: (2, 4) ┌────────────┬──────┬───────────┬──────────┐ - │ patient_id ┆ code ┆ modifier1 ┆ _row_idx │ + │ subject_id ┆ code ┆ modifier1 ┆ _row_idx │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 ┆ i64 │ ╞════════════╪══════╪═══════════╪══════════╡ From 6bfae894c1d46a1f65f6953a68c31c9bef97b0eb Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 15:07:17 -0400 Subject: [PATCH 12/62] Re-organized tests. --- tests/MEDS_Extract/__init__.py | 0 tests/{ => MEDS_Extract}/test_extract.py | 2 +- tests/{ => MEDS_Extract}/test_extract_no_metadata.py | 2 +- tests/MEDS_Transforms/__init__.py | 0 .../test_add_time_derived_measurements.py | 10 ++++++++-- .../test_aggregate_code_metadata.py | 5 ++++- tests/{ => MEDS_Transforms}/test_extract_values.py | 10 +++++++++- .../{ => MEDS_Transforms}/test_filter_measurements.py | 10 ++++++++-- tests/{ => MEDS_Transforms}/test_filter_subjects.py | 7 +++++-- .../test_fit_vocabulary_indices.py | 6 +++++- .../test_multi_stage_preprocess_pipeline.py | 6 +++++- tests/{ => MEDS_Transforms}/test_normalization.py | 7 +++++-- tests/{ => MEDS_Transforms}/test_occlude_outliers.py | 7 +++++-- .../{ => MEDS_Transforms}/test_reorder_measurements.py | 10 ++++++++-- tests/{ => MEDS_Transforms}/test_reshard_to_split.py | 7 +++++-- tests/{ => MEDS_Transforms}/test_tensorization.py | 7 +++++-- tests/{ => MEDS_Transforms}/test_tokenization.py | 0 tests/{ => MEDS_Transforms}/transform_tester_base.py | 2 +- 18 files changed, 75 insertions(+), 23 deletions(-) create mode 100644 tests/MEDS_Extract/__init__.py rename tests/{ => MEDS_Extract}/test_extract.py (99%) rename tests/{ => MEDS_Extract}/test_extract_no_metadata.py (99%) create mode 100644 tests/MEDS_Transforms/__init__.py rename tests/{ => MEDS_Transforms}/test_add_time_derived_measurements.py (96%) rename tests/{ => MEDS_Transforms}/test_aggregate_code_metadata.py (97%) rename tests/{ => MEDS_Transforms}/test_extract_values.py (94%) rename tests/{ => MEDS_Transforms}/test_filter_measurements.py (96%) rename tests/{ => MEDS_Transforms}/test_filter_subjects.py (91%) rename tests/{ => MEDS_Transforms}/test_fit_vocabulary_indices.py (91%) rename tests/{ => MEDS_Transforms}/test_multi_stage_preprocess_pipeline.py (99%) rename tests/{ => MEDS_Transforms}/test_normalization.py (96%) rename tests/{ => MEDS_Transforms}/test_occlude_outliers.py (95%) rename tests/{ => MEDS_Transforms}/test_reorder_measurements.py (92%) rename tests/{ => MEDS_Transforms}/test_reshard_to_split.py (96%) rename tests/{ => MEDS_Transforms}/test_tensorization.py (73%) rename tests/{ => MEDS_Transforms}/test_tokenization.py (100%) rename tests/{ => MEDS_Transforms}/transform_tester_base.py (99%) diff --git a/tests/MEDS_Extract/__init__.py b/tests/MEDS_Extract/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_extract.py b/tests/MEDS_Extract/test_extract.py similarity index 99% rename from tests/test_extract.py rename to tests/MEDS_Extract/test_extract.py index 787d5d83..b1b50a33 100644 --- a/tests/test_extract.py +++ b/tests/MEDS_Extract/test_extract.py @@ -38,7 +38,7 @@ import polars as pl from meds import __version__ as MEDS_VERSION -from .utils import assert_df_equal, run_command +from tests.utils import assert_df_equal, run_command # Test data (inputs) diff --git a/tests/test_extract_no_metadata.py b/tests/MEDS_Extract/test_extract_no_metadata.py similarity index 99% rename from tests/test_extract_no_metadata.py rename to tests/MEDS_Extract/test_extract_no_metadata.py index 2391a977..ed783f4f 100644 --- a/tests/test_extract_no_metadata.py +++ b/tests/MEDS_Extract/test_extract_no_metadata.py @@ -38,7 +38,7 @@ import polars as pl from meds import __version__ as MEDS_VERSION -from .utils import assert_df_equal, run_command +from tests.utils import assert_df_equal, run_command # Test data (inputs) diff --git a/tests/MEDS_Transforms/__init__.py b/tests/MEDS_Transforms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_add_time_derived_measurements.py b/tests/MEDS_Transforms/test_add_time_derived_measurements.py similarity index 96% rename from tests/test_add_time_derived_measurements.py rename to tests/MEDS_Transforms/test_add_time_derived_measurements.py index 964cad9c..ed7bbba9 100644 --- a/tests/test_add_time_derived_measurements.py +++ b/tests/MEDS_Transforms/test_add_time_derived_measurements.py @@ -3,11 +3,17 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) from meds import subject_id_field -from .transform_tester_base import ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, single_stage_transform_tester -from .utils import parse_meds_csvs +from tests.MEDS_Transforms.transform_tester_base import ( + ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, + single_stage_transform_tester, +) +from tests.utils import parse_meds_csvs AGE_CALCULATION_STR = """ See `add_time_derived_measurements.py` for the source of the constant value. diff --git a/tests/test_aggregate_code_metadata.py b/tests/MEDS_Transforms/test_aggregate_code_metadata.py similarity index 97% rename from tests/test_aggregate_code_metadata.py rename to tests/MEDS_Transforms/test_aggregate_code_metadata.py index 7d3d2a4d..c62bb94b 100644 --- a/tests/test_aggregate_code_metadata.py +++ b/tests/MEDS_Transforms/test_aggregate_code_metadata.py @@ -3,10 +3,13 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) import polars as pl -from .transform_tester_base import ( +from tests.MEDS_Transforms.transform_tester_base import ( AGGREGATE_CODE_METADATA_SCRIPT, MEDS_CODE_METADATA_SCHEMA, single_stage_transform_tester, diff --git a/tests/test_extract_values.py b/tests/MEDS_Transforms/test_extract_values.py similarity index 94% rename from tests/test_extract_values.py rename to tests/MEDS_Transforms/test_extract_values.py index 0368e5b7..d273c99b 100644 --- a/tests/test_extract_values.py +++ b/tests/MEDS_Transforms/test_extract_values.py @@ -4,7 +4,15 @@ scripts. """ -from .transform_tester_base import EXTRACT_VALUES_SCRIPT, parse_shards_yaml, single_stage_transform_tester +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +from tests.MEDS_Transforms.transform_tester_base import ( + EXTRACT_VALUES_SCRIPT, + parse_shards_yaml, + single_stage_transform_tester, +) INPUT_SHARDS = parse_shards_yaml( """ diff --git a/tests/test_filter_measurements.py b/tests/MEDS_Transforms/test_filter_measurements.py similarity index 96% rename from tests/test_filter_measurements.py rename to tests/MEDS_Transforms/test_filter_measurements.py index 3a34835f..a3e53d9a 100644 --- a/tests/test_filter_measurements.py +++ b/tests/MEDS_Transforms/test_filter_measurements.py @@ -3,10 +3,16 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ +import rootutils +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) -from .transform_tester_base import FILTER_MEASUREMENTS_SCRIPT, single_stage_transform_tester -from .utils import parse_meds_csvs + +from tests.MEDS_Transforms.transform_tester_base import ( + FILTER_MEASUREMENTS_SCRIPT, + single_stage_transform_tester, +) +from tests.utils import parse_meds_csvs # This is the code metadata # MEDS_CODE_METADATA_CSV = """ diff --git a/tests/test_filter_subjects.py b/tests/MEDS_Transforms/test_filter_subjects.py similarity index 91% rename from tests/test_filter_subjects.py rename to tests/MEDS_Transforms/test_filter_subjects.py index 1defee47..4d4f2ca1 100644 --- a/tests/test_filter_subjects.py +++ b/tests/MEDS_Transforms/test_filter_subjects.py @@ -4,10 +4,13 @@ scripts. """ +import rootutils from meds import subject_id_field -from .transform_tester_base import FILTER_SUBJECTS_SCRIPT, single_stage_transform_tester -from .utils import parse_meds_csvs +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +from tests.MEDS_Transforms.transform_tester_base import FILTER_SUBJECTS_SCRIPT, single_stage_transform_tester +from tests.utils import parse_meds_csvs WANT_TRAIN_0 = f""" {subject_id_field},time,code,numeric_value diff --git a/tests/test_fit_vocabulary_indices.py b/tests/MEDS_Transforms/test_fit_vocabulary_indices.py similarity index 91% rename from tests/test_fit_vocabulary_indices.py rename to tests/MEDS_Transforms/test_fit_vocabulary_indices.py index 607b41ee..ea6c1c5d 100644 --- a/tests/test_fit_vocabulary_indices.py +++ b/tests/MEDS_Transforms/test_fit_vocabulary_indices.py @@ -4,8 +4,12 @@ scripts. """ +import rootutils -from .transform_tester_base import ( +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + + +from tests.MEDS_Transforms.transform_tester_base import ( FIT_VOCABULARY_INDICES_SCRIPT, parse_code_metadata_csv, single_stage_transform_tester, diff --git a/tests/test_multi_stage_preprocess_pipeline.py b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py similarity index 99% rename from tests/test_multi_stage_preprocess_pipeline.py rename to tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py index 32c4a277..15c4b96d 100644 --- a/tests/test_multi_stage_preprocess_pipeline.py +++ b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py @@ -17,13 +17,17 @@ The stage configuration arguments will be as given in the yaml block below: """ +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + from datetime import datetime import polars as pl from meds import subject_id_field from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict -from .transform_tester_base import ( +from tests.MEDS_Transforms.transform_tester_base import ( ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, AGGREGATE_CODE_METADATA_SCRIPT, FILTER_SUBJECTS_SCRIPT, diff --git a/tests/test_normalization.py b/tests/MEDS_Transforms/test_normalization.py similarity index 96% rename from tests/test_normalization.py rename to tests/MEDS_Transforms/test_normalization.py index 14207c4c..b6f386f2 100644 --- a/tests/test_normalization.py +++ b/tests/MEDS_Transforms/test_normalization.py @@ -5,9 +5,12 @@ """ import polars as pl +import rootutils -from .transform_tester_base import NORMALIZATION_SCRIPT, single_stage_transform_tester -from .utils import MEDS_PL_SCHEMA, parse_meds_csvs +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +from tests.MEDS_Transforms.transform_tester_base import NORMALIZATION_SCRIPT, single_stage_transform_tester +from tests.utils import MEDS_PL_SCHEMA, parse_meds_csvs # This is the code metadata file we'll use in this transform test. It is different than the default as we need # a code/vocab_index diff --git a/tests/test_occlude_outliers.py b/tests/MEDS_Transforms/test_occlude_outliers.py similarity index 95% rename from tests/test_occlude_outliers.py rename to tests/MEDS_Transforms/test_occlude_outliers.py index f13a4fa0..ad3d3213 100644 --- a/tests/test_occlude_outliers.py +++ b/tests/MEDS_Transforms/test_occlude_outliers.py @@ -4,11 +4,14 @@ scripts. """ +import rootutils + +rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) import polars as pl -from .transform_tester_base import OCCLUDE_OUTLIERS_SCRIPT, single_stage_transform_tester -from .utils import MEDS_PL_SCHEMA, parse_meds_csvs +from tests.MEDS_Transforms.transform_tester_base import OCCLUDE_OUTLIERS_SCRIPT, single_stage_transform_tester +from tests.utils import MEDS_PL_SCHEMA, parse_meds_csvs # This is the code metadata # MEDS_CODE_METADATA_CSV = """ diff --git a/tests/test_reorder_measurements.py b/tests/MEDS_Transforms/test_reorder_measurements.py similarity index 92% rename from tests/test_reorder_measurements.py rename to tests/MEDS_Transforms/test_reorder_measurements.py index 7cc7aaa3..c4a2a549 100644 --- a/tests/test_reorder_measurements.py +++ b/tests/MEDS_Transforms/test_reorder_measurements.py @@ -4,9 +4,15 @@ scripts. """ +import rootutils -from .transform_tester_base import REORDER_MEASUREMENTS_SCRIPT, single_stage_transform_tester -from .utils import parse_meds_csvs +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +from tests.MEDS_Transforms.transform_tester_base import ( + REORDER_MEASUREMENTS_SCRIPT, + single_stage_transform_tester, +) +from tests.utils import parse_meds_csvs ORDERED_CODE_PATTERNS = [ "ADMISSION.*", diff --git a/tests/test_reshard_to_split.py b/tests/MEDS_Transforms/test_reshard_to_split.py similarity index 96% rename from tests/test_reshard_to_split.py rename to tests/MEDS_Transforms/test_reshard_to_split.py index b479e5a0..19008bca 100644 --- a/tests/test_reshard_to_split.py +++ b/tests/MEDS_Transforms/test_reshard_to_split.py @@ -3,12 +3,15 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) from meds import subject_id_field -from .transform_tester_base import RESHARD_TO_SPLIT_SCRIPT, single_stage_transform_tester -from .utils import parse_meds_csvs +from tests.MEDS_Transforms.transform_tester_base import RESHARD_TO_SPLIT_SCRIPT, single_stage_transform_tester +from tests.utils import parse_meds_csvs IN_SHARDS_MAP = { "0": [68729, 1195293], diff --git a/tests/test_tensorization.py b/tests/MEDS_Transforms/test_tensorization.py similarity index 73% rename from tests/test_tensorization.py rename to tests/MEDS_Transforms/test_tensorization.py index 03371558..b648e6ea 100644 --- a/tests/test_tensorization.py +++ b/tests/MEDS_Transforms/test_tensorization.py @@ -6,11 +6,14 @@ scripts. """ +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict -from .test_tokenization import WANT_EVENT_SEQS as TOKENIZED_SHARDS -from .transform_tester_base import TENSORIZATION_SCRIPT, single_stage_transform_tester +from tests.MEDS_Transforms.test_tokenization import WANT_EVENT_SEQS as TOKENIZED_SHARDS +from tests.MEDS_Transforms.transform_tester_base import TENSORIZATION_SCRIPT, single_stage_transform_tester WANT_NRTS = { f'{k.replace("event_seqs/", "")}.nrt': JointNestedRaggedTensorDict( diff --git a/tests/test_tokenization.py b/tests/MEDS_Transforms/test_tokenization.py similarity index 100% rename from tests/test_tokenization.py rename to tests/MEDS_Transforms/test_tokenization.py diff --git a/tests/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py similarity index 99% rename from tests/transform_tester_base.py rename to tests/MEDS_Transforms/transform_tester_base.py index 21fc2310..9b1184fd 100644 --- a/tests/transform_tester_base.py +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -25,7 +25,7 @@ from meds import subject_id_field from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict -from .utils import MEDS_PL_SCHEMA, assert_df_equal, parse_meds_csvs, run_command +from tests.utils import MEDS_PL_SCHEMA, assert_df_equal, parse_meds_csvs, run_command root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) From 26665c9f686ec8e45b211dd1ed1f82ceba782e68 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 16:21:17 -0400 Subject: [PATCH 13/62] Moved single and multi-stage helpers out to be general in preparation for extraction test refactoring. --- .../test_aggregate_code_metadata.py | 1 + .../test_multi_stage_preprocess_pipeline.py | 1 - .../MEDS_Transforms/transform_tester_base.py | 328 ++++-------------- tests/utils.py | 264 ++++++++++++++ 4 files changed, 334 insertions(+), 260 deletions(-) diff --git a/tests/MEDS_Transforms/test_aggregate_code_metadata.py b/tests/MEDS_Transforms/test_aggregate_code_metadata.py index c62bb94b..48ff79bb 100644 --- a/tests/MEDS_Transforms/test_aggregate_code_metadata.py +++ b/tests/MEDS_Transforms/test_aggregate_code_metadata.py @@ -186,4 +186,5 @@ def test_aggregate_code_metadata(): want_metadata=WANT_OUTPUT_CODE_METADATA_FILE, input_code_metadata=MEDS_CODE_METADATA_FILE, do_use_config_yaml=True, + assert_no_other_outputs=False, ) diff --git a/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py index 15c4b96d..0deade2a 100644 --- a/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py +++ b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py @@ -1089,6 +1089,5 @@ def test_pipeline(): **WANT_TOKENIZATION_EVENT_SEQS, **WANT_NRTs, }, - outputs_from_cohort_dir=True, input_code_metadata=MEDS_CODE_METADATA, ) diff --git a/tests/MEDS_Transforms/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py index 9b1184fd..1e692e45 100644 --- a/tests/MEDS_Transforms/transform_tester_base.py +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -11,21 +11,16 @@ except ImportError: from yaml import Loader -import json import os -import tempfile from collections import defaultdict -from contextlib import contextmanager from io import StringIO from pathlib import Path -import numpy as np import polars as pl import rootutils from meds import subject_id_field -from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict -from tests.utils import MEDS_PL_SCHEMA, assert_df_equal, parse_meds_csvs, run_command +from tests.utils import FILE_T, MEDS_PL_SCHEMA, multi_stage_tester, parse_meds_csvs, single_stage_tester root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) @@ -223,177 +218,45 @@ def parse_code_metadata_csv(csv_str: str) -> pl.DataFrame: MEDS_CODE_METADATA = parse_code_metadata_csv(MEDS_CODE_METADATA_CSV) -def check_NRT_output( - output_fp: Path, - want_nrt: JointNestedRaggedTensorDict, -): - assert output_fp.is_file(), f"Expected {output_fp} to exist." - - got_nrt = JointNestedRaggedTensorDict.load(output_fp) - - # assert got_nrt.schema == want_nrt.schema, ( - # f"Expected the schema of the NRT at {output_fp} to be equal to the target.\n" - # f"Wanted:\n{want_nrt.schema}\n" - # f"Got:\n{got_nrt.schema}" - # ) - - want_tensors = want_nrt.tensors - got_tensors = got_nrt.tensors - - assert got_tensors.keys() == want_tensors.keys(), ( - f"Expected the keys of the NRT at {output_fp} to be equal to the target.\n" - f"Wanted:\n{list(want_tensors.keys())}\n" - f"Got:\n{list(got_tensors.keys())}" - ) - - for k in want_tensors.keys(): - want_v = want_tensors[k] - got_v = got_tensors[k] - - assert type(want_v) is type(got_v), ( - f"Expected tensor {k} of the NRT at {output_fp} to be of the same type as the target.\n" - f"Wanted:\n{type(want_v)}\n" - f"Got:\n{type(got_v)}" - ) - - if isinstance(want_v, list): - assert len(want_v) == len(got_v), ( - f"Expected list {k} of the NRT at {output_fp} to be of the same length as the target.\n" - f"Wanted:\n{len(want_v)}\n" - f"Got:\n{len(got_v)}" - ) - for i, (want_i, got_i) in enumerate(zip(want_v, got_v)): - assert np.array_equal(want_i, got_i, equal_nan=True), ( - f"Expected tensor {k}[{i}] of the NRT at {output_fp} to be equal to the target.\n" - f"Wanted:\n{want_i}\n" - f"Got:\n{got_i}" - ) - else: - assert np.array_equal(want_v, got_v, equal_nan=True), ( - f"Expected tensor {k} of the NRT at {output_fp} to be equal to the target.\n" - f"Wanted:\n{want_v}\n" - f"Got:\n{got_v}" - ) - - -def check_df_output( - output_fp: Path, - want_df: pl.DataFrame, - check_column_order: bool = False, - check_row_order: bool = True, - **kwargs, -): - assert output_fp.is_file(), f"Expected {output_fp} to exist." - - got_df = pl.read_parquet(output_fp, glob=False) - assert_df_equal( - want_df, - got_df, - (f"Expected the dataframe at {output_fp} to be equal to the target.\n"), - check_column_order=check_column_order, - check_row_order=check_row_order, - **kwargs, - ) - - -@contextmanager -def input_MEDS_dataset( +def remap_inputs_for_transform( input_code_metadata: pl.DataFrame | str | None = None, input_shards: dict[str, pl.DataFrame] | None = None, input_shards_map: dict[str, list[int]] | None = None, input_splits_map: dict[str, list[int]] | None = None, -): - with tempfile.TemporaryDirectory() as d: - MEDS_dir = Path(d) / "MEDS_cohort" - cohort_dir = Path(d) / "output_cohort" - - MEDS_data_dir = MEDS_dir / "data" - MEDS_metadata_dir = MEDS_dir / "metadata" - - # Create the directories - MEDS_data_dir.mkdir(parents=True) - MEDS_metadata_dir.mkdir(parents=True) - cohort_dir.mkdir(parents=True) - - # Write the shards map - if input_shards_map is None: - input_shards_map = SHARDS - - shards_fp = MEDS_metadata_dir / ".shards.json" - shards_fp.write_text(json.dumps(input_shards_map)) - - # Write the splits parquet file - if input_splits_map is None: - input_splits_map = SPLITS - input_splits_as_df = defaultdict(list) - for split_name, subject_ids in input_splits_map.items(): - input_splits_as_df[subject_id_field].extend(subject_ids) - input_splits_as_df["split"].extend([split_name] * len(subject_ids)) - input_splits_df = pl.DataFrame(input_splits_as_df) - input_splits_fp = MEDS_metadata_dir / "subject_splits.parquet" - input_splits_df.write_parquet(input_splits_fp, use_pyarrow=True) - - if input_shards is None: - input_shards = MEDS_SHARDS - - # Write the shards - for shard_name, df in input_shards.items(): - fp = MEDS_data_dir / f"{shard_name}.parquet" - fp.parent.mkdir(parents=True, exist_ok=True) - df.write_parquet(fp, use_pyarrow=True) - - code_metadata_fp = MEDS_metadata_dir / "codes.parquet" - if input_code_metadata is None: - input_code_metadata = MEDS_CODE_METADATA - elif isinstance(input_code_metadata, str): - input_code_metadata = parse_code_metadata_csv(input_code_metadata) - input_code_metadata.write_parquet(code_metadata_fp, use_pyarrow=True) - - yield MEDS_dir, cohort_dir - - -def check_outputs( - cohort_dir: Path, - want_data: dict[str, pl.DataFrame] | None = None, - want_metadata: dict[str, pl.DataFrame] | pl.DataFrame | None = None, - assert_no_other_outputs: bool = True, - outputs_from_cohort_dir: bool = False, -): - if want_metadata is not None: - if isinstance(want_metadata, pl.DataFrame): - want_metadata = {"codes.parquet": want_metadata} - metadata_root = cohort_dir if outputs_from_cohort_dir else cohort_dir / "metadata" - for shard_name, want in want_metadata.items(): - if Path(shard_name).suffix == "": - shard_name = f"{shard_name}.parquet" - check_df_output(metadata_root / shard_name, want) +) -> dict[str, FILE_T]: + unified_inputs = {} - if want_data: - data_root = cohort_dir if outputs_from_cohort_dir else cohort_dir / "data" - all_file_suffixes = set() - for shard_name, want in want_data.items(): - if Path(shard_name).suffix == "": - shard_name = f"{shard_name}.parquet" - - file_suffix = Path(shard_name).suffix - all_file_suffixes.add(file_suffix) - - output_fp = data_root / f"{shard_name}" - if file_suffix == ".parquet": - check_df_output(output_fp, want) - elif file_suffix == ".nrt": - check_NRT_output(output_fp, want) - else: - raise ValueError(f"Unknown file suffix: {file_suffix}") - - if assert_no_other_outputs: - all_outputs = [] - for suffix in all_file_suffixes: - all_outputs.extend(list((data_root).glob(f"**/*{suffix}"))) - assert len(want_data) == len(all_outputs), ( - f"Want {len(want_data)} outputs, but found {len(all_outputs)}.\n" - f"Found outputs: {[fp.relative_to(data_root) for fp in all_outputs]}\n" - ) + if input_code_metadata is None: + input_code_metadata = MEDS_CODE_METADATA + elif isinstance(input_code_metadata, str): + input_code_metadata = parse_code_metadata_csv(input_code_metadata) + + unified_inputs["metadata/codes.parquet"] = input_code_metadata + + if input_shards is None: + input_shards = MEDS_SHARDS + + for shard_name, df in input_shards.items(): + unified_inputs[f"data/{shard_name}.parquet"] = df + + if input_shards_map is None: + input_shards_map = SHARDS + + unified_inputs["metadata/.shards.json"] = input_shards_map + + if input_splits_map is None: + input_splits_map = SPLITS + + input_splits_as_df = defaultdict(list) + for split_name, subject_ids in input_splits_map.items(): + input_splits_as_df[subject_id_field].extend(subject_ids) + input_splits_as_df["split"].extend([split_name] * len(subject_ids)) + + input_splits_df = pl.DataFrame(input_splits_as_df) + + unified_inputs["metadata/subject_splits.parquet"] = input_splits_df + + return unified_inputs def single_stage_transform_tester( @@ -408,43 +271,28 @@ def single_stage_transform_tester( should_error: bool = False, **input_data_kwargs, ): - with input_MEDS_dataset(**input_data_kwargs) as (MEDS_dir, cohort_dir): - pipeline_config_kwargs = { - "input_dir": str(MEDS_dir.resolve()), - "cohort_dir": str(cohort_dir.resolve()), - "stages": [stage_name], - "hydra.verbose": True, - } - - if transform_stage_kwargs: - pipeline_config_kwargs["stage_configs"] = {stage_name: transform_stage_kwargs} - - run_command_kwargs = { - "script": transform_script, - "hydra_kwargs": pipeline_config_kwargs, - "test_name": f"Single stage transform: {stage_name}", - "should_error": should_error, - } - if do_use_config_yaml: - run_command_kwargs["do_use_config_yaml"] = True - run_command_kwargs["config_name"] = "preprocess" - if do_pass_stage_name: - run_command_kwargs["stage"] = stage_name - run_command_kwargs["do_pass_stage_name"] = True - - # Run the transform - stderr, stdout = run_command(**run_command_kwargs) - if should_error: - return - - try: - check_outputs(cohort_dir, want_data=want_data, want_metadata=want_metadata) - except Exception as e: - raise AssertionError( - f"Single stage transform {stage_name} failed.\n" - f"Script stdout:\n{stdout}\n" - f"Script stderr:\n{stderr}" - ) from e + base_kwargs = { + "script": transform_script, + "stage_name": stage_name, + "stage_kwargs": transform_stage_kwargs, + "do_pass_stage_name": do_pass_stage_name, + "do_use_config_yaml": do_use_config_yaml, + "assert_no_other_outputs": assert_no_other_outputs, + "should_error": should_error, + "config_name": "preprocess", + "input_files": remap_inputs_for_transform(**input_data_kwargs), + } + + want_outputs = {} + if want_data: + for data_fn, want in want_data.items(): + want_outputs[f"data/{data_fn}"] = want + if want_metadata is not None: + want_outputs["metadata/codes.parquet"] = want_metadata + + base_kwargs["want_outputs"] = want_outputs + + single_stage_tester(**base_kwargs) def multi_stage_transform_tester( @@ -454,55 +302,17 @@ def multi_stage_transform_tester( do_pass_stage_name: bool | dict[str, bool] = True, want_data: dict[str, pl.DataFrame] | None = None, want_metadata: pl.DataFrame | None = None, - outputs_from_cohort_dir: bool = True, **input_data_kwargs, ): - with input_MEDS_dataset(**input_data_kwargs) as (MEDS_dir, cohort_dir): - match stage_configs: - case None: - stage_configs = {} - case str(): - stage_configs = load_yaml(stage_configs, Loader=Loader) - case dict(): - pass - case _: - raise ValueError(f"Unknown stage_configs type: {type(stage_configs)}") - - match do_pass_stage_name: - case True: - do_pass_stage_name = {stage_name: True for stage_name in stage_names} - case False: - do_pass_stage_name = {stage_name: False for stage_name in stage_names} - case dict(): - pass - case _: - raise ValueError(f"Unknown do_pass_stage_name type: {type(do_pass_stage_name)}") - - pipeline_config_kwargs = { - "input_dir": str(MEDS_dir.resolve()), - "cohort_dir": str(cohort_dir.resolve()), - "stages": stage_names, - "stage_configs": stage_configs, - "hydra.verbose": True, - } - - script_outputs = {} - n_stages = len(stage_names) - for i, (stage, script) in enumerate(zip(stage_names, transform_scripts)): - script_outputs[stage] = run_command( - script=script, - hydra_kwargs=pipeline_config_kwargs, - do_use_config_yaml=True, - config_name="preprocess", - test_name=f"Multi stage transform {i}/{n_stages}: {stage}", - stage_name=stage, - do_pass_stage_name=do_pass_stage_name[stage], - ) - - check_outputs( - cohort_dir, - want_data=want_data, - want_metadata=want_metadata, - outputs_from_cohort_dir=outputs_from_cohort_dir, - assert_no_other_outputs=False, # this currently doesn't work due to metadata / data confusions. - ) + base_kwargs = { + "scripts": transform_scripts, + "stage_names": stage_names, + "stage_configs": stage_configs, + "do_pass_stage_name": do_pass_stage_name, + "assert_no_other_outputs": False, # TODO(mmd): eventually fix + "config_name": "preprocess", + "input_files": remap_inputs_for_transform(**input_data_kwargs), + "want_outputs": {**want_data, **want_metadata}, + } + + multi_stage_tester(**base_kwargs) diff --git a/tests/utils.py b/tests/utils.py index 7cb7c180..c2585133 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,11 +1,22 @@ +import json import subprocess import tempfile +from contextlib import contextmanager from io import StringIO from pathlib import Path +from typing import Any +import numpy as np import polars as pl +from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict from omegaconf import OmegaConf from polars.testing import assert_frame_equal +from yaml import load as load_yaml + +try: + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader DEFAULT_CSV_TS_FORMAT = "%m/%d/%Y, %H:%M:%S" @@ -192,3 +203,256 @@ def assert_df_equal(want: pl.DataFrame, got: pl.DataFrame, msg: str = None, **kw print("got:") print(got) raise AssertionError(f"{msg}\n{e}") from e + + +def check_NRT_output( + output_fp: Path, + want_nrt: JointNestedRaggedTensorDict, +): + assert output_fp.is_file(), f"Expected {output_fp} to exist." + + got_nrt = JointNestedRaggedTensorDict.load(output_fp) + + # assert got_nrt.schema == want_nrt.schema, ( + # f"Expected the schema of the NRT at {output_fp} to be equal to the target.\n" + # f"Wanted:\n{want_nrt.schema}\n" + # f"Got:\n{got_nrt.schema}" + # ) + + want_tensors = want_nrt.tensors + got_tensors = got_nrt.tensors + + assert got_tensors.keys() == want_tensors.keys(), ( + f"Expected the keys of the NRT at {output_fp} to be equal to the target.\n" + f"Wanted:\n{list(want_tensors.keys())}\n" + f"Got:\n{list(got_tensors.keys())}" + ) + + for k in want_tensors.keys(): + want_v = want_tensors[k] + got_v = got_tensors[k] + + assert type(want_v) is type(got_v), ( + f"Expected tensor {k} of the NRT at {output_fp} to be of the same type as the target.\n" + f"Wanted:\n{type(want_v)}\n" + f"Got:\n{type(got_v)}" + ) + + if isinstance(want_v, list): + assert len(want_v) == len(got_v), ( + f"Expected list {k} of the NRT at {output_fp} to be of the same length as the target.\n" + f"Wanted:\n{len(want_v)}\n" + f"Got:\n{len(got_v)}" + ) + for i, (want_i, got_i) in enumerate(zip(want_v, got_v)): + assert np.array_equal(want_i, got_i, equal_nan=True), ( + f"Expected tensor {k}[{i}] of the NRT at {output_fp} to be equal to the target.\n" + f"Wanted:\n{want_i}\n" + f"Got:\n{got_i}" + ) + else: + assert np.array_equal(want_v, got_v, equal_nan=True), ( + f"Expected tensor {k} of the NRT at {output_fp} to be equal to the target.\n" + f"Wanted:\n{want_v}\n" + f"Got:\n{got_v}" + ) + + +def check_df_output( + output_fp: Path, + want_df: pl.DataFrame, + check_column_order: bool = False, + check_row_order: bool = True, + **kwargs, +): + assert output_fp.is_file(), f"Expected {output_fp} to exist." + + got_df = pl.read_parquet(output_fp, glob=False) + assert_df_equal( + want_df, + got_df, + (f"Expected the dataframe at {output_fp} to be equal to the target.\n"), + check_column_order=check_column_order, + check_row_order=check_row_order, + **kwargs, + ) + + +FILE_T = pl.DataFrame | dict[str, Any] + + +@contextmanager +def input_dataset(input_files: dict[str, FILE_T] | None = None): + with tempfile.TemporaryDirectory() as d: + input_dir = Path(d) / "input_cohort" + cohort_dir = Path(d) / "output_cohort" + + for filename, data in input_files.items(): + fp = input_dir / filename + fp.parent.mkdir(parents=True, exist_ok=True) + + match data: + case pl.DataFrame() if fp.suffix == "": + data.write_parquet(fp.with_suffix(".parquet"), use_pyarrow=True) + case pl.DataFrame() if fp.suffix == ".parquet": + data.write_parquet(fp, use_pyarrow=True) + case dict() if fp.suffix == "": + fp.with_suffix(".json").write_text(json.dumps(data)) + case dict() if fp.suffix.endswith(".json"): + fp.write_text(json.dumps(data)) + case _: + raise ValueError(f"Unknown data type {type(data)} for file {fp.relative_to(input_dir)}") + + yield input_dir, cohort_dir + + +def check_outputs( + cohort_dir: Path, + want_outputs: dict[str, pl.DataFrame], + assert_no_other_outputs: bool = True, +): + all_file_suffixes = set() + + for output_name, want in want_outputs.items(): + if Path(output_name).suffix == "": + output_name = f"{output_name}.parquet" + + file_suffix = Path(output_name).suffix + all_file_suffixes.add(file_suffix) + + output_fp = cohort_dir / output_name + + if not output_fp.is_file(): + raise AssertionError(f"Expected {output_fp} to exist.") + + match file_suffix: + case ".parquet": + check_df_output(output_fp, want) + case ".nrt": + check_NRT_output(output_fp, want) + case _: + raise ValueError(f"Unknown file suffix: {file_suffix}") + + if assert_no_other_outputs: + all_outputs = [] + for suffix in all_file_suffixes: + all_outputs.extend(list(cohort_dir.glob(f"**/*{suffix}"))) + assert len(want_outputs) == len(all_outputs), ( + f"Want {len(want_outputs)} outputs, but found {len(all_outputs)}.\n" + f"Found outputs: {[fp.relative_to(cohort_dir) for fp in all_outputs]}\n" + ) + + +def single_stage_tester( + script: str | Path, + stage_name: str, + stage_kwargs: dict[str, str] | None, + do_pass_stage_name: bool = False, + do_use_config_yaml: bool = False, + want_outputs: dict[str, pl.DataFrame] | None = None, + assert_no_other_outputs: bool = True, + should_error: bool = False, + config_name: str = "preprocess", + input_files: dict[str, FILE_T] | None = None, +): + with input_dataset(input_files) as (input_dir, cohort_dir): + pipeline_config_kwargs = { + "input_dir": str(input_dir.resolve()), + "cohort_dir": str(cohort_dir.resolve()), + "stages": [stage_name], + "hydra.verbose": True, + } + + if stage_kwargs: + pipeline_config_kwargs["stage_configs"] = {stage_name: stage_kwargs} + + run_command_kwargs = { + "script": script, + "hydra_kwargs": pipeline_config_kwargs, + "test_name": f"Single stage transform: {stage_name}", + "should_error": should_error, + "config_name": config_name, + } + if do_use_config_yaml: + run_command_kwargs["do_use_config_yaml"] = True + + if do_pass_stage_name: + run_command_kwargs["stage"] = stage_name + run_command_kwargs["do_pass_stage_name"] = True + + # Run the transform + stderr, stdout = run_command(**run_command_kwargs) + if should_error: + return + + try: + check_outputs( + cohort_dir, want_outputs=want_outputs, assert_no_other_outputs=assert_no_other_outputs + ) + except Exception as e: + raise AssertionError( + f"Single stage transform {stage_name} failed.\n" + f"Script stdout:\n{stdout}\n" + f"Script stderr:\n{stderr}" + ) from e + + +def multi_stage_tester( + scripts: list[str | Path], + stage_names: list[str], + stage_configs: dict[str, str] | str | None, + do_pass_stage_name: bool | dict[str, bool] = True, + want_outputs: dict[str, pl.DataFrame] | None = None, + assert_no_other_outputs: bool = False, + config_name: str = "preprocess", + input_files: dict[str, FILE_T] | None = None, + **pipeline_kwargs, +): + with input_dataset(input_files) as (input_dir, cohort_dir): + match stage_configs: + case None: + stage_configs = {} + case str(): + stage_configs = load_yaml(stage_configs, Loader=Loader) + case dict(): + pass + case _: + raise ValueError(f"Unknown stage_configs type: {type(stage_configs)}") + + match do_pass_stage_name: + case True: + do_pass_stage_name = {stage_name: True for stage_name in stage_names} + case False: + do_pass_stage_name = {stage_name: False for stage_name in stage_names} + case dict(): + pass + case _: + raise ValueError(f"Unknown do_pass_stage_name type: {type(do_pass_stage_name)}") + + pipeline_config_kwargs = { + "input_dir": str(input_dir.resolve()), + "cohort_dir": str(cohort_dir.resolve()), + "stages": stage_names, + "stage_configs": stage_configs, + "hydra.verbose": True, + **pipeline_kwargs, + } + + script_outputs = {} + n_stages = len(stage_names) + for i, (stage, script) in enumerate(zip(stage_names, scripts)): + script_outputs[stage] = run_command( + script=script, + hydra_kwargs=pipeline_config_kwargs, + do_use_config_yaml=True, + config_name=config_name, + test_name=f"Multi stage transform {i}/{n_stages}: {stage}", + stage_name=stage, + do_pass_stage_name=do_pass_stage_name[stage], + ) + + check_outputs( + cohort_dir, + want_outputs=want_outputs, + assert_no_other_outputs=assert_no_other_outputs, + ) From e111ef64ace283a2f07af69a6c7ca8d27a289581 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 16:39:15 -0400 Subject: [PATCH 14/62] Re-organized imports further to fix some issues. --- tests/MEDS_Extract/__init__.py | 24 ++++++++++ tests/MEDS_Extract/test_extract.py | 35 ++++---------- .../MEDS_Extract/test_extract_no_metadata.py | 35 ++++---------- tests/MEDS_Transforms/__init__.py | 46 +++++++++++++++++++ .../test_add_time_derived_measurements.py | 9 +--- .../test_aggregate_code_metadata.py | 5 +- tests/MEDS_Transforms/test_extract_values.py | 11 +---- .../test_filter_measurements.py | 10 +--- tests/MEDS_Transforms/test_filter_subjects.py | 6 +-- .../test_fit_vocabulary_indices.py | 11 +---- .../test_multi_stage_preprocess_pipeline.py | 8 +--- tests/MEDS_Transforms/test_normalization.py | 6 +-- .../MEDS_Transforms/test_occlude_outliers.py | 6 +-- .../test_reorder_measurements.py | 9 +--- .../MEDS_Transforms/test_reshard_to_split.py | 6 +-- tests/MEDS_Transforms/test_tensorization.py | 6 +-- tests/MEDS_Transforms/test_tokenization.py | 4 +- .../MEDS_Transforms/transform_tester_base.py | 45 ------------------ 18 files changed, 114 insertions(+), 168 deletions(-) diff --git a/tests/MEDS_Extract/__init__.py b/tests/MEDS_Extract/__init__.py index e69de29b..14ddbce2 100644 --- a/tests/MEDS_Extract/__init__.py +++ b/tests/MEDS_Extract/__init__.py @@ -0,0 +1,24 @@ +import os + +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +extraction_root = root / "src" / "MEDS_transforms" / "extract" + +if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": + SHARD_EVENTS_SCRIPT = extraction_root / "shard_events.py" + SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_subjects.py" + CONVERT_TO_SHARDED_EVENTS_SCRIPT = extraction_root / "convert_to_sharded_events.py" + MERGE_TO_MEDS_COHORT_SCRIPT = extraction_root / "merge_to_MEDS_cohort.py" + EXTRACT_CODE_METADATA_SCRIPT = extraction_root / "extract_code_metadata.py" + FINALIZE_DATA_SCRIPT = extraction_root / "finalize_MEDS_data.py" + FINALIZE_METADATA_SCRIPT = extraction_root / "finalize_MEDS_metadata.py" +else: + SHARD_EVENTS_SCRIPT = "MEDS_extract-shard_events" + SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_subjects" + CONVERT_TO_SHARDED_EVENTS_SCRIPT = "MEDS_extract-convert_to_sharded_events" + MERGE_TO_MEDS_COHORT_SCRIPT = "MEDS_extract-merge_to_MEDS_cohort" + EXTRACT_CODE_METADATA_SCRIPT = "MEDS_extract-extract_code_metadata" + FINALIZE_DATA_SCRIPT = "MEDS_extract-finalize_MEDS_data" + FINALIZE_METADATA_SCRIPT = "MEDS_extract-finalize_MEDS_metadata" diff --git a/tests/MEDS_Extract/test_extract.py b/tests/MEDS_Extract/test_extract.py index b1b50a33..be6b2461 100644 --- a/tests/MEDS_Extract/test_extract.py +++ b/tests/MEDS_Extract/test_extract.py @@ -4,32 +4,6 @@ scripts. """ -import os - -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -code_root = root / "src" / "MEDS_transforms" -extraction_root = code_root / "extract" - -if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": - SHARD_EVENTS_SCRIPT = extraction_root / "shard_events.py" - SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_subjects.py" - CONVERT_TO_SHARDED_EVENTS_SCRIPT = extraction_root / "convert_to_sharded_events.py" - MERGE_TO_MEDS_COHORT_SCRIPT = extraction_root / "merge_to_MEDS_cohort.py" - EXTRACT_CODE_METADATA_SCRIPT = extraction_root / "extract_code_metadata.py" - FINALIZE_DATA_SCRIPT = extraction_root / "finalize_MEDS_data.py" - FINALIZE_METADATA_SCRIPT = extraction_root / "finalize_MEDS_metadata.py" -else: - SHARD_EVENTS_SCRIPT = "MEDS_extract-shard_events" - SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_subjects" - CONVERT_TO_SHARDED_EVENTS_SCRIPT = "MEDS_extract-convert_to_sharded_events" - MERGE_TO_MEDS_COHORT_SCRIPT = "MEDS_extract-merge_to_MEDS_cohort" - EXTRACT_CODE_METADATA_SCRIPT = "MEDS_extract-extract_code_metadata" - FINALIZE_DATA_SCRIPT = "MEDS_extract-finalize_MEDS_data" - FINALIZE_METADATA_SCRIPT = "MEDS_extract-finalize_MEDS_metadata" - import json import tempfile from io import StringIO @@ -38,6 +12,15 @@ import polars as pl from meds import __version__ as MEDS_VERSION +from tests.MEDS_Extract import ( + CONVERT_TO_SHARDED_EVENTS_SCRIPT, + EXTRACT_CODE_METADATA_SCRIPT, + FINALIZE_DATA_SCRIPT, + FINALIZE_METADATA_SCRIPT, + MERGE_TO_MEDS_COHORT_SCRIPT, + SHARD_EVENTS_SCRIPT, + SPLIT_AND_SHARD_SCRIPT, +) from tests.utils import assert_df_equal, run_command # Test data (inputs) diff --git a/tests/MEDS_Extract/test_extract_no_metadata.py b/tests/MEDS_Extract/test_extract_no_metadata.py index ed783f4f..0fa8eec8 100644 --- a/tests/MEDS_Extract/test_extract_no_metadata.py +++ b/tests/MEDS_Extract/test_extract_no_metadata.py @@ -4,32 +4,6 @@ scripts. """ -import os - -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -code_root = root / "src" / "MEDS_transforms" -extraction_root = code_root / "extract" - -if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": - SHARD_EVENTS_SCRIPT = extraction_root / "shard_events.py" - SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_subjects.py" - CONVERT_TO_SHARDED_EVENTS_SCRIPT = extraction_root / "convert_to_sharded_events.py" - MERGE_TO_MEDS_COHORT_SCRIPT = extraction_root / "merge_to_MEDS_cohort.py" - EXTRACT_CODE_METADATA_SCRIPT = extraction_root / "extract_code_metadata.py" - FINALIZE_DATA_SCRIPT = extraction_root / "finalize_MEDS_data.py" - FINALIZE_METADATA_SCRIPT = extraction_root / "finalize_MEDS_metadata.py" -else: - SHARD_EVENTS_SCRIPT = "MEDS_extract-shard_events" - SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_subjects" - CONVERT_TO_SHARDED_EVENTS_SCRIPT = "MEDS_extract-convert_to_sharded_events" - MERGE_TO_MEDS_COHORT_SCRIPT = "MEDS_extract-merge_to_MEDS_cohort" - EXTRACT_CODE_METADATA_SCRIPT = "MEDS_extract-extract_code_metadata" - FINALIZE_DATA_SCRIPT = "MEDS_extract-finalize_MEDS_data" - FINALIZE_METADATA_SCRIPT = "MEDS_extract-finalize_MEDS_metadata" - import json import tempfile from io import StringIO @@ -38,6 +12,15 @@ import polars as pl from meds import __version__ as MEDS_VERSION +from tests.MEDS_Extract import ( + CONVERT_TO_SHARDED_EVENTS_SCRIPT, + EXTRACT_CODE_METADATA_SCRIPT, + FINALIZE_DATA_SCRIPT, + FINALIZE_METADATA_SCRIPT, + MERGE_TO_MEDS_COHORT_SCRIPT, + SHARD_EVENTS_SCRIPT, + SPLIT_AND_SHARD_SCRIPT, +) from tests.utils import assert_df_equal, run_command # Test data (inputs) diff --git a/tests/MEDS_Transforms/__init__.py b/tests/MEDS_Transforms/__init__.py index e69de29b..a2d3d56f 100644 --- a/tests/MEDS_Transforms/__init__.py +++ b/tests/MEDS_Transforms/__init__.py @@ -0,0 +1,46 @@ +import os + +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +code_root = root / "src" / "MEDS_transforms" +transforms_root = code_root / "transforms" +filters_root = code_root / "filters" + +if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": + # Root Source + AGGREGATE_CODE_METADATA_SCRIPT = code_root / "aggregate_code_metadata.py" + FIT_VOCABULARY_INDICES_SCRIPT = code_root / "fit_vocabulary_indices.py" + RESHARD_TO_SPLIT_SCRIPT = code_root / "reshard_to_split.py" + + # Filters + FILTER_MEASUREMENTS_SCRIPT = filters_root / "filter_measurements.py" + FILTER_SUBJECTS_SCRIPT = filters_root / "filter_subjects.py" + + # Transforms + ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = transforms_root / "add_time_derived_measurements.py" + REORDER_MEASUREMENTS_SCRIPT = transforms_root / "reorder_measurements.py" + EXTRACT_VALUES_SCRIPT = transforms_root / "extract_values.py" + NORMALIZATION_SCRIPT = transforms_root / "normalization.py" + OCCLUDE_OUTLIERS_SCRIPT = transforms_root / "occlude_outliers.py" + TENSORIZATION_SCRIPT = transforms_root / "tensorization.py" + TOKENIZATION_SCRIPT = transforms_root / "tokenization.py" +else: + # Root Source + AGGREGATE_CODE_METADATA_SCRIPT = "MEDS_transform-aggregate_code_metadata" + FIT_VOCABULARY_INDICES_SCRIPT = "MEDS_transform-fit_vocabulary_indices" + RESHARD_TO_SPLIT_SCRIPT = "MEDS_transform-reshard_to_split" + + # Filters + FILTER_MEASUREMENTS_SCRIPT = "MEDS_transform-filter_measurements" + FILTER_SUBJECTS_SCRIPT = "MEDS_transform-filter_subjects" + + # Transforms + ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = "MEDS_transform-add_time_derived_measurements" + REORDER_MEASUREMENTS_SCRIPT = "MEDS_transform-reorder_measurements" + EXTRACT_VALUES_SCRIPT = "MEDS_transform-extract_values" + NORMALIZATION_SCRIPT = "MEDS_transform-normalization" + OCCLUDE_OUTLIERS_SCRIPT = "MEDS_transform-occlude_outliers" + TENSORIZATION_SCRIPT = "MEDS_transform-tensorization" + TOKENIZATION_SCRIPT = "MEDS_transform-tokenization" diff --git a/tests/MEDS_Transforms/test_add_time_derived_measurements.py b/tests/MEDS_Transforms/test_add_time_derived_measurements.py index ed7bbba9..ff2131de 100644 --- a/tests/MEDS_Transforms/test_add_time_derived_measurements.py +++ b/tests/MEDS_Transforms/test_add_time_derived_measurements.py @@ -3,16 +3,11 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) from meds import subject_id_field -from tests.MEDS_Transforms.transform_tester_base import ( - ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, - single_stage_transform_tester, -) +from tests.MEDS_Transforms import ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester from tests.utils import parse_meds_csvs AGE_CALCULATION_STR = """ diff --git a/tests/MEDS_Transforms/test_aggregate_code_metadata.py b/tests/MEDS_Transforms/test_aggregate_code_metadata.py index 48ff79bb..acf00995 100644 --- a/tests/MEDS_Transforms/test_aggregate_code_metadata.py +++ b/tests/MEDS_Transforms/test_aggregate_code_metadata.py @@ -3,14 +3,11 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) import polars as pl +from tests.MEDS_Transforms import AGGREGATE_CODE_METADATA_SCRIPT from tests.MEDS_Transforms.transform_tester_base import ( - AGGREGATE_CODE_METADATA_SCRIPT, MEDS_CODE_METADATA_SCHEMA, single_stage_transform_tester, ) diff --git a/tests/MEDS_Transforms/test_extract_values.py b/tests/MEDS_Transforms/test_extract_values.py index d273c99b..4114b3ba 100644 --- a/tests/MEDS_Transforms/test_extract_values.py +++ b/tests/MEDS_Transforms/test_extract_values.py @@ -4,15 +4,8 @@ scripts. """ -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -from tests.MEDS_Transforms.transform_tester_base import ( - EXTRACT_VALUES_SCRIPT, - parse_shards_yaml, - single_stage_transform_tester, -) +from tests.MEDS_Transforms import EXTRACT_VALUES_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import parse_shards_yaml, single_stage_transform_tester INPUT_SHARDS = parse_shards_yaml( """ diff --git a/tests/MEDS_Transforms/test_filter_measurements.py b/tests/MEDS_Transforms/test_filter_measurements.py index a3e53d9a..9991a265 100644 --- a/tests/MEDS_Transforms/test_filter_measurements.py +++ b/tests/MEDS_Transforms/test_filter_measurements.py @@ -3,15 +3,9 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ -import rootutils -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - - -from tests.MEDS_Transforms.transform_tester_base import ( - FILTER_MEASUREMENTS_SCRIPT, - single_stage_transform_tester, -) +from tests.MEDS_Transforms import FILTER_MEASUREMENTS_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester from tests.utils import parse_meds_csvs # This is the code metadata diff --git a/tests/MEDS_Transforms/test_filter_subjects.py b/tests/MEDS_Transforms/test_filter_subjects.py index 4d4f2ca1..83f40689 100644 --- a/tests/MEDS_Transforms/test_filter_subjects.py +++ b/tests/MEDS_Transforms/test_filter_subjects.py @@ -4,12 +4,10 @@ scripts. """ -import rootutils from meds import subject_id_field -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -from tests.MEDS_Transforms.transform_tester_base import FILTER_SUBJECTS_SCRIPT, single_stage_transform_tester +from tests.MEDS_Transforms import FILTER_SUBJECTS_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester from tests.utils import parse_meds_csvs WANT_TRAIN_0 = f""" diff --git a/tests/MEDS_Transforms/test_fit_vocabulary_indices.py b/tests/MEDS_Transforms/test_fit_vocabulary_indices.py index ea6c1c5d..78f637a5 100644 --- a/tests/MEDS_Transforms/test_fit_vocabulary_indices.py +++ b/tests/MEDS_Transforms/test_fit_vocabulary_indices.py @@ -4,16 +4,9 @@ scripts. """ -import rootutils -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - - -from tests.MEDS_Transforms.transform_tester_base import ( - FIT_VOCABULARY_INDICES_SCRIPT, - parse_code_metadata_csv, - single_stage_transform_tester, -) +from tests.MEDS_Transforms import FIT_VOCABULARY_INDICES_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import parse_code_metadata_csv, single_stage_transform_tester WANT_CSV = """ code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_codes,code/vocab_index diff --git a/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py index 0deade2a..6667313f 100644 --- a/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py +++ b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py @@ -17,9 +17,6 @@ The stage configuration arguments will be as given in the yaml block below: """ -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) from datetime import datetime @@ -27,7 +24,7 @@ from meds import subject_id_field from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict -from tests.MEDS_Transforms.transform_tester_base import ( +from tests.MEDS_Transforms import ( ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, AGGREGATE_CODE_METADATA_SCRIPT, FILTER_SUBJECTS_SCRIPT, @@ -36,9 +33,8 @@ OCCLUDE_OUTLIERS_SCRIPT, TENSORIZATION_SCRIPT, TOKENIZATION_SCRIPT, - multi_stage_transform_tester, - parse_shards_yaml, ) +from tests.MEDS_Transforms.transform_tester_base import multi_stage_transform_tester, parse_shards_yaml MEDS_CODE_METADATA = pl.DataFrame( { diff --git a/tests/MEDS_Transforms/test_normalization.py b/tests/MEDS_Transforms/test_normalization.py index b6f386f2..4cc21ae6 100644 --- a/tests/MEDS_Transforms/test_normalization.py +++ b/tests/MEDS_Transforms/test_normalization.py @@ -5,11 +5,9 @@ """ import polars as pl -import rootutils -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -from tests.MEDS_Transforms.transform_tester_base import NORMALIZATION_SCRIPT, single_stage_transform_tester +from tests.MEDS_Transforms import NORMALIZATION_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester from tests.utils import MEDS_PL_SCHEMA, parse_meds_csvs # This is the code metadata file we'll use in this transform test. It is different than the default as we need diff --git a/tests/MEDS_Transforms/test_occlude_outliers.py b/tests/MEDS_Transforms/test_occlude_outliers.py index ad3d3213..8e30db67 100644 --- a/tests/MEDS_Transforms/test_occlude_outliers.py +++ b/tests/MEDS_Transforms/test_occlude_outliers.py @@ -4,13 +4,11 @@ scripts. """ -import rootutils - -rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) import polars as pl -from tests.MEDS_Transforms.transform_tester_base import OCCLUDE_OUTLIERS_SCRIPT, single_stage_transform_tester +from tests.MEDS_Transforms import OCCLUDE_OUTLIERS_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester from tests.utils import MEDS_PL_SCHEMA, parse_meds_csvs # This is the code metadata diff --git a/tests/MEDS_Transforms/test_reorder_measurements.py b/tests/MEDS_Transforms/test_reorder_measurements.py index c4a2a549..782c3945 100644 --- a/tests/MEDS_Transforms/test_reorder_measurements.py +++ b/tests/MEDS_Transforms/test_reorder_measurements.py @@ -4,14 +4,9 @@ scripts. """ -import rootutils -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -from tests.MEDS_Transforms.transform_tester_base import ( - REORDER_MEASUREMENTS_SCRIPT, - single_stage_transform_tester, -) +from tests.MEDS_Transforms import REORDER_MEASUREMENTS_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester from tests.utils import parse_meds_csvs ORDERED_CODE_PATTERNS = [ diff --git a/tests/MEDS_Transforms/test_reshard_to_split.py b/tests/MEDS_Transforms/test_reshard_to_split.py index 19008bca..d0094a96 100644 --- a/tests/MEDS_Transforms/test_reshard_to_split.py +++ b/tests/MEDS_Transforms/test_reshard_to_split.py @@ -3,14 +3,12 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) from meds import subject_id_field -from tests.MEDS_Transforms.transform_tester_base import RESHARD_TO_SPLIT_SCRIPT, single_stage_transform_tester +from tests.MEDS_Transforms import RESHARD_TO_SPLIT_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester from tests.utils import parse_meds_csvs IN_SHARDS_MAP = { diff --git a/tests/MEDS_Transforms/test_tensorization.py b/tests/MEDS_Transforms/test_tensorization.py index b648e6ea..f56d064e 100644 --- a/tests/MEDS_Transforms/test_tensorization.py +++ b/tests/MEDS_Transforms/test_tensorization.py @@ -6,14 +6,12 @@ scripts. """ -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict +from tests.MEDS_Transforms import TENSORIZATION_SCRIPT from tests.MEDS_Transforms.test_tokenization import WANT_EVENT_SEQS as TOKENIZED_SHARDS -from tests.MEDS_Transforms.transform_tester_base import TENSORIZATION_SCRIPT, single_stage_transform_tester +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester WANT_NRTS = { f'{k.replace("event_seqs/", "")}.nrt': JointNestedRaggedTensorDict( diff --git a/tests/MEDS_Transforms/test_tokenization.py b/tests/MEDS_Transforms/test_tokenization.py index 811b1c97..d12cdb0e 100644 --- a/tests/MEDS_Transforms/test_tokenization.py +++ b/tests/MEDS_Transforms/test_tokenization.py @@ -10,9 +10,11 @@ import polars as pl +from tests.MEDS_Transforms import TOKENIZATION_SCRIPT + from .test_normalization import NORMALIZED_MEDS_SCHEMA from .test_normalization import WANT_SHARDS as NORMALIZED_SHARDS -from .transform_tester_base import TOKENIZATION_SCRIPT, single_stage_transform_tester +from .transform_tester_base import single_stage_transform_tester SECONDS_PER_DAY = 60 * 60 * 24 diff --git a/tests/MEDS_Transforms/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py index 1e692e45..90b2ae73 100644 --- a/tests/MEDS_Transforms/transform_tester_base.py +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -11,60 +11,15 @@ except ImportError: from yaml import Loader -import os from collections import defaultdict from io import StringIO from pathlib import Path import polars as pl -import rootutils from meds import subject_id_field from tests.utils import FILE_T, MEDS_PL_SCHEMA, multi_stage_tester, parse_meds_csvs, single_stage_tester -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -code_root = root / "src" / "MEDS_transforms" -transforms_root = code_root / "transforms" -filters_root = code_root / "filters" - -if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": - # Root Source - AGGREGATE_CODE_METADATA_SCRIPT = code_root / "aggregate_code_metadata.py" - FIT_VOCABULARY_INDICES_SCRIPT = code_root / "fit_vocabulary_indices.py" - RESHARD_TO_SPLIT_SCRIPT = code_root / "reshard_to_split.py" - - # Filters - FILTER_MEASUREMENTS_SCRIPT = filters_root / "filter_measurements.py" - FILTER_SUBJECTS_SCRIPT = filters_root / "filter_subjects.py" - - # Transforms - ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = transforms_root / "add_time_derived_measurements.py" - REORDER_MEASUREMENTS_SCRIPT = transforms_root / "reorder_measurements.py" - EXTRACT_VALUES_SCRIPT = transforms_root / "extract_values.py" - NORMALIZATION_SCRIPT = transforms_root / "normalization.py" - OCCLUDE_OUTLIERS_SCRIPT = transforms_root / "occlude_outliers.py" - TENSORIZATION_SCRIPT = transforms_root / "tensorization.py" - TOKENIZATION_SCRIPT = transforms_root / "tokenization.py" -else: - # Root Source - AGGREGATE_CODE_METADATA_SCRIPT = "MEDS_transform-aggregate_code_metadata" - FIT_VOCABULARY_INDICES_SCRIPT = "MEDS_transform-fit_vocabulary_indices" - RESHARD_TO_SPLIT_SCRIPT = "MEDS_transform-reshard_to_split" - - # Filters - FILTER_MEASUREMENTS_SCRIPT = "MEDS_transform-filter_measurements" - FILTER_SUBJECTS_SCRIPT = "MEDS_transform-filter_subjects" - - # Transforms - ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = "MEDS_transform-add_time_derived_measurements" - REORDER_MEASUREMENTS_SCRIPT = "MEDS_transform-reorder_measurements" - EXTRACT_VALUES_SCRIPT = "MEDS_transform-extract_values" - NORMALIZATION_SCRIPT = "MEDS_transform-normalization" - OCCLUDE_OUTLIERS_SCRIPT = "MEDS_transform-occlude_outliers" - TENSORIZATION_SCRIPT = "MEDS_transform-tensorization" - TOKENIZATION_SCRIPT = "MEDS_transform-tokenization" - # Test MEDS data (inputs) SHARDS = { From 7675292b7e2971fe595a6520c34972d12d6ddf8e Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 16:56:36 -0400 Subject: [PATCH 15/62] Added a shard_events test. --- tests/MEDS_Extract/test_shard_events.py | 113 ++++++++++++++++++++++++ tests/utils.py | 15 +++- 2 files changed, 125 insertions(+), 3 deletions(-) create mode 100644 tests/MEDS_Extract/test_shard_events.py diff --git a/tests/MEDS_Extract/test_shard_events.py b/tests/MEDS_Extract/test_shard_events.py new file mode 100644 index 00000000..c6f0d0fa --- /dev/null +++ b/tests/MEDS_Extract/test_shard_events.py @@ -0,0 +1,113 @@ +"""Tests the shard events stage in isolation. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +from io import StringIO + +import polars as pl + +from tests.MEDS_Extract import SHARD_EVENTS_SCRIPT +from tests.utils import single_stage_tester + +SUBJECTS_CSV = """ +MRN,dob,eye_color,height +1195293,06/20/1978,BLUE,164.6868838269085 +239684,12/28/1980,BROWN,175.271115221764 +1500733,07/20/1986,BROWN,158.60131573580904 +814703,03/28/1976,HAZEL,156.48559093209357 +754281,12/19/1988,BROWN,166.22261567137025 +68729,03/09/1978,HAZEL,160.3953106166676 +""" + +ADMIT_VITALS_CSV = """ +subject_id,admit_date,disch_date,department,vitals_date,HR,temp +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 +754281,"01/03/2010, 06:27:59","01/03/2010, 08:22:13",PULMONARY,"01/03/2010, 06:27:59",142.0,99.8 +814703,"02/05/2010, 05:55:39","02/05/2010, 07:02:30",ORTHOPEDIC,"02/05/2010, 05:55:39",170.2,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:25:35",113.4,95.8 +68729,"05/26/2010, 02:30:56","05/26/2010, 04:51:52",PULMONARY,"05/26/2010, 02:30:56",86.0,97.8 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:12:31",112.5,99.8 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 16:20:49",90.1,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:48:48",105.1,96.2 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:41:51",102.6,96.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:25:32",114.1,100.0 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 14:54:38",91.4,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:41:33",107.5,100.4 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:24:44",107.7,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:45:19",119.8,99.9 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:23:52",109.0,100.0 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 15:39:49",84.4,100.3 +""" + +EVENT_CFGS_YAML = """ +subjects: + subject_id_col: MRN + eye_color: + code: + - EYE_COLOR + - col(eye_color) + time: null + _metadata: + demo_metadata: + description: description + height: + code: HEIGHT + time: null + numeric_value: height + dob: + code: DOB + time: col(dob) + time_format: "%m/%d/%Y" +admit_vitals: + admissions: + code: + - ADMISSION + - col(department) + time: col(admit_date) + time_format: "%m/%d/%Y, %H:%M:%S" + discharge: + code: DISCHARGE + time: col(disch_date) + time_format: "%m/%d/%Y, %H:%M:%S" + HR: + code: HR + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + temp: + code: TEMP + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} +""" + + +def test_extraction(): + single_stage_tester( + script=SHARD_EVENTS_SCRIPT, + stage_name="shard_events", + stage_kwargs={"row_chunksize": 10}, + want_outputs={ + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[:10], + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[10:], + }, + config_name="extract", + input_files={ + "subjects.csv": SUBJECTS_CSV, + "admit_vitals.csv": ADMIT_VITALS_CSV, + "admit_vitals.parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV)), + "event_cfgs.yaml": EVENT_CFGS_YAML, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", # This makes the escape pass to hydra + ) diff --git a/tests/utils.py b/tests/utils.py index c2585133..4a00eb45 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -278,7 +278,7 @@ def check_df_output( ) -FILE_T = pl.DataFrame | dict[str, Any] +FILE_T = pl.DataFrame | dict[str, Any] | str @contextmanager @@ -296,10 +296,14 @@ def input_dataset(input_files: dict[str, FILE_T] | None = None): data.write_parquet(fp.with_suffix(".parquet"), use_pyarrow=True) case pl.DataFrame() if fp.suffix == ".parquet": data.write_parquet(fp, use_pyarrow=True) + case pl.DataFrame() if fp.suffix == ".csv": + data.write_csv(fp) case dict() if fp.suffix == "": fp.with_suffix(".json").write_text(json.dumps(data)) case dict() if fp.suffix.endswith(".json"): fp.write_text(json.dumps(data)) + case str(): + fp.write_text(data.strip()) case _: raise ValueError(f"Unknown data type {type(data)} for file {fp.relative_to(input_dir)}") @@ -354,13 +358,19 @@ def single_stage_tester( should_error: bool = False, config_name: str = "preprocess", input_files: dict[str, FILE_T] | None = None, + **pipeline_kwargs, ): with input_dataset(input_files) as (input_dir, cohort_dir): + for k, v in pipeline_kwargs.items(): + if type(v) is str and "{input_dir}" in v: + pipeline_kwargs[k] = v.format(input_dir=str(input_dir.resolve())) + pipeline_config_kwargs = { "input_dir": str(input_dir.resolve()), "cohort_dir": str(cohort_dir.resolve()), "stages": [stage_name], "hydra.verbose": True, + **pipeline_kwargs, } if stage_kwargs: @@ -372,9 +382,8 @@ def single_stage_tester( "test_name": f"Single stage transform: {stage_name}", "should_error": should_error, "config_name": config_name, + "do_use_config_yaml": do_use_config_yaml, } - if do_use_config_yaml: - run_command_kwargs["do_use_config_yaml"] = True if do_pass_stage_name: run_command_kwargs["stage"] = stage_name From 32939b212cf2d8e76161d68aef3a3b74789525b0 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 17:08:18 -0400 Subject: [PATCH 16/62] Added a split_and_shard test. --- tests/MEDS_Extract/test_shard_events.py | 14 +- .../test_split_and_shard_subjects.py | 134 ++++++++++++++++++ tests/utils.py | 8 ++ 3 files changed, 149 insertions(+), 7 deletions(-) create mode 100644 tests/MEDS_Extract/test_split_and_shard_subjects.py diff --git a/tests/MEDS_Extract/test_shard_events.py b/tests/MEDS_Extract/test_shard_events.py index c6f0d0fa..88cd60d1 100644 --- a/tests/MEDS_Extract/test_shard_events.py +++ b/tests/MEDS_Extract/test_shard_events.py @@ -92,16 +92,11 @@ """ -def test_extraction(): +def test_shard_events(): single_stage_tester( script=SHARD_EVENTS_SCRIPT, stage_name="shard_events", stage_kwargs={"row_chunksize": 10}, - want_outputs={ - "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), - "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[:10], - "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[10:], - }, config_name="extract", input_files={ "subjects.csv": SUBJECTS_CSV, @@ -109,5 +104,10 @@ def test_extraction(): "admit_vitals.parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV)), "event_cfgs.yaml": EVENT_CFGS_YAML, }, - event_conversion_config_fp="{input_dir}/event_cfgs.yaml", # This makes the escape pass to hydra + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + want_outputs={ + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[:10], + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[10:], + }, ) diff --git a/tests/MEDS_Extract/test_split_and_shard_subjects.py b/tests/MEDS_Extract/test_split_and_shard_subjects.py new file mode 100644 index 00000000..0a216d2c --- /dev/null +++ b/tests/MEDS_Extract/test_split_and_shard_subjects.py @@ -0,0 +1,134 @@ +"""Tests the full end-to-end extraction process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +from io import StringIO + +import polars as pl + +from tests.MEDS_Extract import SPLIT_AND_SHARD_SCRIPT +from tests.utils import single_stage_tester + +SUBJECTS_CSV = """ +MRN,dob,eye_color,height +1195293,06/20/1978,BLUE,164.6868838269085 +239684,12/28/1980,BROWN,175.271115221764 +1500733,07/20/1986,BROWN,158.60131573580904 +814703,03/28/1976,HAZEL,156.48559093209357 +754281,12/19/1988,BROWN,166.22261567137025 +68729,03/09/1978,HAZEL,160.3953106166676 +""" + +ADMIT_VITALS_0_10_CSV = """ +subject_id,admit_date,disch_date,department,vitals_date,HR,temp +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 +754281,"01/03/2010, 06:27:59","01/03/2010, 08:22:13",PULMONARY,"01/03/2010, 06:27:59",142.0,99.8 +814703,"02/05/2010, 05:55:39","02/05/2010, 07:02:30",ORTHOPEDIC,"02/05/2010, 05:55:39",170.2,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:25:35",113.4,95.8 +68729,"05/26/2010, 02:30:56","05/26/2010, 04:51:52",PULMONARY,"05/26/2010, 02:30:56",86.0,97.8 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:12:31",112.5,99.8 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 16:20:49",90.1,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:48:48",105.1,96.2 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:41:51",102.6,96.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:25:32",114.1,100.0 +""" + +ADMIT_VITALS_10_16_CSV = """ +subject_id,admit_date,disch_date,department,vitals_date,HR,temp +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 14:54:38",91.4,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:41:33",107.5,100.4 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:24:44",107.7,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:45:19",119.8,99.9 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:23:52",109.0,100.0 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 15:39:49",84.4,100.3 +""" + +INPUTS = { + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_0_10_CSV)), + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), +} + +EVENT_CFGS_YAML = """ +subjects: + subject_id_col: MRN + eye_color: + code: + - EYE_COLOR + - col(eye_color) + time: null + _metadata: + demo_metadata: + description: description + height: + code: HEIGHT + time: null + numeric_value: height + dob: + code: DOB + time: col(dob) + time_format: "%m/%d/%Y" +admit_vitals: + admissions: + code: + - ADMISSION + - col(department) + time: col(admit_date) + time_format: "%m/%d/%Y, %H:%M:%S" + discharge: + code: DISCHARGE + time: col(disch_date) + time_format: "%m/%d/%Y, %H:%M:%S" + HR: + code: HR + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + temp: + code: TEMP + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} +""" + +# Test data (expected outputs) -- ALL OF THIS MAY CHANGE IF THE SEED OR DATA CHANGES +EXPECTED_SPLITS = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +SUBJECT_SPLITS_DF = pl.DataFrame( + { + "subject_id": [239684, 1195293, 68729, 814703, 754281, 1500733], + "split": ["train", "train", "train", "train", "tuning", "held_out"], + } +) + + +def test_split_and_shard(): + single_stage_tester( + script=SPLIT_AND_SHARD_SCRIPT, + stage_name="split_and_shard_subjects", + stage_kwargs={ + "split_fracs.train": 4 / 6, + "split_fracs.tuning": 1 / 6, + "split_fracs.held_out": 1 / 6, + "n_subjects_per_shard": 2, + }, + config_name="extract", + input_files={**INPUTS, "event_cfgs.yaml": EVENT_CFGS_YAML}, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", # This makes the escape pass to hydra + want_outputs={"metadata/.shards.json": EXPECTED_SPLITS}, + ) diff --git a/tests/utils.py b/tests/utils.py index 4a00eb45..ea64c1c5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -334,6 +334,14 @@ def check_outputs( check_df_output(output_fp, want) case ".nrt": check_NRT_output(output_fp, want) + case ".json": + with open(output_fp) as f: + got = json.load(f) + assert got == want, ( + f"Expected JSON at {output_fp} to be equal to the target.\n" + f"Wanted:\n{want}\n" + f"Got:\n{got}" + ) case _: raise ValueError(f"Unknown file suffix: {file_suffix}") From 1961478707b28517b7b6189a6f50e1dafb8d6d0d Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 18:35:06 -0400 Subject: [PATCH 17/62] Added a split_and_shard test. --- .../MEDS_Extract/test_split_and_shard_subjects.py | 13 ++++++------- tests/MEDS_Transforms/transform_tester_base.py | 15 ++++++--------- tests/utils.py | 5 +++++ 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/tests/MEDS_Extract/test_split_and_shard_subjects.py b/tests/MEDS_Extract/test_split_and_shard_subjects.py index 0a216d2c..b57cdd1e 100644 --- a/tests/MEDS_Extract/test_split_and_shard_subjects.py +++ b/tests/MEDS_Extract/test_split_and_shard_subjects.py @@ -45,12 +45,6 @@ 1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 15:39:49",84.4,100.3 """ -INPUTS = { - "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), - "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_0_10_CSV)), - "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), -} - EVENT_CFGS_YAML = """ subjects: subject_id_col: MRN @@ -128,7 +122,12 @@ def test_split_and_shard(): "n_subjects_per_shard": 2, }, config_name="extract", - input_files={**INPUTS, "event_cfgs.yaml": EVENT_CFGS_YAML}, + input_files={ + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_0_10_CSV)), + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), + "event_cfgs.yaml": EVENT_CFGS_YAML, + }, event_conversion_config_fp="{input_dir}/event_cfgs.yaml", # This makes the escape pass to hydra want_outputs={"metadata/.shards.json": EXPECTED_SPLITS}, ) diff --git a/tests/MEDS_Transforms/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py index 90b2ae73..2599f11d 100644 --- a/tests/MEDS_Transforms/transform_tester_base.py +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -4,12 +4,11 @@ scripts. """ -from yaml import load as load_yaml try: - from yaml import CLoader as Loader + pass except ImportError: - from yaml import Loader + pass from collections import defaultdict from io import StringIO @@ -18,7 +17,10 @@ import polars as pl from meds import subject_id_field -from tests.utils import FILE_T, MEDS_PL_SCHEMA, multi_stage_tester, parse_meds_csvs, single_stage_tester +from tests.utils import FILE_T, multi_stage_tester, parse_meds_csvs, parse_shards_yaml, single_stage_tester + +# So it can be imported from here +parse_shards_yaml = parse_shards_yaml # Test MEDS data (inputs) @@ -156,11 +158,6 @@ } -def parse_shards_yaml(yaml_str: str, **schema_updates) -> pl.DataFrame: - schema = {**MEDS_PL_SCHEMA, **schema_updates} - return parse_meds_csvs(load_yaml(yaml_str, Loader=Loader), schema=schema) - - def parse_code_metadata_csv(csv_str: str) -> pl.DataFrame: cols = csv_str.strip().split("\n")[0].split(",") schema = {col: dt for col, dt in MEDS_CODE_METADATA_SCHEMA.items() if col in cols} diff --git a/tests/utils.py b/tests/utils.py index ea64c1c5..bf48137e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -55,6 +55,11 @@ def reader(csv_str: str) -> pl.DataFrame: return {k: reader(v) for k, v in csvs.items()} +def parse_shards_yaml(yaml_str: str, **schema_updates) -> pl.DataFrame: + schema = {**MEDS_PL_SCHEMA, **schema_updates} + return parse_meds_csvs(load_yaml(yaml_str, Loader=Loader), schema=schema) + + def dict_to_hydra_kwargs(d: dict[str, str]) -> str: """Converts a dictionary to a hydra kwargs string for testing purposes. From 57280f04cc8af184ccdf5eb0a2226b6b457ee032 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 19:08:33 -0400 Subject: [PATCH 18/62] Made tests specific to row order and column order where appropriate. --- tests/MEDS_Extract/test_shard_events.py | 1 + .../test_aggregate_code_metadata.py | 1 + tests/MEDS_Transforms/test_tokenization.py | 1 + .../MEDS_Transforms/transform_tester_base.py | 5 ++ tests/utils.py | 85 +++++++++---------- 5 files changed, 46 insertions(+), 47 deletions(-) diff --git a/tests/MEDS_Extract/test_shard_events.py b/tests/MEDS_Extract/test_shard_events.py index 88cd60d1..f19746ec 100644 --- a/tests/MEDS_Extract/test_shard_events.py +++ b/tests/MEDS_Extract/test_shard_events.py @@ -110,4 +110,5 @@ def test_shard_events(): "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[:10], "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[10:], }, + df_check_kwargs={"check_column_order": False}, ) diff --git a/tests/MEDS_Transforms/test_aggregate_code_metadata.py b/tests/MEDS_Transforms/test_aggregate_code_metadata.py index acf00995..a2abce52 100644 --- a/tests/MEDS_Transforms/test_aggregate_code_metadata.py +++ b/tests/MEDS_Transforms/test_aggregate_code_metadata.py @@ -184,4 +184,5 @@ def test_aggregate_code_metadata(): input_code_metadata=MEDS_CODE_METADATA_FILE, do_use_config_yaml=True, assert_no_other_outputs=False, + df_check_kwargs={"check_column_order": False}, ) diff --git a/tests/MEDS_Transforms/test_tokenization.py b/tests/MEDS_Transforms/test_tokenization.py index d12cdb0e..470b4250 100644 --- a/tests/MEDS_Transforms/test_tokenization.py +++ b/tests/MEDS_Transforms/test_tokenization.py @@ -226,6 +226,7 @@ def test_tokenization(): transform_stage_kwargs=None, input_shards=NORMALIZED_SHARDS, want_data={**WANT_SCHEMAS, **WANT_EVENT_SEQS}, + df_check_kwargs={"check_column_order": False}, ) single_stage_transform_tester( diff --git a/tests/MEDS_Transforms/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py index 2599f11d..6a1d4ab8 100644 --- a/tests/MEDS_Transforms/transform_tester_base.py +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -221,8 +221,12 @@ def single_stage_transform_tester( want_metadata: pl.DataFrame | None = None, assert_no_other_outputs: bool = True, should_error: bool = False, + df_check_kwargs: dict | None = None, **input_data_kwargs, ): + if df_check_kwargs is None: + df_check_kwargs = {} + base_kwargs = { "script": transform_script, "stage_name": stage_name, @@ -233,6 +237,7 @@ def single_stage_transform_tester( "should_error": should_error, "config_name": "preprocess", "input_files": remap_inputs_for_transform(**input_data_kwargs), + "df_check_kwargs": df_check_kwargs, } want_outputs = {} diff --git a/tests/utils.py b/tests/utils.py index bf48137e..f029e898 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -203,19 +203,14 @@ def assert_df_equal(want: pl.DataFrame, got: pl.DataFrame, msg: str = None, **kw assert_frame_equal(want, got, **kwargs) except AssertionError as e: pl.Config.set_tbl_rows(-1) - print(f"DFs are not equal: {msg}\nwant:") - print(want) - print("got:") - print(got) - raise AssertionError(f"{msg}\n{e}") from e + raise AssertionError(f"{msg}:\nWant:\n{want}\nGot:\n{got}\n{e}") from e def check_NRT_output( output_fp: Path, want_nrt: JointNestedRaggedTensorDict, + msg: str, ): - assert output_fp.is_file(), f"Expected {output_fp} to exist." - got_nrt = JointNestedRaggedTensorDict.load(output_fp) # assert got_nrt.schema == want_nrt.schema, ( @@ -228,20 +223,16 @@ def check_NRT_output( got_tensors = got_nrt.tensors assert got_tensors.keys() == want_tensors.keys(), ( - f"Expected the keys of the NRT at {output_fp} to be equal to the target.\n" - f"Wanted:\n{list(want_tensors.keys())}\n" - f"Got:\n{list(got_tensors.keys())}" + f"{msg}:\n" f"Wanted:\n{list(want_tensors.keys())}\n" f"Got:\n{list(got_tensors.keys())}" ) for k in want_tensors.keys(): want_v = want_tensors[k] got_v = got_tensors[k] - assert type(want_v) is type(got_v), ( - f"Expected tensor {k} of the NRT at {output_fp} to be of the same type as the target.\n" - f"Wanted:\n{type(want_v)}\n" - f"Got:\n{type(got_v)}" - ) + assert type(want_v) is type( + got_v + ), f"{msg}: Wanted {k} to be of type {type(want_v)}, got {type(got_v)}." if isinstance(want_v, list): assert len(want_v) == len(got_v), ( @@ -263,26 +254,6 @@ def check_NRT_output( ) -def check_df_output( - output_fp: Path, - want_df: pl.DataFrame, - check_column_order: bool = False, - check_row_order: bool = True, - **kwargs, -): - assert output_fp.is_file(), f"Expected {output_fp} to exist." - - got_df = pl.read_parquet(output_fp, glob=False) - assert_df_equal( - want_df, - got_df, - (f"Expected the dataframe at {output_fp} to be equal to the target.\n"), - check_column_order=check_column_order, - check_row_order=check_row_order, - **kwargs, - ) - - FILE_T = pl.DataFrame | dict[str, Any] | str @@ -319,6 +290,7 @@ def check_outputs( cohort_dir: Path, want_outputs: dict[str, pl.DataFrame], assert_no_other_outputs: bool = True, + **df_check_kwargs, ): all_file_suffixes = set() @@ -331,19 +303,27 @@ def check_outputs( output_fp = cohort_dir / output_name + files_found = [str(fp.relative_to(cohort_dir)) for fp in cohort_dir.glob("**/*{file_suffix}")] + if not output_fp.is_file(): - raise AssertionError(f"Expected {output_fp} to exist.") + raise AssertionError( + f"Wanted {output_fp.relative_to(cohort_dir)} to exist. " + f"{len(files_found)} {file_suffix} files found: {', '.join(files_found)}" + ) + + msg = f"Expected {output_fp.relative_to(cohort_dir)} to be equal to the target" match file_suffix: case ".parquet": - check_df_output(output_fp, want) + got_df = pl.read_parquet(output_fp, glob=False) + assert_df_equal(want, got_df, msg=msg, **df_check_kwargs) case ".nrt": - check_NRT_output(output_fp, want) + check_NRT_output(output_fp, want, msg=msg) case ".json": with open(output_fp) as f: got = json.load(f) assert got == want, ( - f"Expected JSON at {output_fp} to be equal to the target.\n" + f"Expected JSON at {output_fp.relative_to(cohort_dir)} to be equal to the target.\n" f"Wanted:\n{want}\n" f"Got:\n{got}" ) @@ -371,8 +351,12 @@ def single_stage_tester( should_error: bool = False, config_name: str = "preprocess", input_files: dict[str, FILE_T] | None = None, + df_check_kwargs: dict | None = None, **pipeline_kwargs, ): + if df_check_kwargs is None: + df_check_kwargs = {} + with input_dataset(input_files) as (input_dir, cohort_dir): for k, v in pipeline_kwargs.items(): if type(v) is str and "{input_dir}" in v: @@ -409,13 +393,16 @@ def single_stage_tester( try: check_outputs( - cohort_dir, want_outputs=want_outputs, assert_no_other_outputs=assert_no_other_outputs + cohort_dir, + want_outputs=want_outputs, + assert_no_other_outputs=assert_no_other_outputs, + **df_check_kwargs, ) except Exception as e: raise AssertionError( - f"Single stage transform {stage_name} failed.\n" + f"Single stage transform {stage_name} failed -- {e}:\n" f"Script stdout:\n{stdout}\n" - f"Script stderr:\n{stderr}" + f"Script stderr:\n{stderr}\n" ) from e @@ -473,8 +460,12 @@ def multi_stage_tester( do_pass_stage_name=do_pass_stage_name[stage], ) - check_outputs( - cohort_dir, - want_outputs=want_outputs, - assert_no_other_outputs=assert_no_other_outputs, - ) + try: + check_outputs( + cohort_dir, + want_outputs=want_outputs, + assert_no_other_outputs=assert_no_other_outputs, + check_column_order=False, + ) + except Exception as e: + raise AssertionError(f"{n_stages}-stage pipeline ({stage_names}) failed--{e}") from e From 3d0ad006700905992a39bdec895482287ec9a49b Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 19:10:19 -0400 Subject: [PATCH 19/62] Added an event conversion test. --- .../test_convert_to_sharded_events.py | 229 ++++++++++++++++++ 1 file changed, 229 insertions(+) create mode 100644 tests/MEDS_Extract/test_convert_to_sharded_events.py diff --git a/tests/MEDS_Extract/test_convert_to_sharded_events.py b/tests/MEDS_Extract/test_convert_to_sharded_events.py new file mode 100644 index 00000000..9eae09ff --- /dev/null +++ b/tests/MEDS_Extract/test_convert_to_sharded_events.py @@ -0,0 +1,229 @@ +"""Tests the convert to sharded events process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +from io import StringIO + +import polars as pl + +from tests.MEDS_Extract import CONVERT_TO_SHARDED_EVENTS_SCRIPT +from tests.utils import parse_shards_yaml, single_stage_tester + +SUBJECTS_CSV = """ +MRN,dob,eye_color,height +1195293,06/20/1978,BLUE,164.6868838269085 +239684,12/28/1980,BROWN,175.271115221764 +1500733,07/20/1986,BROWN,158.60131573580904 +814703,03/28/1976,HAZEL,156.48559093209357 +754281,12/19/1988,BROWN,166.22261567137025 +68729,03/09/1978,HAZEL,160.3953106166676 +""" + +ADMIT_VITALS_0_10_CSV = """ +subject_id,admit_date,disch_date,department,vitals_date,HR,temp +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 +754281,"01/03/2010, 06:27:59","01/03/2010, 08:22:13",PULMONARY,"01/03/2010, 06:27:59",142.0,99.8 +814703,"02/05/2010, 05:55:39","02/05/2010, 07:02:30",ORTHOPEDIC,"02/05/2010, 05:55:39",170.2,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:25:35",113.4,95.8 +68729,"05/26/2010, 02:30:56","05/26/2010, 04:51:52",PULMONARY,"05/26/2010, 02:30:56",86.0,97.8 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:12:31",112.5,99.8 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 16:20:49",90.1,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:48:48",105.1,96.2 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:41:51",102.6,96.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:25:32",114.1,100.0 +""" + +ADMIT_VITALS_10_16_CSV = """ +subject_id,admit_date,disch_date,department,vitals_date,HR,temp +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 14:54:38",91.4,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:41:33",107.5,100.4 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:24:44",107.7,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:45:19",119.8,99.9 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:23:52",109.0,100.0 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 15:39:49",84.4,100.3 +""" + +EVENT_CFGS_YAML = """ +subjects: + subject_id_col: MRN + eye_color: + code: + - EYE_COLOR + - col(eye_color) + time: null + _metadata: + demo_metadata: + description: description + height: + code: HEIGHT + time: null + numeric_value: height + dob: + code: DOB + time: col(dob) + time_format: "%m/%d/%Y" +admit_vitals: + admissions: + code: + - ADMISSION + - col(department) + time: col(admit_date) + time_format: "%m/%d/%Y, %H:%M:%S" + discharge: + code: DISCHARGE + time: col(disch_date) + time_format: "%m/%d/%Y, %H:%M:%S" + HR: + code: HR + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + temp: + code: TEMP + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} +""" + +SHARDS_JSON = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +WANT_OUTPUTS = parse_shards_yaml( + """ + data/train/0/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + + data/train/1/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + + data/tuning/0/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + + data/held_out/0/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + + data/train/0/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + data/train/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + data/train/1/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + + data/train/1/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + + data/tuning/0/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + + data/tuning/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + + data/held_out/0/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + + data/held_out/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ +) + + +def test_convert_to_sharded_events(): + single_stage_tester( + script=CONVERT_TO_SHARDED_EVENTS_SCRIPT, + stage_name="convert_to_sharded_events", + stage_kwargs=None, + config_name="extract", + input_files={ + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_0_10_CSV)), + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), + "event_cfgs.yaml": EVENT_CFGS_YAML, + "metadata/.shards.json": SHARDS_JSON, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", # This makes the escape pass to hydra + shards_map_fp="{input_dir}/metadata/.shards.json", + want_outputs=WANT_OUTPUTS, + df_check_kwargs={"check_row_order": False, "check_column_order": False, "check_dtypes": False}, + ) From ab31a070a7215d647878b2c36ea1962d3e8c5cec Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 19:49:36 -0400 Subject: [PATCH 20/62] Removing outdated comments. --- tests/MEDS_Extract/test_convert_to_sharded_events.py | 2 +- tests/MEDS_Extract/test_split_and_shard_subjects.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/MEDS_Extract/test_convert_to_sharded_events.py b/tests/MEDS_Extract/test_convert_to_sharded_events.py index 9eae09ff..074e897d 100644 --- a/tests/MEDS_Extract/test_convert_to_sharded_events.py +++ b/tests/MEDS_Extract/test_convert_to_sharded_events.py @@ -222,7 +222,7 @@ def test_convert_to_sharded_events(): "event_cfgs.yaml": EVENT_CFGS_YAML, "metadata/.shards.json": SHARDS_JSON, }, - event_conversion_config_fp="{input_dir}/event_cfgs.yaml", # This makes the escape pass to hydra + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", shards_map_fp="{input_dir}/metadata/.shards.json", want_outputs=WANT_OUTPUTS, df_check_kwargs={"check_row_order": False, "check_column_order": False, "check_dtypes": False}, diff --git a/tests/MEDS_Extract/test_split_and_shard_subjects.py b/tests/MEDS_Extract/test_split_and_shard_subjects.py index b57cdd1e..db74896d 100644 --- a/tests/MEDS_Extract/test_split_and_shard_subjects.py +++ b/tests/MEDS_Extract/test_split_and_shard_subjects.py @@ -128,6 +128,6 @@ def test_split_and_shard(): "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), "event_cfgs.yaml": EVENT_CFGS_YAML, }, - event_conversion_config_fp="{input_dir}/event_cfgs.yaml", # This makes the escape pass to hydra + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", want_outputs={"metadata/.shards.json": EXPECTED_SPLITS}, ) From a812425467957a706607c63894f002e1c4c099b3 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 11:01:48 -0400 Subject: [PATCH 21/62] Adding merge test. --- .../MEDS_Extract/test_merge_to_MEDS_cohort.py | 268 ++++++++++++++++++ 1 file changed, 268 insertions(+) create mode 100644 tests/MEDS_Extract/test_merge_to_MEDS_cohort.py diff --git a/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py b/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py new file mode 100644 index 00000000..b9a21ffb --- /dev/null +++ b/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py @@ -0,0 +1,268 @@ +"""Tests the merge to MEDS events process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +from tests.MEDS_Extract import MERGE_TO_MEDS_COHORT_SCRIPT +from tests.utils import parse_shards_yaml, single_stage_tester + +EVENT_CFGS_YAML = """ +subjects: + subject_id_col: MRN + eye_color: + code: + - EYE_COLOR + - col(eye_color) + time: null + _metadata: + demo_metadata: + description: description + height: + code: HEIGHT + time: null + numeric_value: height + dob: + code: DOB + time: col(dob) + time_format: "%m/%d/%Y" +admit_vitals: + admissions: + code: + - ADMISSION + - col(department) + time: col(admit_date) + time_format: "%m/%d/%Y, %H:%M:%S" + discharge: + code: DISCHARGE + time: col(disch_date) + time_format: "%m/%d/%Y, %H:%M:%S" + HR: + code: HR + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + temp: + code: TEMP + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} +""" + +SHARDS_JSON = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +INPUT_SHARDS = parse_shards_yaml( + """ + data/train/0/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + + data/train/1/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + + data/tuning/0/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + + data/held_out/0/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + + data/train/0/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + data/train/0/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + data/train/1/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + + data/train/1/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value + + data/tuning/0/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + + data/tuning/0/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value + + data/held_out/0/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + + data/held_out/0/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ +) + +WANT_OUTPUTS = parse_shards_yaml( + """ + data/train/0: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + + data/train/1: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + + data/tuning/0: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + + data/held_out/0: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ +) + + +def test_convert_to_sharded_events(): + single_stage_tester( + script=MERGE_TO_MEDS_COHORT_SCRIPT, + stage_name="merge_to_MEDS_cohort", + stage_kwargs=None, + config_name="extract", + input_files={ + **INPUT_SHARDS, + "event_cfgs.yaml": EVENT_CFGS_YAML, + "metadata/.shards.json": SHARDS_JSON, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + shards_map_fp="{input_dir}/metadata/.shards.json", + want_outputs=WANT_OUTPUTS, + df_check_kwargs={"check_column_order": False}, + ) From 5393d83bbfd382aa4ac363beb62a2905befa8588 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 11:10:28 -0400 Subject: [PATCH 22/62] Added data finalization test. --- tests/MEDS_Extract/test_finalize_MEDS_data.py | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 tests/MEDS_Extract/test_finalize_MEDS_data.py diff --git a/tests/MEDS_Extract/test_finalize_MEDS_data.py b/tests/MEDS_Extract/test_finalize_MEDS_data.py new file mode 100644 index 00000000..d9a3e0ad --- /dev/null +++ b/tests/MEDS_Extract/test_finalize_MEDS_data.py @@ -0,0 +1,111 @@ +"""Tests the finalize MEDS data process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +import polars as pl + +from tests.MEDS_Extract import FINALIZE_DATA_SCRIPT +from tests.utils import parse_shards_yaml, single_stage_tester + +INPUT_SHARDS = parse_shards_yaml( + """ + data/train/0: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + + data/train/1: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + + data/tuning/0: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + + data/held_out/0: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ +) + +WANT_OUTPUTS = { + k: v.with_columns( + pl.col("subject_id").cast(pl.Int64), + pl.col("time").cast(pl.Datetime("us")), + pl.col("code").cast(pl.String), + pl.col("numeric_value").cast(pl.Float32), + ) + for k, v in INPUT_SHARDS.items() +} + + +def test_convert_to_sharded_events(): + single_stage_tester( + script=FINALIZE_DATA_SCRIPT, + stage_name="finalize_MEDS_data", + stage_kwargs=None, + config_name="extract", + input_files=INPUT_SHARDS, + want_outputs=WANT_OUTPUTS, + df_check_kwargs={"check_column_order": True, "check_dtypes": True, "check_row_order": True}, + ) From 886817883136ece9e7e466efecc9cc4ccba1be0a Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 11:25:19 -0400 Subject: [PATCH 23/62] Added the ability to handle list columns to df checker. --- tests/MEDS_Extract/test_extract.py | 5 ----- tests/utils.py | 14 ++++++++++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/MEDS_Extract/test_extract.py b/tests/MEDS_Extract/test_extract.py index be6b2461..954ca21e 100644 --- a/tests/MEDS_Extract/test_extract.py +++ b/tests/MEDS_Extract/test_extract.py @@ -525,14 +525,9 @@ def test_extraction(): got_df = pl.read_parquet(output_file, glob=False) want_df = pl.read_csv(source=StringIO(MEDS_OUTPUT_CODE_METADATA_FILE)).with_columns( - pl.col("code"), pl.col("parent_codes").cast(pl.List(pl.Utf8)), ) - # We collapse the list type as it throws an error in the assert_df_equal otherwise - got_df = got_df.with_columns(pl.col("parent_codes").list.join("||")) - want_df = want_df.with_columns(pl.col("parent_codes").list.join("||")) - assert_df_equal( want=want_df, got=got_df, diff --git a/tests/utils.py b/tests/utils.py index f029e898..450a86b4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -200,6 +200,20 @@ def run_command( def assert_df_equal(want: pl.DataFrame, got: pl.DataFrame, msg: str = None, **kwargs): try: + update_exprs = {} + for k, v in want.schema.items(): + assert k in got.schema, f"missing column {k}." + if kwargs.get("check_dtypes", False): + assert v == got.schema[k], f"column {k} has different types." + if v == pl.List(pl.String) and got.schema[k] == pl.List(pl.String): + update_exprs[k] = pl.col(k).list.join("||") + if update_exprs: + want_cols = want.columns + got_cols = got.columns + + want = want.with_columns(**update_exprs).select(want_cols) + got = got.with_columns(**update_exprs).select(got_cols) + assert_frame_equal(want, got, **kwargs) except AssertionError as e: pl.Config.set_tbl_rows(-1) From 3b77247cd7e54347b73a367dafcb1d1b276f414e Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 11:26:12 -0400 Subject: [PATCH 24/62] Added a metadata extractor test. --- .../test_extract_code_metadata.py | 208 ++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 tests/MEDS_Extract/test_extract_code_metadata.py diff --git a/tests/MEDS_Extract/test_extract_code_metadata.py b/tests/MEDS_Extract/test_extract_code_metadata.py new file mode 100644 index 00000000..af307ce0 --- /dev/null +++ b/tests/MEDS_Extract/test_extract_code_metadata.py @@ -0,0 +1,208 @@ +"""Tests the extract code metadata process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + + +import polars as pl + +from tests.MEDS_Extract import EXTRACT_CODE_METADATA_SCRIPT +from tests.utils import parse_shards_yaml, single_stage_tester + +INPUT_SHARDS = parse_shards_yaml( + """ + data/train/0: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + + data/train/1: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + + data/tuning/0: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + + data/held_out/0: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ +) + + +INPUT_METADATA_FILE = """ +lab_code,title,loinc +HR,Heart Rate,8867-4 +temp,Body Temperature,8310-5 +""" + +DEMO_METADATA_FILE = """ +eye_color,description +BROWN,"Brown Eyes. The most common eye color." +BLUE,"Blue Eyes. Less common than brown." +HAZEL,"Hazel eyes. These are uncommon" +GREEN,"Green eyes. These are rare." +""" + +EVENT_CFGS_YAML = """ +subjects: + subject_id_col: MRN + eye_color: + code: + - EYE_COLOR + - col(eye_color) + time: null + _metadata: + demo_metadata: + description: description + height: + code: HEIGHT + time: null + numeric_value: height + dob: + code: DOB + time: col(dob) + time_format: "%m/%d/%Y" +admit_vitals: + admissions: + code: + - ADMISSION + - col(department) + time: col(admit_date) + time_format: "%m/%d/%Y, %H:%M:%S" + discharge: + code: DISCHARGE + time: col(disch_date) + time_format: "%m/%d/%Y, %H:%M:%S" + HR: + code: HR + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + temp: + code: TEMP + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} +""" + +SHARDS_JSON = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +MEDS_OUTPUT_CODE_METADATA_FILE = """ +code,description,parent_codes +EYE_COLOR//BLUE,"Blue Eyes. Less common than brown.", +EYE_COLOR//BROWN,"Brown Eyes. The most common eye color.", +EYE_COLOR//HAZEL,"Hazel eyes. These are uncommon", +HR,"Heart Rate",LOINC/8867-4 +TEMP,"Body Temperature",LOINC/8310-5 +""" + +WANT_OUTPUTS = { + "metadata/codes": pl.DataFrame( + { + "code": ["EYE_COLOR//BLUE", "EYE_COLOR//BROWN", "EYE_COLOR//HAZEL", "HR", "TEMP"], + "description": [ + "Blue Eyes. Less common than brown.", + "Brown Eyes. The most common eye color.", + "Hazel eyes. These are uncommon", + "Heart Rate", + "Body Temperature", + ], + "parent_codes": [None, None, None, ["LOINC/8867-4"], ["LOINC/8310-5"]], + } + ), +} + + +def test_convert_to_sharded_events(): + single_stage_tester( + script=EXTRACT_CODE_METADATA_SCRIPT, + stage_name="extract_code_metadata", + stage_kwargs=None, + config_name="extract", + input_files={ + **INPUT_SHARDS, + "demo_metadata.csv": DEMO_METADATA_FILE, + "input_metadata.csv": INPUT_METADATA_FILE, + "event_cfgs.yaml": EVENT_CFGS_YAML, + "metadata/.shards.json": SHARDS_JSON, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + shards_map_fp="{input_dir}/metadata/.shards.json", + want_outputs=WANT_OUTPUTS, + df_check_kwargs={"check_row_order": False, "check_column_order": False, "check_dtypes": True}, + assert_no_other_outputs=False, + ) From 04fc97af94189f88d8fbf974b01fadb0f2a5bea7 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 11:36:08 -0400 Subject: [PATCH 25/62] Removed unused constant in test. --- tests/MEDS_Extract/test_extract_code_metadata.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/MEDS_Extract/test_extract_code_metadata.py b/tests/MEDS_Extract/test_extract_code_metadata.py index af307ce0..74a3d02e 100644 --- a/tests/MEDS_Extract/test_extract_code_metadata.py +++ b/tests/MEDS_Extract/test_extract_code_metadata.py @@ -161,15 +161,6 @@ "held_out/0": [1500733], } -MEDS_OUTPUT_CODE_METADATA_FILE = """ -code,description,parent_codes -EYE_COLOR//BLUE,"Blue Eyes. Less common than brown.", -EYE_COLOR//BROWN,"Brown Eyes. The most common eye color.", -EYE_COLOR//HAZEL,"Hazel eyes. These are uncommon", -HR,"Heart Rate",LOINC/8867-4 -TEMP,"Body Temperature",LOINC/8310-5 -""" - WANT_OUTPUTS = { "metadata/codes": pl.DataFrame( { From 2d848091ea49f8ede703c52a4dce069646db8a5f Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 11:46:21 -0400 Subject: [PATCH 26/62] Added finalize MEDS metadata test. --- .../test_finalize_MEDS_metadata.py | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 tests/MEDS_Extract/test_finalize_MEDS_metadata.py diff --git a/tests/MEDS_Extract/test_finalize_MEDS_metadata.py b/tests/MEDS_Extract/test_finalize_MEDS_metadata.py new file mode 100644 index 00000000..274997f4 --- /dev/null +++ b/tests/MEDS_Extract/test_finalize_MEDS_metadata.py @@ -0,0 +1,91 @@ +"""Tests the finalize MEDS metadata process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + + +import polars as pl +from meds import __version__ as MEDS_VERSION + +from MEDS_transforms.utils import get_package_version as get_meds_transform_version +from tests.MEDS_Extract import FINALIZE_METADATA_SCRIPT +from tests.utils import single_stage_tester + +SHARDS_JSON = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +WANT_OUTPUTS = { + "metadata/codes": pl.DataFrame( + { + "code": ["EYE_COLOR//BLUE", "EYE_COLOR//BROWN", "EYE_COLOR//HAZEL", "HR", "TEMP"], + "description": [ + "Blue Eyes. Less common than brown.", + "Brown Eyes. The most common eye color.", + "Hazel eyes. These are uncommon", + "Heart Rate", + "Body Temperature", + ], + "parent_codes": [None, None, None, ["LOINC/8867-4"], ["LOINC/8310-5"]], + } + ), +} + +METADATA_DF = pl.DataFrame( + { + "code": ["EYE_COLOR//BLUE", "EYE_COLOR//BROWN", "EYE_COLOR//HAZEL", "HR", "TEMP"], + "description": [ + "Blue Eyes. Less common than brown.", + "Brown Eyes. The most common eye color.", + "Hazel eyes. These are uncommon", + "Heart Rate", + "Body Temperature", + ], + "parent_codes": [None, None, None, ["LOINC/8867-4"], ["LOINC/8310-5"]], + } +) + +WANT_OUTPUTS = { + "metadata/codes": ( + METADATA_DF.with_columns( + pl.col("code").cast(pl.String), + pl.col("description").cast(pl.String), + pl.col("parent_codes").cast(pl.List(pl.String)), + ).select(["code", "description", "parent_codes"]) + ), + "metadata/subject_splits": pl.DataFrame( + { + "subject_id": [239684, 1195293, 68729, 814703, 754281, 1500733], + "split": ["train", "train", "train", "train", "tuning", "held_out"], + } + ), + "metadata/dataset.json": { + "dataset_name": "TEST", + "dataset_version": "1.0", + "etl_name": "MEDS_transforms", + "etl_version": get_meds_transform_version(), + "meds_version": MEDS_VERSION, + }, +} + + +def test_convert_to_sharded_events(): + single_stage_tester( + script=FINALIZE_METADATA_SCRIPT, + stage_name="finalize_MEDS_metadata", + stage_kwargs=None, + config_name="extract", + input_files={ + "metadata/codes": METADATA_DF, + "metadata/.shards.json": SHARDS_JSON, + }, + **{"etl_metadata.dataset_name": "TEST", "etl_metadata.dataset_version": "1.0"}, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + shards_map_fp="{input_dir}/metadata/.shards.json", + want_outputs=WANT_OUTPUTS, + df_check_kwargs={"check_row_order": False, "check_column_order": True, "check_dtypes": True}, + ) From 9e3d7829392a82b73f0c23a6bf4723ebb41b0c21 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 12:01:12 -0400 Subject: [PATCH 27/62] Added more complex metadata extraction test that replicates part of #156 --- .../test_extract_code_metadata.py | 46 ++++++++++--------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/tests/MEDS_Extract/test_extract_code_metadata.py b/tests/MEDS_Extract/test_extract_code_metadata.py index 74a3d02e..a493c207 100644 --- a/tests/MEDS_Extract/test_extract_code_metadata.py +++ b/tests/MEDS_Extract/test_extract_code_metadata.py @@ -18,30 +18,30 @@ 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, - 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",HR//bpm,102.6 239684,"05/11/2010, 17:41:51",TEMP,96.0 - 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",HR//bpm,105.1 239684,"05/11/2010, 17:48:48",TEMP,96.2 - 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",HR//bpm,113.4 239684,"05/11/2010, 18:25:35",TEMP,95.8 - 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",HR//bpm,112.6 239684,"05/11/2010, 18:57:18",TEMP,95.5 239684,"05/11/2010, 19:27:19",DISCHARGE, 1195293,,EYE_COLOR//BLUE, 1195293,,HEIGHT,164.6868838269085 1195293,"06/20/1978, 00:00:00",DOB, 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, - 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",HR//bpm,109.0 1195293,"06/20/2010, 19:23:52",TEMP,100.0 - 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",HR//bpm,114.1 1195293,"06/20/2010, 19:25:32",TEMP,100.0 - 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",HR//bpm,119.8 1195293,"06/20/2010, 19:45:19",TEMP,99.9 - 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",HR//bpm,112.5 1195293,"06/20/2010, 20:12:31",TEMP,99.8 - 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",HR//bpm,107.7 1195293,"06/20/2010, 20:24:44",TEMP,100.0 - 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",HR//bpm,107.5 1195293,"06/20/2010, 20:41:33",TEMP,100.4 1195293,"06/20/2010, 20:50:04",DISCHARGE, @@ -52,14 +52,14 @@ 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, - 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",HR//bpm,86.0 68729,"05/26/2010, 02:30:56",TEMP,97.8 68729,"05/26/2010, 04:51:52",DISCHARGE, 814703,,EYE_COLOR//HAZEL, 814703,,HEIGHT,156.48559093209357 814703,"03/28/1976, 00:00:00",DOB, 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, - 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",HR//bpm,170.2 814703,"02/05/2010, 05:55:39",TEMP,100.1 814703,"02/05/2010, 07:02:30",DISCHARGE, @@ -69,7 +69,7 @@ 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, - 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",HR//bpm,142.0 754281,"01/03/2010, 06:27:59",TEMP,99.8 754281,"01/03/2010, 08:22:13",DISCHARGE, @@ -79,11 +79,11 @@ 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, - 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",HR//bpm,91.4 1500733,"06/03/2010, 14:54:38",TEMP,100.0 - 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",HR//bpm,84.4 1500733,"06/03/2010, 15:39:49",TEMP,100.3 - 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",HR//bpm,90.1 1500733,"06/03/2010, 16:20:49",TEMP,100.1 1500733,"06/03/2010, 16:44:26",DISCHARGE, """ @@ -91,9 +91,9 @@ INPUT_METADATA_FILE = """ -lab_code,title,loinc -HR,Heart Rate,8867-4 -temp,Body Temperature,8310-5 +lab_code,valueuom,title,loinc +HR,bpm,Heart Rate,8867-4 +temp,,Body Temperature,8310-5 """ DEMO_METADATA_FILE = """ @@ -135,7 +135,9 @@ time: col(disch_date) time_format: "%m/%d/%Y, %H:%M:%S" HR: - code: HR + code: + - HR + - col(valueuom) time: col(vitals_date) time_format: "%m/%d/%Y, %H:%M:%S" numeric_value: HR @@ -152,6 +154,7 @@ input_metadata: description: {"title": {"lab_code": "temp"}} parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} + valueuom: {"valueuom": {"lab_code": "temp"}} """ SHARDS_JSON = { @@ -164,7 +167,7 @@ WANT_OUTPUTS = { "metadata/codes": pl.DataFrame( { - "code": ["EYE_COLOR//BLUE", "EYE_COLOR//BROWN", "EYE_COLOR//HAZEL", "HR", "TEMP"], + "code": ["EYE_COLOR//BLUE", "EYE_COLOR//BROWN", "EYE_COLOR//HAZEL", "HR//bpm", "TEMP"], "description": [ "Blue Eyes. Less common than brown.", "Brown Eyes. The most common eye color.", @@ -173,6 +176,7 @@ "Body Temperature", ], "parent_codes": [None, None, None, ["LOINC/8867-4"], ["LOINC/8310-5"]], + "valueuom": [None, None, None, "bpm", None], } ), } From 838fbeab6d6c28aa3889f595601c3cb20a585ae8 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 14:21:34 -0400 Subject: [PATCH 28/62] Correct config for the test case and add a comment. --- tests/MEDS_Extract/test_extract_code_metadata.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/MEDS_Extract/test_extract_code_metadata.py b/tests/MEDS_Extract/test_extract_code_metadata.py index a493c207..e93bc346 100644 --- a/tests/MEDS_Extract/test_extract_code_metadata.py +++ b/tests/MEDS_Extract/test_extract_code_metadata.py @@ -145,6 +145,7 @@ input_metadata: description: {"title": {"lab_code": "HR"}} parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + valueuom: {"valueuom": {"lab_code": "HR"}} # If we didn't have this valueuom would be null for HR rows temp: code: TEMP time: col(vitals_date) From 3cba71dc91de52330d64bd55253c6faf5bf8fe10 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 14:22:29 -0400 Subject: [PATCH 29/62] Manually added valueuom to MIMIC config. --- MIMIC-IV_Example/configs/event_configs.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/MIMIC-IV_Example/configs/event_configs.yaml b/MIMIC-IV_Example/configs/event_configs.yaml index 619d5a2e..2986a958 100644 --- a/MIMIC-IV_Example/configs/event_configs.yaml +++ b/MIMIC-IV_Example/configs/event_configs.yaml @@ -100,6 +100,7 @@ hosp/labevents: description: ["omop_concept_name", "label"] # List of strings are columns to be collated itemid: "itemid (omop_source_code)" parent_codes: "{omop_vocabulary_id}/{omop_concept_code}" + valueuom: "valueuom" hosp/omr: omr: @@ -218,6 +219,7 @@ icu/chartevents: description: ["omop_concept_name", "label"] # List of strings are columns to be collated itemid: "itemid (omop_source_code)" parent_codes: "{omop_vocabulary_id}/{omop_concept_code}" + valueuom: "valueuom" icu/procedureevents: start: From 6081992c23818264efa38910c06fdbd8f0de5954 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 16:23:59 -0400 Subject: [PATCH 30/62] Prepping tests for addition of runner. --- src/MEDS_transforms/utils.py | 2 +- .../test_convert_to_sharded_events.py | 202 +++--- .../test_extract_code_metadata.py | 148 ++--- tests/MEDS_Extract/test_finalize_MEDS_data.py | 142 ++--- .../MEDS_Extract/test_merge_to_MEDS_cohort.py | 320 +++++----- tests/MEDS_Transforms/test_extract_values.py | 120 ++-- .../test_multi_stage_preprocess_pipeline.py | 602 +++++++++--------- .../MEDS_Transforms/transform_tester_base.py | 167 +++-- tests/utils.py | 30 +- 9 files changed, 869 insertions(+), 864 deletions(-) diff --git a/src/MEDS_transforms/utils.py b/src/MEDS_transforms/utils.py index eff649b4..871b90a0 100644 --- a/src/MEDS_transforms/utils.py +++ b/src/MEDS_transforms/utils.py @@ -1,7 +1,7 @@ """Core utilities for MEDS pipelines built with these tools.""" -import inspect import importlib +import inspect import os import sys from pathlib import Path diff --git a/tests/MEDS_Extract/test_convert_to_sharded_events.py b/tests/MEDS_Extract/test_convert_to_sharded_events.py index 074e897d..6d8311d2 100644 --- a/tests/MEDS_Extract/test_convert_to_sharded_events.py +++ b/tests/MEDS_Extract/test_convert_to_sharded_events.py @@ -104,107 +104,107 @@ WANT_OUTPUTS = parse_shards_yaml( """ - data/train/0/subjects/[0-6).parquet: |-2 - subject_id,time,code,numeric_value - 239684,,EYE_COLOR//BROWN, - 239684,,HEIGHT,175.271115221764 - 239684,"12/28/1980, 00:00:00",DOB, - 1195293,,EYE_COLOR//BLUE, - 1195293,,HEIGHT,164.6868838269085 - 1195293,"06/20/1978, 00:00:00",DOB, - - data/train/1/subjects/[0-6).parquet: |-2 - subject_id,time,code,numeric_value - 68729,,EYE_COLOR//HAZEL, - 68729,,HEIGHT,160.3953106166676 - 68729,"03/09/1978, 00:00:00",DOB, - 814703,,EYE_COLOR//HAZEL, - 814703,,HEIGHT,156.48559093209357 - 814703,"03/28/1976, 00:00:00",DOB, - - data/tuning/0/subjects/[0-6).parquet: |-2 - subject_id,time,code,numeric_value - 754281,,EYE_COLOR//BROWN, - 754281,,HEIGHT,166.22261567137025 - 754281,"12/19/1988, 00:00:00",DOB, - - data/held_out/0/subjects/[0-6).parquet: |-2 - subject_id,time,code,numeric_value - 1500733,,EYE_COLOR//BROWN, - 1500733,,HEIGHT,158.60131573580904 - 1500733,"07/20/1986, 00:00:00",DOB, - - data/train/0/admit_vitals/[0-10).parquet: |-2 - subject_id,time,code,numeric_value - 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, - 239684,"05/11/2010, 17:41:51",HR,102.6 - 239684,"05/11/2010, 17:41:51",TEMP,96.0 - 239684,"05/11/2010, 17:48:48",HR,105.1 - 239684,"05/11/2010, 17:48:48",TEMP,96.2 - 239684,"05/11/2010, 18:25:35",HR,113.4 - 239684,"05/11/2010, 18:25:35",TEMP,95.8 - 239684,"05/11/2010, 18:57:18",HR,112.6 - 239684,"05/11/2010, 18:57:18",TEMP,95.5 - 239684,"05/11/2010, 19:27:19",DISCHARGE, - 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, - 1195293,"06/20/2010, 19:25:32",HR,114.1 - 1195293,"06/20/2010, 19:25:32",TEMP,100.0 - 1195293,"06/20/2010, 20:12:31",HR,112.5 - 1195293,"06/20/2010, 20:12:31",TEMP,99.8 - 1195293,"06/20/2010, 20:50:04",DISCHARGE, - - data/train/0/admit_vitals/[10-16).parquet: |-2 - subject_id,time,code,numeric_value - 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, - 1195293,"06/20/2010, 19:23:52",HR,109.0 - 1195293,"06/20/2010, 19:23:52",TEMP,100.0 - 1195293,"06/20/2010, 19:45:19",HR,119.8 - 1195293,"06/20/2010, 19:45:19",TEMP,99.9 - 1195293,"06/20/2010, 20:24:44",HR,107.7 - 1195293,"06/20/2010, 20:24:44",TEMP,100.0 - 1195293,"06/20/2010, 20:41:33",HR,107.5 - 1195293,"06/20/2010, 20:41:33",TEMP,100.4 - 1195293,"06/20/2010, 20:50:04",DISCHARGE, - - data/train/1/admit_vitals/[0-10).parquet: |-2 - subject_id,time,code,numeric_value - 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, - 68729,"05/26/2010, 02:30:56",HR,86.0 - 68729,"05/26/2010, 02:30:56",TEMP,97.8 - 68729,"05/26/2010, 04:51:52",DISCHARGE, - 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, - 814703,"02/05/2010, 05:55:39",HR,170.2 - 814703,"02/05/2010, 05:55:39",TEMP,100.1 - 814703,"02/05/2010, 07:02:30",DISCHARGE, - - data/train/1/admit_vitals/[10-16).parquet: |-2 - subject_id,time,code,numeric_value - - data/tuning/0/admit_vitals/[0-10).parquet: |-2 - subject_id,time,code,numeric_value - 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, - 754281,"01/03/2010, 06:27:59",HR,142.0 - 754281,"01/03/2010, 06:27:59",TEMP,99.8 - 754281,"01/03/2010, 08:22:13",DISCHARGE, - - data/tuning/0/admit_vitals/[10-16).parquet: |-2 - subject_id,time,code,numeric_value - - data/held_out/0/admit_vitals/[0-10).parquet: |-2 - subject_id,time,code,numeric_value - 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, - 1500733,"06/03/2010, 16:20:49",HR,90.1 - 1500733,"06/03/2010, 16:20:49",TEMP,100.1 - 1500733,"06/03/2010, 16:44:26",DISCHARGE, - - data/held_out/0/admit_vitals/[10-16).parquet: |-2 - subject_id,time,code,numeric_value - 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, - 1500733,"06/03/2010, 14:54:38",HR,91.4 - 1500733,"06/03/2010, 14:54:38",TEMP,100.0 - 1500733,"06/03/2010, 15:39:49",HR,84.4 - 1500733,"06/03/2010, 15:39:49",TEMP,100.3 - 1500733,"06/03/2010, 16:44:26",DISCHARGE, +data/train/0/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + +data/train/1/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + +data/tuning/0/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + +data/held_out/0/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + +data/train/0/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + +data/train/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + +data/train/1/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + +data/train/1/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + +data/tuning/0/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + +data/tuning/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + +data/held_out/0/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + +data/held_out/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, """ ) diff --git a/tests/MEDS_Extract/test_extract_code_metadata.py b/tests/MEDS_Extract/test_extract_code_metadata.py index e93bc346..7700426f 100644 --- a/tests/MEDS_Extract/test_extract_code_metadata.py +++ b/tests/MEDS_Extract/test_extract_code_metadata.py @@ -12,80 +12,80 @@ INPUT_SHARDS = parse_shards_yaml( """ - data/train/0: |-2 - subject_id,time,code,numeric_value - 239684,,EYE_COLOR//BROWN, - 239684,,HEIGHT,175.271115221764 - 239684,"12/28/1980, 00:00:00",DOB, - 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, - 239684,"05/11/2010, 17:41:51",HR//bpm,102.6 - 239684,"05/11/2010, 17:41:51",TEMP,96.0 - 239684,"05/11/2010, 17:48:48",HR//bpm,105.1 - 239684,"05/11/2010, 17:48:48",TEMP,96.2 - 239684,"05/11/2010, 18:25:35",HR//bpm,113.4 - 239684,"05/11/2010, 18:25:35",TEMP,95.8 - 239684,"05/11/2010, 18:57:18",HR//bpm,112.6 - 239684,"05/11/2010, 18:57:18",TEMP,95.5 - 239684,"05/11/2010, 19:27:19",DISCHARGE, - 1195293,,EYE_COLOR//BLUE, - 1195293,,HEIGHT,164.6868838269085 - 1195293,"06/20/1978, 00:00:00",DOB, - 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, - 1195293,"06/20/2010, 19:23:52",HR//bpm,109.0 - 1195293,"06/20/2010, 19:23:52",TEMP,100.0 - 1195293,"06/20/2010, 19:25:32",HR//bpm,114.1 - 1195293,"06/20/2010, 19:25:32",TEMP,100.0 - 1195293,"06/20/2010, 19:45:19",HR//bpm,119.8 - 1195293,"06/20/2010, 19:45:19",TEMP,99.9 - 1195293,"06/20/2010, 20:12:31",HR//bpm,112.5 - 1195293,"06/20/2010, 20:12:31",TEMP,99.8 - 1195293,"06/20/2010, 20:24:44",HR//bpm,107.7 - 1195293,"06/20/2010, 20:24:44",TEMP,100.0 - 1195293,"06/20/2010, 20:41:33",HR//bpm,107.5 - 1195293,"06/20/2010, 20:41:33",TEMP,100.4 - 1195293,"06/20/2010, 20:50:04",DISCHARGE, - - - data/train/1: |-2 - subject_id,time,code,numeric_value - 68729,,EYE_COLOR//HAZEL, - 68729,,HEIGHT,160.3953106166676 - 68729,"03/09/1978, 00:00:00",DOB, - 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, - 68729,"05/26/2010, 02:30:56",HR//bpm,86.0 - 68729,"05/26/2010, 02:30:56",TEMP,97.8 - 68729,"05/26/2010, 04:51:52",DISCHARGE, - 814703,,EYE_COLOR//HAZEL, - 814703,,HEIGHT,156.48559093209357 - 814703,"03/28/1976, 00:00:00",DOB, - 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, - 814703,"02/05/2010, 05:55:39",HR//bpm,170.2 - 814703,"02/05/2010, 05:55:39",TEMP,100.1 - 814703,"02/05/2010, 07:02:30",DISCHARGE, - - data/tuning/0: |-2 - subject_id,time,code,numeric_value - 754281,,EYE_COLOR//BROWN, - 754281,,HEIGHT,166.22261567137025 - 754281,"12/19/1988, 00:00:00",DOB, - 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, - 754281,"01/03/2010, 06:27:59",HR//bpm,142.0 - 754281,"01/03/2010, 06:27:59",TEMP,99.8 - 754281,"01/03/2010, 08:22:13",DISCHARGE, - - data/held_out/0: |-2 - subject_id,time,code,numeric_value - 1500733,,EYE_COLOR//BROWN, - 1500733,,HEIGHT,158.60131573580904 - 1500733,"07/20/1986, 00:00:00",DOB, - 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, - 1500733,"06/03/2010, 14:54:38",HR//bpm,91.4 - 1500733,"06/03/2010, 14:54:38",TEMP,100.0 - 1500733,"06/03/2010, 15:39:49",HR//bpm,84.4 - 1500733,"06/03/2010, 15:39:49",TEMP,100.3 - 1500733,"06/03/2010, 16:20:49",HR//bpm,90.1 - 1500733,"06/03/2010, 16:20:49",TEMP,100.1 - 1500733,"06/03/2010, 16:44:26",DISCHARGE, +data/train/0: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR//bpm,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR//bpm,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR//bpm,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR//bpm,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR//bpm,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32",HR//bpm,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR//bpm,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31",HR//bpm,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44",HR//bpm,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR//bpm,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + +data/train/1: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR//bpm,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR//bpm,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + +data/tuning/0: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR//bpm,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + +data/held_out/0: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR//bpm,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR//bpm,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49",HR//bpm,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, """ ) diff --git a/tests/MEDS_Extract/test_finalize_MEDS_data.py b/tests/MEDS_Extract/test_finalize_MEDS_data.py index d9a3e0ad..3348d121 100644 --- a/tests/MEDS_Extract/test_finalize_MEDS_data.py +++ b/tests/MEDS_Extract/test_finalize_MEDS_data.py @@ -11,81 +11,81 @@ INPUT_SHARDS = parse_shards_yaml( """ - data/train/0: |-2 - subject_id,time,code,numeric_value - 239684,,EYE_COLOR//BROWN, - 239684,,HEIGHT,175.271115221764 - 239684,"12/28/1980, 00:00:00",DOB, - 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, - 239684,"05/11/2010, 17:41:51",HR,102.6 - 239684,"05/11/2010, 17:41:51",TEMP,96.0 - 239684,"05/11/2010, 17:48:48",HR,105.1 - 239684,"05/11/2010, 17:48:48",TEMP,96.2 - 239684,"05/11/2010, 18:25:35",HR,113.4 - 239684,"05/11/2010, 18:25:35",TEMP,95.8 - 239684,"05/11/2010, 18:57:18",HR,112.6 - 239684,"05/11/2010, 18:57:18",TEMP,95.5 - 239684,"05/11/2010, 19:27:19",DISCHARGE, - 1195293,,EYE_COLOR//BLUE, - 1195293,,HEIGHT,164.6868838269085 - 1195293,"06/20/1978, 00:00:00",DOB, - 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, - 1195293,"06/20/2010, 19:23:52",HR,109.0 - 1195293,"06/20/2010, 19:23:52",TEMP,100.0 - 1195293,"06/20/2010, 19:25:32",HR,114.1 - 1195293,"06/20/2010, 19:25:32",TEMP,100.0 - 1195293,"06/20/2010, 19:45:19",HR,119.8 - 1195293,"06/20/2010, 19:45:19",TEMP,99.9 - 1195293,"06/20/2010, 20:12:31",HR,112.5 - 1195293,"06/20/2010, 20:12:31",TEMP,99.8 - 1195293,"06/20/2010, 20:24:44",HR,107.7 - 1195293,"06/20/2010, 20:24:44",TEMP,100.0 - 1195293,"06/20/2010, 20:41:33",HR,107.5 - 1195293,"06/20/2010, 20:41:33",TEMP,100.4 - 1195293,"06/20/2010, 20:50:04",DISCHARGE, +data/train/0: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, - data/train/1: |-2 - subject_id,time,code,numeric_value - 68729,,EYE_COLOR//HAZEL, - 68729,,HEIGHT,160.3953106166676 - 68729,"03/09/1978, 00:00:00",DOB, - 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, - 68729,"05/26/2010, 02:30:56",HR,86.0 - 68729,"05/26/2010, 02:30:56",TEMP,97.8 - 68729,"05/26/2010, 04:51:52",DISCHARGE, - 814703,,EYE_COLOR//HAZEL, - 814703,,HEIGHT,156.48559093209357 - 814703,"03/28/1976, 00:00:00",DOB, - 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, - 814703,"02/05/2010, 05:55:39",HR,170.2 - 814703,"02/05/2010, 05:55:39",TEMP,100.1 - 814703,"02/05/2010, 07:02:30",DISCHARGE, +data/train/1: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, - data/tuning/0: |-2 - subject_id,time,code,numeric_value - 754281,,EYE_COLOR//BROWN, - 754281,,HEIGHT,166.22261567137025 - 754281,"12/19/1988, 00:00:00",DOB, - 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, - 754281,"01/03/2010, 06:27:59",HR,142.0 - 754281,"01/03/2010, 06:27:59",TEMP,99.8 - 754281,"01/03/2010, 08:22:13",DISCHARGE, +data/tuning/0: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, - data/held_out/0: |-2 - subject_id,time,code,numeric_value - 1500733,,EYE_COLOR//BROWN, - 1500733,,HEIGHT,158.60131573580904 - 1500733,"07/20/1986, 00:00:00",DOB, - 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, - 1500733,"06/03/2010, 14:54:38",HR,91.4 - 1500733,"06/03/2010, 14:54:38",TEMP,100.0 - 1500733,"06/03/2010, 15:39:49",HR,84.4 - 1500733,"06/03/2010, 15:39:49",TEMP,100.3 - 1500733,"06/03/2010, 16:20:49",HR,90.1 - 1500733,"06/03/2010, 16:20:49",TEMP,100.1 - 1500733,"06/03/2010, 16:44:26",DISCHARGE, - """ +data/held_out/0: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, +""" ) WANT_OUTPUTS = { diff --git a/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py b/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py index b9a21ffb..74688043 100644 --- a/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py +++ b/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py @@ -66,186 +66,186 @@ INPUT_SHARDS = parse_shards_yaml( """ - data/train/0/subjects/[0-6): |-2 - subject_id,time,code,numeric_value - 239684,,EYE_COLOR//BROWN, - 239684,,HEIGHT,175.271115221764 - 239684,"12/28/1980, 00:00:00",DOB, - 1195293,,EYE_COLOR//BLUE, - 1195293,,HEIGHT,164.6868838269085 - 1195293,"06/20/1978, 00:00:00",DOB, +data/train/0/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, - data/train/1/subjects/[0-6): |-2 - subject_id,time,code,numeric_value - 68729,,EYE_COLOR//HAZEL, - 68729,,HEIGHT,160.3953106166676 - 68729,"03/09/1978, 00:00:00",DOB, - 814703,,EYE_COLOR//HAZEL, - 814703,,HEIGHT,156.48559093209357 - 814703,"03/28/1976, 00:00:00",DOB, +data/train/1/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, - data/tuning/0/subjects/[0-6): |-2 - subject_id,time,code,numeric_value - 754281,,EYE_COLOR//BROWN, - 754281,,HEIGHT,166.22261567137025 - 754281,"12/19/1988, 00:00:00",DOB, +data/tuning/0/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, - data/held_out/0/subjects/[0-6): |-2 - subject_id,time,code,numeric_value - 1500733,,EYE_COLOR//BROWN, - 1500733,,HEIGHT,158.60131573580904 - 1500733,"07/20/1986, 00:00:00",DOB, +data/held_out/0/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, - data/train/0/admit_vitals/[0-10): |-2 - subject_id,time,code,numeric_value - 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, - 239684,"05/11/2010, 17:41:51",HR,102.6 - 239684,"05/11/2010, 17:41:51",TEMP,96.0 - 239684,"05/11/2010, 17:48:48",HR,105.1 - 239684,"05/11/2010, 17:48:48",TEMP,96.2 - 239684,"05/11/2010, 18:25:35",HR,113.4 - 239684,"05/11/2010, 18:25:35",TEMP,95.8 - 239684,"05/11/2010, 18:57:18",HR,112.6 - 239684,"05/11/2010, 18:57:18",TEMP,95.5 - 239684,"05/11/2010, 19:27:19",DISCHARGE, - 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, - 1195293,"06/20/2010, 19:25:32",HR,114.1 - 1195293,"06/20/2010, 19:25:32",TEMP,100.0 - 1195293,"06/20/2010, 20:12:31",HR,112.5 - 1195293,"06/20/2010, 20:12:31",TEMP,99.8 - 1195293,"06/20/2010, 20:50:04",DISCHARGE, +data/train/0/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, - data/train/0/admit_vitals/[10-16): |-2 - subject_id,time,code,numeric_value - 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, - 1195293,"06/20/2010, 19:23:52",HR,109.0 - 1195293,"06/20/2010, 19:23:52",TEMP,100.0 - 1195293,"06/20/2010, 19:45:19",HR,119.8 - 1195293,"06/20/2010, 19:45:19",TEMP,99.9 - 1195293,"06/20/2010, 20:24:44",HR,107.7 - 1195293,"06/20/2010, 20:24:44",TEMP,100.0 - 1195293,"06/20/2010, 20:41:33",HR,107.5 - 1195293,"06/20/2010, 20:41:33",TEMP,100.4 - 1195293,"06/20/2010, 20:50:04",DISCHARGE, +data/train/0/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, - data/train/1/admit_vitals/[0-10): |-2 - subject_id,time,code,numeric_value - 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, - 68729,"05/26/2010, 02:30:56",HR,86.0 - 68729,"05/26/2010, 02:30:56",TEMP,97.8 - 68729,"05/26/2010, 04:51:52",DISCHARGE, - 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, - 814703,"02/05/2010, 05:55:39",HR,170.2 - 814703,"02/05/2010, 05:55:39",TEMP,100.1 - 814703,"02/05/2010, 07:02:30",DISCHARGE, +data/train/1/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, - data/train/1/admit_vitals/[10-16): |-2 - subject_id,time,code,numeric_value +data/train/1/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value - data/tuning/0/admit_vitals/[0-10): |-2 - subject_id,time,code,numeric_value - 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, - 754281,"01/03/2010, 06:27:59",HR,142.0 - 754281,"01/03/2010, 06:27:59",TEMP,99.8 - 754281,"01/03/2010, 08:22:13",DISCHARGE, +data/tuning/0/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, - data/tuning/0/admit_vitals/[10-16): |-2 - subject_id,time,code,numeric_value +data/tuning/0/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value - data/held_out/0/admit_vitals/[0-10): |-2 - subject_id,time,code,numeric_value - 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, - 1500733,"06/03/2010, 16:20:49",HR,90.1 - 1500733,"06/03/2010, 16:20:49",TEMP,100.1 - 1500733,"06/03/2010, 16:44:26",DISCHARGE, +data/held_out/0/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, - data/held_out/0/admit_vitals/[10-16): |-2 - subject_id,time,code,numeric_value - 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, - 1500733,"06/03/2010, 14:54:38",HR,91.4 - 1500733,"06/03/2010, 14:54:38",TEMP,100.0 - 1500733,"06/03/2010, 15:39:49",HR,84.4 - 1500733,"06/03/2010, 15:39:49",TEMP,100.3 - 1500733,"06/03/2010, 16:44:26",DISCHARGE, +data/held_out/0/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, """ ) WANT_OUTPUTS = parse_shards_yaml( """ - data/train/0: |-2 - subject_id,time,code,numeric_value - 239684,,EYE_COLOR//BROWN, - 239684,,HEIGHT,175.271115221764 - 239684,"12/28/1980, 00:00:00",DOB, - 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, - 239684,"05/11/2010, 17:41:51",HR,102.6 - 239684,"05/11/2010, 17:41:51",TEMP,96.0 - 239684,"05/11/2010, 17:48:48",HR,105.1 - 239684,"05/11/2010, 17:48:48",TEMP,96.2 - 239684,"05/11/2010, 18:25:35",HR,113.4 - 239684,"05/11/2010, 18:25:35",TEMP,95.8 - 239684,"05/11/2010, 18:57:18",HR,112.6 - 239684,"05/11/2010, 18:57:18",TEMP,95.5 - 239684,"05/11/2010, 19:27:19",DISCHARGE, - 1195293,,EYE_COLOR//BLUE, - 1195293,,HEIGHT,164.6868838269085 - 1195293,"06/20/1978, 00:00:00",DOB, - 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, - 1195293,"06/20/2010, 19:23:52",HR,109.0 - 1195293,"06/20/2010, 19:23:52",TEMP,100.0 - 1195293,"06/20/2010, 19:25:32",HR,114.1 - 1195293,"06/20/2010, 19:25:32",TEMP,100.0 - 1195293,"06/20/2010, 19:45:19",HR,119.8 - 1195293,"06/20/2010, 19:45:19",TEMP,99.9 - 1195293,"06/20/2010, 20:12:31",HR,112.5 - 1195293,"06/20/2010, 20:12:31",TEMP,99.8 - 1195293,"06/20/2010, 20:24:44",HR,107.7 - 1195293,"06/20/2010, 20:24:44",TEMP,100.0 - 1195293,"06/20/2010, 20:41:33",HR,107.5 - 1195293,"06/20/2010, 20:41:33",TEMP,100.4 - 1195293,"06/20/2010, 20:50:04",DISCHARGE, +data/train/0: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, - data/train/1: |-2 - subject_id,time,code,numeric_value - 68729,,EYE_COLOR//HAZEL, - 68729,,HEIGHT,160.3953106166676 - 68729,"03/09/1978, 00:00:00",DOB, - 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, - 68729,"05/26/2010, 02:30:56",HR,86.0 - 68729,"05/26/2010, 02:30:56",TEMP,97.8 - 68729,"05/26/2010, 04:51:52",DISCHARGE, - 814703,,EYE_COLOR//HAZEL, - 814703,,HEIGHT,156.48559093209357 - 814703,"03/28/1976, 00:00:00",DOB, - 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, - 814703,"02/05/2010, 05:55:39",HR,170.2 - 814703,"02/05/2010, 05:55:39",TEMP,100.1 - 814703,"02/05/2010, 07:02:30",DISCHARGE, +data/train/1: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, - data/tuning/0: |-2 - subject_id,time,code,numeric_value - 754281,,EYE_COLOR//BROWN, - 754281,,HEIGHT,166.22261567137025 - 754281,"12/19/1988, 00:00:00",DOB, - 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, - 754281,"01/03/2010, 06:27:59",HR,142.0 - 754281,"01/03/2010, 06:27:59",TEMP,99.8 - 754281,"01/03/2010, 08:22:13",DISCHARGE, +data/tuning/0: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, - data/held_out/0: |-2 - subject_id,time,code,numeric_value - 1500733,,EYE_COLOR//BROWN, - 1500733,,HEIGHT,158.60131573580904 - 1500733,"07/20/1986, 00:00:00",DOB, - 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, - 1500733,"06/03/2010, 14:54:38",HR,91.4 - 1500733,"06/03/2010, 14:54:38",TEMP,100.0 - 1500733,"06/03/2010, 15:39:49",HR,84.4 - 1500733,"06/03/2010, 15:39:49",TEMP,100.3 - 1500733,"06/03/2010, 16:20:49",HR,90.1 - 1500733,"06/03/2010, 16:20:49",TEMP,100.1 - 1500733,"06/03/2010, 16:44:26",DISCHARGE, +data/held_out/0: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, """ ) diff --git a/tests/MEDS_Transforms/test_extract_values.py b/tests/MEDS_Transforms/test_extract_values.py index 4114b3ba..83a2aa3d 100644 --- a/tests/MEDS_Transforms/test_extract_values.py +++ b/tests/MEDS_Transforms/test_extract_values.py @@ -9,71 +9,71 @@ INPUT_SHARDS = parse_shards_yaml( """ - train/0: |-2 - subject_id,time,code,numeric_value,text_value - 239684,,EYE_COLOR//BROWN,, - 239684,"12/28/1980, 00:00:00",DOB,, - 239684,"05/11/2010, 17:41:51",BP,,"120/80" - 1195293,,EYE_COLOR//BLUE,, - 1195293,"06/20/1978, 00:00:00",DOB,, - 1195293,"06/20/2010, 19:23:52",BP,,"144/96" - 1195293,"06/20/2010, 19:23:52",HR,80, - 1195293,"06/20/2010, 19:23:52",TEMP,,"100F" - train/1: |-2 - subject_id,time,code,numeric_value,text_value - 68729,,EYE_COLOR//HAZEL,, - 68729,"03/09/1978, 00:00:00",DOB,, - 814703,"02/05/2010, 05:55:39",HR,170.2, - tuning/0: |-2 - subject_id,time,code,numeric_value,text_value - 754281,,EYE_COLOR//BROWN,, - 754281,"12/19/1988, 00:00:00",DOB,, - 754281,"01/03/2010, 06:27:59",HR,142.0, - 754281,"06/20/2010, 20:23:50",BP,,"134/76" - 754281,"06/20/2010, 21:00:02",TEMP,,"36.2C" - held_out/0: |-2 - subject_id,time,code,numeric_value,text_value - 1500733,,EYE_COLOR//BROWN,, - 1500733,"07/20/1986, 00:00:00",DOB,, - 1500733,"06/03/2010, 14:54:38",HR,91.4 - 1500733,"06/03/2010, 14:54:38",BP,,"123/82" +train/0: |-2 + subject_id,time,code,numeric_value,text_value + 239684,,EYE_COLOR//BROWN,, + 239684,"12/28/1980, 00:00:00",DOB,, + 239684,"05/11/2010, 17:41:51",BP,,"120/80" + 1195293,,EYE_COLOR//BLUE,, + 1195293,"06/20/1978, 00:00:00",DOB,, + 1195293,"06/20/2010, 19:23:52",BP,,"144/96" + 1195293,"06/20/2010, 19:23:52",HR,80, + 1195293,"06/20/2010, 19:23:52",TEMP,,"100F" +train/1: |-2 + subject_id,time,code,numeric_value,text_value + 68729,,EYE_COLOR//HAZEL,, + 68729,"03/09/1978, 00:00:00",DOB,, + 814703,"02/05/2010, 05:55:39",HR,170.2, +tuning/0: |-2 + subject_id,time,code,numeric_value,text_value + 754281,,EYE_COLOR//BROWN,, + 754281,"12/19/1988, 00:00:00",DOB,, + 754281,"01/03/2010, 06:27:59",HR,142.0, + 754281,"06/20/2010, 20:23:50",BP,,"134/76" + 754281,"06/20/2010, 21:00:02",TEMP,,"36.2C" +held_out/0: |-2 + subject_id,time,code,numeric_value,text_value + 1500733,,EYE_COLOR//BROWN,, + 1500733,"07/20/1986, 00:00:00",DOB,, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",BP,,"123/82" """ ) WANT_SHARDS = parse_shards_yaml( """ - train/0: |-2 - subject_id,time,code,numeric_value,text_value - 239684,,EYE_COLOR//BROWN,, - 239684,"12/28/1980, 00:00:00",DOB,, - 239684,"05/11/2010, 17:41:51",BP//SYSTOLIC,120, - 239684,"05/11/2010, 17:41:51",BP//DIASTOLIC,80, - 1195293,,EYE_COLOR//BLUE,, - 1195293,"06/20/1978, 00:00:00",DOB,, - 1195293,"06/20/2010, 19:23:52",BP//SYSTOLIC,144, - 1195293,"06/20/2010, 19:23:52",BP//DIASTOLIC,96, - 1195293,"06/20/2010, 19:23:52",TEMP//F,100, - 1195293,"06/20/2010, 19:23:52",HR,80, - train/1: |-2 - subject_id,time,code,numeric_value,text_value - 68729,,EYE_COLOR//HAZEL,, - 68729,"03/09/1978, 00:00:00",DOB,, - 814703,"02/05/2010, 05:55:39",HR,170.2, - tuning/0: |-2 - subject_id,time,code,numeric_value,text_value - 754281,,EYE_COLOR//BROWN,, - 754281,"12/19/1988, 00:00:00",DOB,, - 754281,"01/03/2010, 06:27:59",HR,142.0, - 754281,"06/20/2010, 20:23:50",BP//SYSTOLIC,134, - 754281,"06/20/2010, 20:23:50",BP//DIASTOLIC,76, - 754281,"06/20/2010, 21:00:02",TEMP//C,36.2, - held_out/0: |-2 - subject_id,time,code,numeric_value,text_value - 1500733,,EYE_COLOR//BROWN,, - 1500733,"07/20/1986, 00:00:00",DOB,, - 1500733,"06/03/2010, 14:54:38",BP//SYSTOLIC,123, - 1500733,"06/03/2010, 14:54:38",BP//DIASTOLIC,82, - 1500733,"06/03/2010, 14:54:38",HR,91.4, +train/0: |-2 + subject_id,time,code,numeric_value,text_value + 239684,,EYE_COLOR//BROWN,, + 239684,"12/28/1980, 00:00:00",DOB,, + 239684,"05/11/2010, 17:41:51",BP//SYSTOLIC,120, + 239684,"05/11/2010, 17:41:51",BP//DIASTOLIC,80, + 1195293,,EYE_COLOR//BLUE,, + 1195293,"06/20/1978, 00:00:00",DOB,, + 1195293,"06/20/2010, 19:23:52",BP//SYSTOLIC,144, + 1195293,"06/20/2010, 19:23:52",BP//DIASTOLIC,96, + 1195293,"06/20/2010, 19:23:52",TEMP//F,100, + 1195293,"06/20/2010, 19:23:52",HR,80, +train/1: |-2 + subject_id,time,code,numeric_value,text_value + 68729,,EYE_COLOR//HAZEL,, + 68729,"03/09/1978, 00:00:00",DOB,, + 814703,"02/05/2010, 05:55:39",HR,170.2, +tuning/0: |-2 + subject_id,time,code,numeric_value,text_value + 754281,,EYE_COLOR//BROWN,, + 754281,"12/19/1988, 00:00:00",DOB,, + 754281,"01/03/2010, 06:27:59",HR,142.0, + 754281,"06/20/2010, 20:23:50",BP//SYSTOLIC,134, + 754281,"06/20/2010, 20:23:50",BP//DIASTOLIC,76, + 754281,"06/20/2010, 21:00:02",TEMP//C,36.2, +held_out/0: |-2 + subject_id,time,code,numeric_value,text_value + 1500733,,EYE_COLOR//BROWN,, + 1500733,"07/20/1986, 00:00:00",DOB,, + 1500733,"06/03/2010, 14:54:38",BP//SYSTOLIC,123, + 1500733,"06/03/2010, 14:54:38",BP//DIASTOLIC,82, + 1500733,"06/03/2010, 14:54:38",HR,91.4, """ ) diff --git a/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py index 6667313f..7a1b1a75 100644 --- a/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py +++ b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py @@ -81,150 +81,150 @@ # After filtering out subjects with fewer than 5 events: WANT_FILTER = parse_shards_yaml( f""" - "filter_subjects/train/0": |-2 - {subject_id_field},time,code,numeric_value - 239684,,EYE_COLOR//BROWN, - 239684,,HEIGHT,175.271115221764 - 239684,"12/28/1980, 00:00:00",DOB, - 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, - 239684,"05/11/2010, 17:41:51",HR,102.6 - 239684,"05/11/2010, 17:41:51",TEMP,96.0 - 239684,"05/11/2010, 17:48:48",HR,105.1 - 239684,"05/11/2010, 17:48:48",TEMP,96.2 - 239684,"05/11/2010, 18:25:35",HR,113.4 - 239684,"05/11/2010, 18:25:35",TEMP,95.8 - 239684,"05/11/2010, 18:57:18",HR,112.6 - 239684,"05/11/2010, 18:57:18",TEMP,95.5 - 239684,"05/11/2010, 19:27:19",DISCHARGE, - 1195293,,EYE_COLOR//BLUE, - 1195293,,HEIGHT,164.6868838269085 - 1195293,"06/20/1978, 00:00:00",DOB, - 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, - 1195293,"06/20/2010, 19:23:52",HR,109.0 - 1195293,"06/20/2010, 19:23:52",TEMP,100.0 - 1195293,"06/20/2010, 19:25:32",HR,114.1 - 1195293,"06/20/2010, 19:25:32",TEMP,100.0 - 1195293,"06/20/2010, 19:45:19",HR,119.8 - 1195293,"06/20/2010, 19:45:19",TEMP,99.9 - 1195293,"06/20/2010, 20:12:31",HR,112.5 - 1195293,"06/20/2010, 20:12:31",TEMP,99.8 - 1195293,"06/20/2010, 20:24:44",HR,107.7 - 1195293,"06/20/2010, 20:24:44",TEMP,100.0 - 1195293,"06/20/2010, 20:41:33",HR,107.5 - 1195293,"06/20/2010, 20:41:33",TEMP,100.4 - 1195293,"06/20/2010, 20:50:04",DISCHARGE, +"filter_subjects/train/0": |-2 + {subject_id_field},time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, - "filter_subjects/train/1": |-2 - {subject_id_field},time,code,numeric_value +"filter_subjects/train/1": |-2 + {subject_id_field},time,code,numeric_value - "filter_subjects/tuning/0": |-2 - {subject_id_field},time,code,numeric_value +"filter_subjects/tuning/0": |-2 + {subject_id_field},time,code,numeric_value - "filter_subjects/held_out/0": |-2 - {subject_id_field},time,code,numeric_value - 1500733,,EYE_COLOR//BROWN, - 1500733,,HEIGHT,158.60131573580904 - 1500733,"07/20/1986, 00:00:00",DOB, - 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, - 1500733,"06/03/2010, 14:54:38",HR,91.4 - 1500733,"06/03/2010, 14:54:38",TEMP,100.0 - 1500733,"06/03/2010, 15:39:49",HR,84.4 - 1500733,"06/03/2010, 15:39:49",TEMP,100.3 - 1500733,"06/03/2010, 16:20:49",HR,90.1 - 1500733,"06/03/2010, 16:20:49",TEMP,100.1 - 1500733,"06/03/2010, 16:44:26",DISCHARGE, +"filter_subjects/held_out/0": |-2 + {subject_id_field},time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, """ ) WANT_TIME_DERIVED = parse_shards_yaml( f""" - "add_time_derived_measurements/train/0": |-2 - {subject_id_field},time,code,numeric_value - 239684,,EYE_COLOR//BROWN, - 239684,,HEIGHT,175.271115221764 - 239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)", - 239684,"12/28/1980, 00:00:00",DOB, - 239684,"05/11/2010, 17:41:51","TIME_OF_DAY//[12,18)", - 239684,"05/11/2010, 17:41:51",AGE,29.36883360091833 - 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, - 239684,"05/11/2010, 17:41:51",HR,102.6 - 239684,"05/11/2010, 17:41:51",TEMP,96.0 - 239684,"05/11/2010, 17:48:48","TIME_OF_DAY//[12,18)", - 239684,"05/11/2010, 17:48:48",AGE,29.36884681513314 - 239684,"05/11/2010, 17:48:48",HR,105.1 - 239684,"05/11/2010, 17:48:48",TEMP,96.2 - 239684,"05/11/2010, 18:25:35","TIME_OF_DAY//[18,24)", - 239684,"05/11/2010, 18:25:35",AGE,29.36891675223647 - 239684,"05/11/2010, 18:25:35",HR,113.4 - 239684,"05/11/2010, 18:25:35",TEMP,95.8 - 239684,"05/11/2010, 18:57:18","TIME_OF_DAY//[18,24)", - 239684,"05/11/2010, 18:57:18",AGE,29.36897705595538 - 239684,"05/11/2010, 18:57:18",HR,112.6 - 239684,"05/11/2010, 18:57:18",TEMP,95.5 - 239684,"05/11/2010, 19:27:19","TIME_OF_DAY//[18,24)", - 239684,"05/11/2010, 19:27:19",AGE,29.369034127420306 - 239684,"05/11/2010, 19:27:19",DISCHARGE, - 1195293,,EYE_COLOR//BLUE, - 1195293,,HEIGHT,164.6868838269085 - 1195293,"06/20/1978, 00:00:00","TIME_OF_DAY//[00,06)", - 1195293,"06/20/1978, 00:00:00",DOB, - 1195293,"06/20/2010, 19:23:52","TIME_OF_DAY//[18,24)", - 1195293,"06/20/2010, 19:23:52",AGE,32.002896271955265 - 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, - 1195293,"06/20/2010, 19:23:52",HR,109.0 - 1195293,"06/20/2010, 19:23:52",TEMP,100.0 - 1195293,"06/20/2010, 19:25:32","TIME_OF_DAY//[18,24)", - 1195293,"06/20/2010, 19:25:32",AGE,32.00289944083172 - 1195293,"06/20/2010, 19:25:32",HR,114.1 - 1195293,"06/20/2010, 19:25:32",TEMP,100.0 - 1195293,"06/20/2010, 19:45:19","TIME_OF_DAY//[18,24)", - 1195293,"06/20/2010, 19:45:19",AGE,32.00293705539522 - 1195293,"06/20/2010, 19:45:19",HR,119.8 - 1195293,"06/20/2010, 19:45:19",TEMP,99.9 - 1195293,"06/20/2010, 20:12:31","TIME_OF_DAY//[18,24)", - 1195293,"06/20/2010, 20:12:31",AGE,32.002988771458945 - 1195293,"06/20/2010, 20:12:31",HR,112.5 - 1195293,"06/20/2010, 20:12:31",TEMP,99.8 - 1195293,"06/20/2010, 20:24:44","TIME_OF_DAY//[18,24)", - 1195293,"06/20/2010, 20:24:44",AGE,32.00301199932335 - 1195293,"06/20/2010, 20:24:44",HR,107.7 - 1195293,"06/20/2010, 20:24:44",TEMP,100.0 - 1195293,"06/20/2010, 20:41:33","TIME_OF_DAY//[18,24)", - 1195293,"06/20/2010, 20:41:33",AGE,32.003043973286765 - 1195293,"06/20/2010, 20:41:33",HR,107.5 - 1195293,"06/20/2010, 20:41:33",TEMP,100.4 - 1195293,"06/20/2010, 20:50:04","TIME_OF_DAY//[18,24)", - 1195293,"06/20/2010, 20:50:04",AGE,32.00306016624544 - 1195293,"06/20/2010, 20:50:04",DISCHARGE, +"add_time_derived_measurements/train/0": |-2 + {subject_id_field},time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)", + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51","TIME_OF_DAY//[12,18)", + 239684,"05/11/2010, 17:41:51",AGE,29.36883360091833 + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48","TIME_OF_DAY//[12,18)", + 239684,"05/11/2010, 17:48:48",AGE,29.36884681513314 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35","TIME_OF_DAY//[18,24)", + 239684,"05/11/2010, 18:25:35",AGE,29.36891675223647 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18","TIME_OF_DAY//[18,24)", + 239684,"05/11/2010, 18:57:18",AGE,29.36897705595538 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19","TIME_OF_DAY//[18,24)", + 239684,"05/11/2010, 19:27:19",AGE,29.369034127420306 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00","TIME_OF_DAY//[00,06)", + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 19:23:52",AGE,32.002896271955265 + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 19:25:32",AGE,32.00289944083172 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 19:45:19",AGE,32.00293705539522 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 20:12:31",AGE,32.002988771458945 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 20:24:44",AGE,32.00301199932335 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 20:41:33",AGE,32.003043973286765 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 20:50:04",AGE,32.00306016624544 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, - "add_time_derived_measurements/train/1": |-2 - {subject_id_field},time,code,numeric_value +"add_time_derived_measurements/train/1": |-2 + {subject_id_field},time,code,numeric_value - "add_time_derived_measurements/tuning/0": |-2 - {subject_id_field},time,code,numeric_value +"add_time_derived_measurements/tuning/0": |-2 + {subject_id_field},time,code,numeric_value - "add_time_derived_measurements/held_out/0": |-2 - {subject_id_field},time,code,numeric_value - 1500733,,EYE_COLOR//BROWN, - 1500733,,HEIGHT,158.60131573580904 - 1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)", - 1500733,"07/20/1986, 00:00:00",DOB, - 1500733,"06/03/2010, 14:54:38","TIME_OF_DAY//[12,18)", - 1500733,"06/03/2010, 14:54:38",AGE,23.873531791091356 - 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, - 1500733,"06/03/2010, 14:54:38",HR,91.4 - 1500733,"06/03/2010, 14:54:38",TEMP,100.0 - 1500733,"06/03/2010, 15:39:49","TIME_OF_DAY//[12,18)", - 1500733,"06/03/2010, 15:39:49",AGE,23.873617699332012 - 1500733,"06/03/2010, 15:39:49",HR,84.4 - 1500733,"06/03/2010, 15:39:49",TEMP,100.3 - 1500733,"06/03/2010, 16:20:49","TIME_OF_DAY//[12,18)", - 1500733,"06/03/2010, 16:20:49",AGE,23.873695653692767 - 1500733,"06/03/2010, 16:20:49",HR,90.1 - 1500733,"06/03/2010, 16:20:49",TEMP,100.1 - 1500733,"06/03/2010, 16:44:26","TIME_OF_DAY//[12,18)", - 1500733,"06/03/2010, 16:44:26",AGE,23.873740556672114 - 1500733,"06/03/2010, 16:44:26",DISCHARGE, +"add_time_derived_measurements/held_out/0": |-2 + {subject_id_field},time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)", + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38","TIME_OF_DAY//[12,18)", + 1500733,"06/03/2010, 14:54:38",AGE,23.873531791091356 + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49","TIME_OF_DAY//[12,18)", + 1500733,"06/03/2010, 15:39:49",AGE,23.873617699332012 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49","TIME_OF_DAY//[12,18)", + 1500733,"06/03/2010, 16:20:49",AGE,23.873695653692767 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26","TIME_OF_DAY//[12,18)", + 1500733,"06/03/2010, 16:44:26",AGE,23.873740556672114 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, """ ) @@ -389,93 +389,93 @@ WANT_OCCLUDE_OUTLIERS = parse_shards_yaml( f""" - "occlude_outliers/train/0": |-2 - {subject_id_field},time,code,numeric_value,numeric_value/is_inlier - 239684,,EYE_COLOR//BROWN,, - 239684,,HEIGHT,,false - 239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)",, - 239684,"12/28/1980, 00:00:00",DOB,, - 239684,"05/11/2010, 17:41:51","TIME_OF_DAY//[12,18)",, - 239684,"05/11/2010, 17:41:51",AGE,,false - 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC,, - 239684,"05/11/2010, 17:41:51",HR,,false - 239684,"05/11/2010, 17:41:51",TEMP,,false - 239684,"05/11/2010, 17:48:48","TIME_OF_DAY//[12,18)",, - 239684,"05/11/2010, 17:48:48",AGE,,false - 239684,"05/11/2010, 17:48:48",HR,,false - 239684,"05/11/2010, 17:48:48",TEMP,,false - 239684,"05/11/2010, 18:25:35","TIME_OF_DAY//[18,24)",, - 239684,"05/11/2010, 18:25:35",AGE,,false - 239684,"05/11/2010, 18:25:35",HR,113.4,true - 239684,"05/11/2010, 18:25:35",TEMP,,false - 239684,"05/11/2010, 18:57:18","TIME_OF_DAY//[18,24)",, - 239684,"05/11/2010, 18:57:18",AGE,,false - 239684,"05/11/2010, 18:57:18",HR,112.6,true - 239684,"05/11/2010, 18:57:18",TEMP,,false - 239684,"05/11/2010, 19:27:19","TIME_OF_DAY//[18,24)",, - 239684,"05/11/2010, 19:27:19",AGE,,false - 239684,"05/11/2010, 19:27:19",DISCHARGE,, - 1195293,,EYE_COLOR//BLUE,, - 1195293,,HEIGHT,,false - 1195293,"06/20/1978, 00:00:00","TIME_OF_DAY//[00,06)",, - 1195293,"06/20/1978, 00:00:00",DOB,, - 1195293,"06/20/2010, 19:23:52","TIME_OF_DAY//[18,24)",, - 1195293,"06/20/2010, 19:23:52",AGE,32.002896271955265,true - 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC,, - 1195293,"06/20/2010, 19:23:52",HR,109.0,true - 1195293,"06/20/2010, 19:23:52",TEMP,100.0,true - 1195293,"06/20/2010, 19:25:32","TIME_OF_DAY//[18,24)",, - 1195293,"06/20/2010, 19:25:32",AGE,32.00289944083172,true - 1195293,"06/20/2010, 19:25:32",HR,114.1,true - 1195293,"06/20/2010, 19:25:32",TEMP,100.0,true - 1195293,"06/20/2010, 19:45:19","TIME_OF_DAY//[18,24)",, - 1195293,"06/20/2010, 19:45:19",AGE,32.00293705539522,true - 1195293,"06/20/2010, 19:45:19",HR,,false - 1195293,"06/20/2010, 19:45:19",TEMP,99.9,true - 1195293,"06/20/2010, 20:12:31","TIME_OF_DAY//[18,24)",, - 1195293,"06/20/2010, 20:12:31",AGE,32.002988771458945,true - 1195293,"06/20/2010, 20:12:31",HR,112.5,true - 1195293,"06/20/2010, 20:12:31",TEMP,99.8,true - 1195293,"06/20/2010, 20:24:44","TIME_OF_DAY//[18,24)", - 1195293,"06/20/2010, 20:24:44",AGE,32.00301199932335,true - 1195293,"06/20/2010, 20:24:44",HR,107.7,true - 1195293,"06/20/2010, 20:24:44",TEMP,100.0,true - 1195293,"06/20/2010, 20:41:33","TIME_OF_DAY//[18,24)",, - 1195293,"06/20/2010, 20:41:33",AGE,32.003043973286765,true - 1195293,"06/20/2010, 20:41:33",HR,107.5,true - 1195293,"06/20/2010, 20:41:33",TEMP,100.4,true - 1195293,"06/20/2010, 20:50:04","TIME_OF_DAY//[18,24)",, - 1195293,"06/20/2010, 20:50:04",AGE,32.00306016624544,true - 1195293,"06/20/2010, 20:50:04",DISCHARGE,, +"occlude_outliers/train/0": |-2 + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier + 239684,,EYE_COLOR//BROWN,, + 239684,,HEIGHT,,false + 239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)",, + 239684,"12/28/1980, 00:00:00",DOB,, + 239684,"05/11/2010, 17:41:51","TIME_OF_DAY//[12,18)",, + 239684,"05/11/2010, 17:41:51",AGE,,false + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC,, + 239684,"05/11/2010, 17:41:51",HR,,false + 239684,"05/11/2010, 17:41:51",TEMP,,false + 239684,"05/11/2010, 17:48:48","TIME_OF_DAY//[12,18)",, + 239684,"05/11/2010, 17:48:48",AGE,,false + 239684,"05/11/2010, 17:48:48",HR,,false + 239684,"05/11/2010, 17:48:48",TEMP,,false + 239684,"05/11/2010, 18:25:35","TIME_OF_DAY//[18,24)",, + 239684,"05/11/2010, 18:25:35",AGE,,false + 239684,"05/11/2010, 18:25:35",HR,113.4,true + 239684,"05/11/2010, 18:25:35",TEMP,,false + 239684,"05/11/2010, 18:57:18","TIME_OF_DAY//[18,24)",, + 239684,"05/11/2010, 18:57:18",AGE,,false + 239684,"05/11/2010, 18:57:18",HR,112.6,true + 239684,"05/11/2010, 18:57:18",TEMP,,false + 239684,"05/11/2010, 19:27:19","TIME_OF_DAY//[18,24)",, + 239684,"05/11/2010, 19:27:19",AGE,,false + 239684,"05/11/2010, 19:27:19",DISCHARGE,, + 1195293,,EYE_COLOR//BLUE,, + 1195293,,HEIGHT,,false + 1195293,"06/20/1978, 00:00:00","TIME_OF_DAY//[00,06)",, + 1195293,"06/20/1978, 00:00:00",DOB,, + 1195293,"06/20/2010, 19:23:52","TIME_OF_DAY//[18,24)",, + 1195293,"06/20/2010, 19:23:52",AGE,32.002896271955265,true + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC,, + 1195293,"06/20/2010, 19:23:52",HR,109.0,true + 1195293,"06/20/2010, 19:23:52",TEMP,100.0,true + 1195293,"06/20/2010, 19:25:32","TIME_OF_DAY//[18,24)",, + 1195293,"06/20/2010, 19:25:32",AGE,32.00289944083172,true + 1195293,"06/20/2010, 19:25:32",HR,114.1,true + 1195293,"06/20/2010, 19:25:32",TEMP,100.0,true + 1195293,"06/20/2010, 19:45:19","TIME_OF_DAY//[18,24)",, + 1195293,"06/20/2010, 19:45:19",AGE,32.00293705539522,true + 1195293,"06/20/2010, 19:45:19",HR,,false + 1195293,"06/20/2010, 19:45:19",TEMP,99.9,true + 1195293,"06/20/2010, 20:12:31","TIME_OF_DAY//[18,24)",, + 1195293,"06/20/2010, 20:12:31",AGE,32.002988771458945,true + 1195293,"06/20/2010, 20:12:31",HR,112.5,true + 1195293,"06/20/2010, 20:12:31",TEMP,99.8,true + 1195293,"06/20/2010, 20:24:44","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 20:24:44",AGE,32.00301199932335,true + 1195293,"06/20/2010, 20:24:44",HR,107.7,true + 1195293,"06/20/2010, 20:24:44",TEMP,100.0,true + 1195293,"06/20/2010, 20:41:33","TIME_OF_DAY//[18,24)",, + 1195293,"06/20/2010, 20:41:33",AGE,32.003043973286765,true + 1195293,"06/20/2010, 20:41:33",HR,107.5,true + 1195293,"06/20/2010, 20:41:33",TEMP,100.4,true + 1195293,"06/20/2010, 20:50:04","TIME_OF_DAY//[18,24)",, + 1195293,"06/20/2010, 20:50:04",AGE,32.00306016624544,true + 1195293,"06/20/2010, 20:50:04",DISCHARGE,, - "occlude_outliers/train/1": |-2 - {subject_id_field},time,code,numeric_value,numeric_value/is_inlier +"occlude_outliers/train/1": |-2 + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier - "occlude_outliers/tuning/0": |-2 - {subject_id_field},time,code,numeric_value,numeric_value/is_inlier +"occlude_outliers/tuning/0": |-2 + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier - "occlude_outliers/held_out/0": |-2 - {subject_id_field},time,code,numeric_value,numeric_value/is_inlier - 1500733,,EYE_COLOR//BROWN,, - 1500733,,HEIGHT,,false - 1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)",, - 1500733,"07/20/1986, 00:00:00",DOB,, - 1500733,"06/03/2010, 14:54:38","TIME_OF_DAY//[12,18)",, - 1500733,"06/03/2010, 14:54:38",AGE,,false - 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC,, - 1500733,"06/03/2010, 14:54:38",HR,,false - 1500733,"06/03/2010, 14:54:38",TEMP,100.0,true - 1500733,"06/03/2010, 15:39:49","TIME_OF_DAY//[12,18)",, - 1500733,"06/03/2010, 15:39:49",AGE,,false - 1500733,"06/03/2010, 15:39:49",HR,,false - 1500733,"06/03/2010, 15:39:49",TEMP,100.3,true - 1500733,"06/03/2010, 16:20:49","TIME_OF_DAY//[12,18)",, - 1500733,"06/03/2010, 16:20:49",AGE,,false - 1500733,"06/03/2010, 16:20:49",HR,,false - 1500733,"06/03/2010, 16:20:49",TEMP,100.1,true - 1500733,"06/03/2010, 16:44:26","TIME_OF_DAY//[12,18)",, - 1500733,"06/03/2010, 16:44:26",AGE,,false - 1500733,"06/03/2010, 16:44:26",DISCHARGE,, +"occlude_outliers/held_out/0": |-2 + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier + 1500733,,EYE_COLOR//BROWN,, + 1500733,,HEIGHT,,false + 1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)",, + 1500733,"07/20/1986, 00:00:00",DOB,, + 1500733,"06/03/2010, 14:54:38","TIME_OF_DAY//[12,18)",, + 1500733,"06/03/2010, 14:54:38",AGE,,false + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC,, + 1500733,"06/03/2010, 14:54:38",HR,,false + 1500733,"06/03/2010, 14:54:38",TEMP,100.0,true + 1500733,"06/03/2010, 15:39:49","TIME_OF_DAY//[12,18)",, + 1500733,"06/03/2010, 15:39:49",AGE,,false + 1500733,"06/03/2010, 15:39:49",HR,,false + 1500733,"06/03/2010, 15:39:49",TEMP,100.3,true + 1500733,"06/03/2010, 16:20:49","TIME_OF_DAY//[12,18)",, + 1500733,"06/03/2010, 16:20:49",AGE,,false + 1500733,"06/03/2010, 16:20:49",HR,,false + 1500733,"06/03/2010, 16:20:49",TEMP,100.1,true + 1500733,"06/03/2010, 16:44:26","TIME_OF_DAY//[12,18)",, + 1500733,"06/03/2010, 16:44:26",AGE,,false + 1500733,"06/03/2010, 16:44:26",DISCHARGE,, """ ) @@ -774,93 +774,93 @@ # Note we have dropped the row in the held out shard that doesn't have a code in the vocabulary! WANT_NORMALIZATION = parse_shards_yaml( f""" - "normalization/train/0": |-2 - {subject_id_field},time,code,numeric_value - 239684,,6, - 239684,,7, - 239684,"12/28/1980, 00:00:00",10, - 239684,"12/28/1980, 00:00:00",4, - 239684,"05/11/2010, 17:41:51",11, - 239684,"05/11/2010, 17:41:51",2, - 239684,"05/11/2010, 17:41:51",1, - 239684,"05/11/2010, 17:41:51",8, - 239684,"05/11/2010, 17:41:51",9, - 239684,"05/11/2010, 17:48:48",11, - 239684,"05/11/2010, 17:48:48",2, - 239684,"05/11/2010, 17:48:48",8, - 239684,"05/11/2010, 17:48:48",9, - 239684,"05/11/2010, 18:25:35",12, - 239684,"05/11/2010, 18:25:35",2, - 239684,"05/11/2010, 18:25:35",8,0.9341503977775574 - 239684,"05/11/2010, 18:25:35",9, - 239684,"05/11/2010, 18:57:18",12, - 239684,"05/11/2010, 18:57:18",2, - 239684,"05/11/2010, 18:57:18",8,0.6264293789863586 - 239684,"05/11/2010, 18:57:18",9, - 239684,"05/11/2010, 19:27:19",12, - 239684,"05/11/2010, 19:27:19",2, - 239684,"05/11/2010, 19:27:19",3, - 1195293,,5, - 1195293,,7, - 1195293,"06/20/1978, 00:00:00",10, - 1195293,"06/20/1978, 00:00:00",4, - 1195293,"06/20/2010, 19:23:52",12, - 1195293,"06/20/2010, 19:23:52",2,nan - 1195293,"06/20/2010, 19:23:52",1, - 1195293,"06/20/2010, 19:23:52",8,-0.7583094239234924 - 1195293,"06/20/2010, 19:23:52",9,-0.0889078751206398 - 1195293,"06/20/2010, 19:25:32",12, - 1195293,"06/20/2010, 19:25:32",2,nan - 1195293,"06/20/2010, 19:25:32",8,1.2034040689468384 - 1195293,"06/20/2010, 19:25:32",9,-0.0889078751206398 - 1195293,"06/20/2010, 19:45:19",12, - 1195293,"06/20/2010, 19:45:19",2,nan - 1195293,"06/20/2010, 19:45:19",8, - 1195293,"06/20/2010, 19:45:19",9,-0.6222330927848816 - 1195293,"06/20/2010, 20:12:31",12, - 1195293,"06/20/2010, 20:12:31",2,nan - 1195293,"06/20/2010, 20:12:31",8,0.5879650115966797 - 1195293,"06/20/2010, 20:12:31",9,-1.1555582284927368 - 1195293,"06/20/2010, 20:24:44",12 - 1195293,"06/20/2010, 20:24:44",2,nan - 1195293,"06/20/2010, 20:24:44",8,-1.2583553791046143 - 1195293,"06/20/2010, 20:24:44",9,-0.0889078751206398 - 1195293,"06/20/2010, 20:41:33",12, - 1195293,"06/20/2010, 20:41:33",2,nan - 1195293,"06/20/2010, 20:41:33",8,-1.3352841138839722 - 1195293,"06/20/2010, 20:41:33",9,2.04443359375 - 1195293,"06/20/2010, 20:50:04",12, - 1195293,"06/20/2010, 20:50:04",2,nan - 1195293,"06/20/2010, 20:50:04",3, +"normalization/train/0": |-2 + {subject_id_field},time,code,numeric_value + 239684,,6, + 239684,,7, + 239684,"12/28/1980, 00:00:00",10, + 239684,"12/28/1980, 00:00:00",4, + 239684,"05/11/2010, 17:41:51",11, + 239684,"05/11/2010, 17:41:51",2, + 239684,"05/11/2010, 17:41:51",1, + 239684,"05/11/2010, 17:41:51",8, + 239684,"05/11/2010, 17:41:51",9, + 239684,"05/11/2010, 17:48:48",11, + 239684,"05/11/2010, 17:48:48",2, + 239684,"05/11/2010, 17:48:48",8, + 239684,"05/11/2010, 17:48:48",9, + 239684,"05/11/2010, 18:25:35",12, + 239684,"05/11/2010, 18:25:35",2, + 239684,"05/11/2010, 18:25:35",8,0.9341503977775574 + 239684,"05/11/2010, 18:25:35",9, + 239684,"05/11/2010, 18:57:18",12, + 239684,"05/11/2010, 18:57:18",2, + 239684,"05/11/2010, 18:57:18",8,0.6264293789863586 + 239684,"05/11/2010, 18:57:18",9, + 239684,"05/11/2010, 19:27:19",12, + 239684,"05/11/2010, 19:27:19",2, + 239684,"05/11/2010, 19:27:19",3, + 1195293,,5, + 1195293,,7, + 1195293,"06/20/1978, 00:00:00",10, + 1195293,"06/20/1978, 00:00:00",4, + 1195293,"06/20/2010, 19:23:52",12, + 1195293,"06/20/2010, 19:23:52",2,nan + 1195293,"06/20/2010, 19:23:52",1, + 1195293,"06/20/2010, 19:23:52",8,-0.7583094239234924 + 1195293,"06/20/2010, 19:23:52",9,-0.0889078751206398 + 1195293,"06/20/2010, 19:25:32",12, + 1195293,"06/20/2010, 19:25:32",2,nan + 1195293,"06/20/2010, 19:25:32",8,1.2034040689468384 + 1195293,"06/20/2010, 19:25:32",9,-0.0889078751206398 + 1195293,"06/20/2010, 19:45:19",12, + 1195293,"06/20/2010, 19:45:19",2,nan + 1195293,"06/20/2010, 19:45:19",8, + 1195293,"06/20/2010, 19:45:19",9,-0.6222330927848816 + 1195293,"06/20/2010, 20:12:31",12, + 1195293,"06/20/2010, 20:12:31",2,nan + 1195293,"06/20/2010, 20:12:31",8,0.5879650115966797 + 1195293,"06/20/2010, 20:12:31",9,-1.1555582284927368 + 1195293,"06/20/2010, 20:24:44",12 + 1195293,"06/20/2010, 20:24:44",2,nan + 1195293,"06/20/2010, 20:24:44",8,-1.2583553791046143 + 1195293,"06/20/2010, 20:24:44",9,-0.0889078751206398 + 1195293,"06/20/2010, 20:41:33",12, + 1195293,"06/20/2010, 20:41:33",2,nan + 1195293,"06/20/2010, 20:41:33",8,-1.3352841138839722 + 1195293,"06/20/2010, 20:41:33",9,2.04443359375 + 1195293,"06/20/2010, 20:50:04",12, + 1195293,"06/20/2010, 20:50:04",2,nan + 1195293,"06/20/2010, 20:50:04",3, - "normalization/train/1": |-2 - {subject_id_field},time,code,numeric_value +"normalization/train/1": |-2 + {subject_id_field},time,code,numeric_value - "normalization/tuning/0": |-2 - {subject_id_field},time,code,numeric_value +"normalization/tuning/0": |-2 + {subject_id_field},time,code,numeric_value - "normalization/held_out/0": |-2 - {subject_id_field},time,code,numeric_value - 1500733,,6, - 1500733,,7, - 1500733,"07/20/1986, 00:00:00",10, - 1500733,"07/20/1986, 00:00:00",4, - 1500733,"06/03/2010, 14:54:38",11, - 1500733,"06/03/2010, 14:54:38",2, - 1500733,"06/03/2010, 14:54:38",8, - 1500733,"06/03/2010, 14:54:38",9,-0.0889078751206398 - 1500733,"06/03/2010, 15:39:49",11, - 1500733,"06/03/2010, 15:39:49",2, - 1500733,"06/03/2010, 15:39:49",8, - 1500733,"06/03/2010, 15:39:49",9,1.5111083984375 - 1500733,"06/03/2010, 16:20:49",11, - 1500733,"06/03/2010, 16:20:49",2, - 1500733,"06/03/2010, 16:20:49",8, - 1500733,"06/03/2010, 16:20:49",9,0.4444173276424408 - 1500733,"06/03/2010, 16:44:26",11, - 1500733,"06/03/2010, 16:44:26",2, - 1500733,"06/03/2010, 16:44:26",3, - """, +"normalization/held_out/0": |-2 + {subject_id_field},time,code,numeric_value + 1500733,,6, + 1500733,,7, + 1500733,"07/20/1986, 00:00:00",10, + 1500733,"07/20/1986, 00:00:00",4, + 1500733,"06/03/2010, 14:54:38",11, + 1500733,"06/03/2010, 14:54:38",2, + 1500733,"06/03/2010, 14:54:38",8, + 1500733,"06/03/2010, 14:54:38",9,-0.0889078751206398 + 1500733,"06/03/2010, 15:39:49",11, + 1500733,"06/03/2010, 15:39:49",2, + 1500733,"06/03/2010, 15:39:49",8, + 1500733,"06/03/2010, 15:39:49",9,1.5111083984375 + 1500733,"06/03/2010, 16:20:49",11, + 1500733,"06/03/2010, 16:20:49",2, + 1500733,"06/03/2010, 16:20:49",8, + 1500733,"06/03/2010, 16:20:49",9,0.4444173276424408 + 1500733,"06/03/2010, 16:44:26",11, + 1500733,"06/03/2010, 16:44:26",2, + 1500733,"06/03/2010, 16:44:26",3, + """, code=pl.UInt8, ) diff --git a/tests/MEDS_Transforms/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py index 6a1d4ab8..60aa68d9 100644 --- a/tests/MEDS_Transforms/transform_tester_base.py +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -17,10 +17,7 @@ import polars as pl from meds import subject_id_field -from tests.utils import FILE_T, multi_stage_tester, parse_meds_csvs, parse_shards_yaml, single_stage_tester - -# So it can be imported from here -parse_shards_yaml = parse_shards_yaml +from tests.utils import FILE_T, multi_stage_tester, parse_shards_yaml, single_stage_tester # Test MEDS data (inputs) @@ -37,94 +34,84 @@ "held_out": [1500733], } -MEDS_TRAIN_0 = """ -subject_id,time,code,numeric_value -239684,,EYE_COLOR//BROWN, -239684,,HEIGHT,175.271115221764 -239684,"12/28/1980, 00:00:00",DOB, -239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, -239684,"05/11/2010, 17:41:51",HR,102.6 -239684,"05/11/2010, 17:41:51",TEMP,96.0 -239684,"05/11/2010, 17:48:48",HR,105.1 -239684,"05/11/2010, 17:48:48",TEMP,96.2 -239684,"05/11/2010, 18:25:35",HR,113.4 -239684,"05/11/2010, 18:25:35",TEMP,95.8 -239684,"05/11/2010, 18:57:18",HR,112.6 -239684,"05/11/2010, 18:57:18",TEMP,95.5 -239684,"05/11/2010, 19:27:19",DISCHARGE, -1195293,,EYE_COLOR//BLUE, -1195293,,HEIGHT,164.6868838269085 -1195293,"06/20/1978, 00:00:00",DOB, -1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, -1195293,"06/20/2010, 19:23:52",HR,109.0 -1195293,"06/20/2010, 19:23:52",TEMP,100.0 -1195293,"06/20/2010, 19:25:32",HR,114.1 -1195293,"06/20/2010, 19:25:32",TEMP,100.0 -1195293,"06/20/2010, 19:45:19",HR,119.8 -1195293,"06/20/2010, 19:45:19",TEMP,99.9 -1195293,"06/20/2010, 20:12:31",HR,112.5 -1195293,"06/20/2010, 20:12:31",TEMP,99.8 -1195293,"06/20/2010, 20:24:44",HR,107.7 -1195293,"06/20/2010, 20:24:44",TEMP,100.0 -1195293,"06/20/2010, 20:41:33",HR,107.5 -1195293,"06/20/2010, 20:41:33",TEMP,100.4 -1195293,"06/20/2010, 20:50:04",DISCHARGE, -""" - -MEDS_TRAIN_1 = """ -subject_id,time,code,numeric_value -68729,,EYE_COLOR//HAZEL, -68729,,HEIGHT,160.3953106166676 -68729,"03/09/1978, 00:00:00",DOB, -68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, -68729,"05/26/2010, 02:30:56",HR,86.0 -68729,"05/26/2010, 02:30:56",TEMP,97.8 -68729,"05/26/2010, 04:51:52",DISCHARGE, -814703,,EYE_COLOR//HAZEL, -814703,,HEIGHT,156.48559093209357 -814703,"03/28/1976, 00:00:00",DOB, -814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, -814703,"02/05/2010, 05:55:39",HR,170.2 -814703,"02/05/2010, 05:55:39",TEMP,100.1 -814703,"02/05/2010, 07:02:30",DISCHARGE, -""" - -MEDS_TUNING_0 = """ -subject_id,time,code,numeric_value -754281,,EYE_COLOR//BROWN, -754281,,HEIGHT,166.22261567137025 -754281,"12/19/1988, 00:00:00",DOB, -754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, -754281,"01/03/2010, 06:27:59",HR,142.0 -754281,"01/03/2010, 06:27:59",TEMP,99.8 -754281,"01/03/2010, 08:22:13",DISCHARGE, -""" - -MEDS_HELD_OUT_0 = """ -subject_id,time,code,numeric_value -1500733,,EYE_COLOR//BROWN, -1500733,,HEIGHT,158.60131573580904 -1500733,"07/20/1986, 00:00:00",DOB, -1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, -1500733,"06/03/2010, 14:54:38",HR,91.4 -1500733,"06/03/2010, 14:54:38",TEMP,100.0 -1500733,"06/03/2010, 15:39:49",HR,84.4 -1500733,"06/03/2010, 15:39:49",TEMP,100.3 -1500733,"06/03/2010, 16:20:49",HR,90.1 -1500733,"06/03/2010, 16:20:49",TEMP,100.1 -1500733,"06/03/2010, 16:44:26",DISCHARGE, -""" - -MEDS_SHARDS = parse_meds_csvs( - { - "train/0": MEDS_TRAIN_0, - "train/1": MEDS_TRAIN_1, - "tuning/0": MEDS_TUNING_0, - "held_out/0": MEDS_HELD_OUT_0, - } +MEDS_SHARDS = parse_shards_yaml( + """ +train/0: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + +train/1: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + +tuning/0: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + +held_out/0: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ ) - MEDS_CODE_METADATA_CSV = """ code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_codes ,44,4,28,3198.8389005974336,382968.28937288234,, diff --git a/tests/utils.py b/tests/utils.py index 450a86b4..400dd2eb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -57,7 +57,7 @@ def reader(csv_str: str) -> pl.DataFrame: def parse_shards_yaml(yaml_str: str, **schema_updates) -> pl.DataFrame: schema = {**MEDS_PL_SCHEMA, **schema_updates} - return parse_meds_csvs(load_yaml(yaml_str, Loader=Loader), schema=schema) + return parse_meds_csvs(load_yaml(yaml_str.strip(), Loader=Loader), schema=schema) def dict_to_hydra_kwargs(d: dict[str, str]) -> str: @@ -271,6 +271,10 @@ def check_NRT_output( FILE_T = pl.DataFrame | dict[str, Any] | str +def add_params(templ_str: str, **kwargs): + return templ_str.format(**kwargs) + + @contextmanager def input_dataset(input_files: dict[str, FILE_T] | None = None): with tempfile.TemporaryDirectory() as d: @@ -294,6 +298,12 @@ def input_dataset(input_files: dict[str, FILE_T] | None = None): fp.write_text(json.dumps(data)) case str(): fp.write_text(data.strip()) + case _ if callable(data): + data_str = data( + input_dir=str(input_dir.resolve()), + cohort_dir=str(cohort_dir.resolve()), + ) + fp.write_text(data_str) case _: raise ValueError(f"Unknown data type {type(data)} for file {fp.relative_to(input_dir)}") @@ -356,7 +366,7 @@ def check_outputs( def single_stage_tester( script: str | Path, - stage_name: str, + stage_name: str | None, stage_kwargs: dict[str, str] | None, do_pass_stage_name: bool = False, do_use_config_yaml: bool = False, @@ -366,8 +376,13 @@ def single_stage_tester( config_name: str = "preprocess", input_files: dict[str, FILE_T] | None = None, df_check_kwargs: dict | None = None, + test_name: str | None = None, + do_include_dirs: bool = True, **pipeline_kwargs, ): + if test_name is None: + test_name = f"Single stage transform: {stage_name}" + if df_check_kwargs is None: df_check_kwargs = {} @@ -377,20 +392,23 @@ def single_stage_tester( pipeline_kwargs[k] = v.format(input_dir=str(input_dir.resolve())) pipeline_config_kwargs = { - "input_dir": str(input_dir.resolve()), - "cohort_dir": str(cohort_dir.resolve()), - "stages": [stage_name], "hydra.verbose": True, **pipeline_kwargs, } + if do_include_dirs: + pipeline_config_kwargs["input_dir"] = str(input_dir.resolve()) + pipeline_config_kwargs["cohort_dir"] = str(cohort_dir.resolve()) + + if stage_name is not None: + pipeline_config_kwargs["stages"] = [stage_name] if stage_kwargs: pipeline_config_kwargs["stage_configs"] = {stage_name: stage_kwargs} run_command_kwargs = { "script": script, "hydra_kwargs": pipeline_config_kwargs, - "test_name": f"Single stage transform: {stage_name}", + "test_name": test_name, "should_error": should_error, "config_name": config_name, "do_use_config_yaml": do_use_config_yaml, From 13af9fbb24c9953b7a079ebcc274542ab276bf3f Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 22:21:12 -0400 Subject: [PATCH 31/62] Fixed issues with hydra recursion issues and have things almost working; caught a possible issue with split file name. --- pyproject.toml | 3 + src/MEDS_transforms/__init__.py | 8 +- .../configs/{extract.yaml => _extract.yaml} | 2 +- .../configs/{pipeline.yaml => _pipeline.yaml} | 0 .../{preprocess.yaml => _preprocess.yaml} | 2 +- .../configs/{runner.yaml => _runner.yaml} | 14 +- src/MEDS_transforms/mapreduce/mapper.py | 4 +- src/MEDS_transforms/runner.py | 69 ++++++-- .../MEDS_Transforms/transform_tester_base.py | 26 +-- tests/__init__.py | 14 ++ tests/test_with_runner.py | 160 ++++++++++++++++++ tests/utils.py | 3 + 12 files changed, 260 insertions(+), 45 deletions(-) rename src/MEDS_transforms/configs/{extract.yaml => _extract.yaml} (99%) rename src/MEDS_transforms/configs/{pipeline.yaml => _pipeline.yaml} (100%) rename src/MEDS_transforms/configs/{preprocess.yaml => _preprocess.yaml} (98%) rename src/MEDS_transforms/configs/{runner.yaml => _runner.yaml} (65%) create mode 100644 tests/test_with_runner.py diff --git a/pyproject.toml b/pyproject.toml index c9f49069..ba6d30f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,9 @@ MEDS_transform-occlude_outliers = "MEDS_transforms.transforms.occlude_outliers:m MEDS_transform-tensorization = "MEDS_transforms.transforms.tensorization:main" MEDS_transform-tokenization = "MEDS_transforms.transforms.tokenization:main" +# Runner +MEDS_transform-runner = "MEDS_transforms.runner:main" + [project.urls] Homepage = "https://github.com/mmcdermott/MEDS_transforms" Issues = "https://github.com/mmcdermott/MEDS_transforms/issues" diff --git a/src/MEDS_transforms/__init__.py b/src/MEDS_transforms/__init__.py index 6609376b..2d62ae87 100644 --- a/src/MEDS_transforms/__init__.py +++ b/src/MEDS_transforms/__init__.py @@ -10,12 +10,14 @@ except PackageNotFoundError: # pragma: no cover __version__ = "unknown" -PREPROCESS_CONFIG_YAML = files(__package_name__).joinpath("configs/preprocess.yaml") -EXTRACT_CONFIG_YAML = files(__package_name__).joinpath("configs/extract.yaml") -RUNNER_CONFIG_YAML = files(__package_name__).joinpath("configs/runner.yaml") +PREPROCESS_CONFIG_YAML = files(__package_name__).joinpath("configs/_preprocess.yaml") +EXTRACT_CONFIG_YAML = files(__package_name__).joinpath("configs/_extract.yaml") +RUNNER_CONFIG_YAML = files(__package_name__).joinpath("configs/_runner.yaml") MANDATORY_COLUMNS = [subject_id_field, time_field, code_field, "numeric_value"] +RESERVED_CONFIG_NAMES = {c.stem for c in (PREPROCESS_CONFIG_YAML, EXTRACT_CONFIG_YAML, RUNNER_CONFIG_YAML)} + MANDATORY_TYPES = { subject_id_field: pl.Int64, time_field: pl.Datetime("us"), diff --git a/src/MEDS_transforms/configs/extract.yaml b/src/MEDS_transforms/configs/_extract.yaml similarity index 99% rename from src/MEDS_transforms/configs/extract.yaml rename to src/MEDS_transforms/configs/_extract.yaml index 3abd498e..dec483d2 100644 --- a/src/MEDS_transforms/configs/extract.yaml +++ b/src/MEDS_transforms/configs/_extract.yaml @@ -1,5 +1,5 @@ defaults: - - pipeline + - _pipeline - stage_configs: - shard_events - split_and_shard_subjects diff --git a/src/MEDS_transforms/configs/pipeline.yaml b/src/MEDS_transforms/configs/_pipeline.yaml similarity index 100% rename from src/MEDS_transforms/configs/pipeline.yaml rename to src/MEDS_transforms/configs/_pipeline.yaml diff --git a/src/MEDS_transforms/configs/preprocess.yaml b/src/MEDS_transforms/configs/_preprocess.yaml similarity index 98% rename from src/MEDS_transforms/configs/preprocess.yaml rename to src/MEDS_transforms/configs/_preprocess.yaml index dab87a9a..6ebafdc3 100644 --- a/src/MEDS_transforms/configs/preprocess.yaml +++ b/src/MEDS_transforms/configs/_preprocess.yaml @@ -1,5 +1,5 @@ defaults: - - pipeline + - _pipeline - stage_configs: - reshard_to_split - filter_subjects diff --git a/src/MEDS_transforms/configs/runner.yaml b/src/MEDS_transforms/configs/_runner.yaml similarity index 65% rename from src/MEDS_transforms/configs/runner.yaml rename to src/MEDS_transforms/configs/_runner.yaml index 3cf8fa66..9c1ec89b 100644 --- a/src/MEDS_transforms/configs/runner.yaml +++ b/src/MEDS_transforms/configs/_runner.yaml @@ -2,20 +2,16 @@ pipeline_config_fp: ??? stage_runner_fp: null -_pipeline_config: ${oc.create:${load_yaml_file:${oc.select:pipeline_config_fp,null}}} -_etl_metadata: ${oc.select:_pipeline_config.etl_metadata,${oc.create:{}}} - -_pipeline_name: ${oc.select:_etl_metadata.pipeline_name, "MEDS-transforms Pipeline"} -_pipeline_description: ${oc.select:_pipeline_config.description, "No description provided."} +_local_pipeline_config: ${oc.create:${load_yaml_file:${pipeline_config_fp}}} +_stage_runners: ${oc.create:${load_yaml_file:${stage_runner_fp}}} -log_dir: "${_pipeline_config.cohort_dir}/.logs" +log_dir: "${_local_pipeline_config.cohort_dir}/.logs" -_stage_runners: ${oc.create:${load_yaml_file:${stage_runner_fp}}} -stages: ${_pipeline_config.stages} +_pipeline_name: ${oc.select:_local_pipeline_config.pipeline_name, "MEDS-transforms Pipeline"} +_pipeline_description: ${oc.select:_pipeline_config.description, "No description provided."} do_profile: False -# Hydra #${oc.select:_pipeline_help_block,""} hydra: job: name: "${fix_str_for_path:${_pipeline_name}}_runner_${now:%Y-%m-%d_%H-%M-%S}" diff --git a/src/MEDS_transforms/mapreduce/mapper.py b/src/MEDS_transforms/mapreduce/mapper.py index 6cc44e85..6247be60 100644 --- a/src/MEDS_transforms/mapreduce/mapper.py +++ b/src/MEDS_transforms/mapreduce/mapper.py @@ -621,11 +621,11 @@ def map_over( start = datetime.now() train_only = cfg.stage_cfg.get("train_only", False) - split_fp = Path(cfg.input_dir) / "metadata" / "subject_split.parquet" shards, includes_only_train = shard_iterator_fntr(cfg) if train_only: + split_fp = Path(cfg.input_dir) / "metadata" / "subject_splits.parquet" if includes_only_train: logger.info( f"Processing train split only via shard prefix. Not filtering with {str(split_fp.resolve())}." @@ -636,7 +636,7 @@ def map_over( pl.scan_parquet(split_fp) .filter(pl.col("split") == "train") .select(subject_id_field) - .collect() + .collect()[subject_id_field] .to_list() ) read_fn = read_and_filter_fntr(train_subjects, read_fn) diff --git a/src/MEDS_transforms/runner.py b/src/MEDS_transforms/runner.py index 7d614c1d..afd68902 100644 --- a/src/MEDS_transforms/runner.py +++ b/src/MEDS_transforms/runner.py @@ -8,13 +8,20 @@ """ import importlib +import subprocess from pathlib import Path -from typing import Any import hydra +import yaml +from loguru import logger from omegaconf import DictConfig, OmegaConf -from MEDS_transforms import RUNNER_CONFIG_YAML +try: + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader + +from MEDS_transforms import RESERVED_CONFIG_NAMES, RUNNER_CONFIG_YAML from MEDS_transforms.utils import hydra_loguru_init @@ -49,7 +56,7 @@ def get_parallelization_args( ) -> list[str]: """Gets the parallelization args.""" - if parallelization_cfg is None: + if parallelization_cfg is None or len(parallelization_cfg) == 0: return [] if "n_workers" in parallelization_cfg: @@ -61,7 +68,7 @@ def get_parallelization_args( parallelization_args = [ "--multirun", - f"worker=range(0,{n_workers})", + f'worker="range(0,{n_workers})"', ] if "launcher" in parallelization_cfg: @@ -105,7 +112,7 @@ def run_stage(cfg: DictConfig, stage_name: str, default_parallelization_cfg: dic do_profile = cfg.get("do_profile", False) pipeline_config_fp = Path(cfg.pipeline_config_fp) - stage_config = pipeline_config_fp.stage_configs.get("stage", {}) + stage_config = cfg._local_pipeline_config.stage_configs.get(stage_name, {}) stage_runner_config = cfg._stage_runners.get(stage_name, {}) script = None @@ -120,7 +127,7 @@ def run_stage(cfg: DictConfig, stage_name: str, default_parallelization_cfg: dic command_parts = [ script, - f"--config-path={str(pipeline_config_fp.parent.resolve())}", + f"--config-dir={str(pipeline_config_fp.parent.resolve())}", f"--config-name={pipeline_config_fp.stem}", "'hydra.searchpath=[pkg://MEDS_transforms.configs]'", f"stage={stage_name}", @@ -142,9 +149,13 @@ def run_stage(cfg: DictConfig, stage_name: str, default_parallelization_cfg: dic stderr = command_out.stderr.decode() stdout = command_out.stdout.decode() + logger.info(f"Command output:\n{stdout}") + logger.info(f"Command error:\n{stderr}") if command_out.returncode != 0: - raise ValueError(f"Stage {stage_name} failed with return code {command_out.returncode}.\n{stderr}") + raise ValueError( + f"Stage {stage_name} failed via {full_cmd} with return code {command_out.returncode}." + ) @hydra.main( @@ -164,10 +175,20 @@ def main(cfg: DictConfig): raise FileNotFoundError(f"Pipeline configuration file {pipeline_config_fp} does not exist.") if not pipeline_config_fp.suffix == ".yaml": raise ValueError(f"Pipeline configuration file {pipeline_config_fp} must have a .yaml extension.") + if pipeline_config_fp.stem in RESERVED_CONFIG_NAMES: + raise ValueError( + f"Pipeline configuration file {pipeline_config_fp} must not have a name in " + f"{RESERVED_CONFIG_NAMES}." + ) - logs_dir = Path(cfg.logs_dir) + pipeline_config = load_yaml_file(cfg.pipeline_config_fp) + stages = pipeline_config.get("stages", []) + if not stages: + raise ValueError("Pipeline configuration must specify at least one stage.") - if do_profile: + log_dir = Path(cfg.log_dir) + + if cfg.get("do_profile", False): try: pass except ImportError as e: @@ -177,7 +198,7 @@ def main(cfg: DictConfig): "`pip install MEDS-transforms[profiler]`." ) from e - global_done_file = logs_dir / f"_all_stages.done" + global_done_file = log_dir / "_all_stages.done" if global_done_file.exists(): logger.info("All stages are already complete. Exiting.") return @@ -187,8 +208,8 @@ def main(cfg: DictConfig): else: default_parallelization_cfg = None - for stage in cfg.stages: - done_file = logs_dir / f"{stage}.done" + for stage in stages: + done_file = log_dir / f"{stage}.done" if done_file.exists(): logger.info(f"Skipping stage {stage} as it is already complete.") @@ -200,17 +221,29 @@ def main(cfg: DictConfig): global_done_file.touch() -def load_file(path: str) -> Any: - with open(path) as f: - return f.read() +def load_yaml_file(path: str | None) -> dict | DictConfig: + if not path: + return {} + + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File {path} does not exist.") + + try: + return OmegaConf.load(path) + except Exception as e: + logger.warning(f"Failed to load {path} as an OmegaConf: {e}. Trying as a plain YAML file.") + yaml_text = path.read_text() + return yaml.load(yaml_text, Loader=Loader) -def load_yaml_file(path: str | None) -> dict | DictConfig: - return OmegaConf.load(path) if path else {} +def fix_str_for_path(s: str) -> str: + """Replaces all space characters with underscores and all slashes with periods.""" + return s.replace(" ", "_").replace("/", ".") if __name__ == "__main__": OmegaConf.register_new_resolver("load_yaml_file", load_yaml_file, replace=False) - OmegaConf.register_new_resolver("load_file", load_file, replace=False) + OmegaConf.register_new_resolver("fix_str_for_path", fix_str_for_path, replace=False) main() diff --git a/tests/MEDS_Transforms/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py index 60aa68d9..93e70aa9 100644 --- a/tests/MEDS_Transforms/transform_tester_base.py +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -28,11 +28,12 @@ "held_out/0": [1500733], } -SPLITS = { - "train": [239684, 1195293, 68729, 814703], - "tuning": [754281], - "held_out": [1500733], -} +SPLITS_DF = pl.DataFrame( + { + subject_id_field: [239684, 1195293, 68729, 814703, 754281, 1500733], + "split": ["train", "train", "train", "train", "tuning", "held_out"], + } +) MEDS_SHARDS = parse_shards_yaml( """ @@ -184,14 +185,17 @@ def remap_inputs_for_transform( unified_inputs["metadata/.shards.json"] = input_shards_map if input_splits_map is None: - input_splits_map = SPLITS + input_splits_map = SPLITS_DF - input_splits_as_df = defaultdict(list) - for split_name, subject_ids in input_splits_map.items(): - input_splits_as_df[subject_id_field].extend(subject_ids) - input_splits_as_df["split"].extend([split_name] * len(subject_ids)) + if isinstance(input_splits_map, pl.DataFrame): + input_splits_df = input_splits_map + else: + input_splits_as_df = defaultdict(list) + for split_name, subject_ids in input_splits_map.items(): + input_splits_as_df[subject_id_field].extend(subject_ids) + input_splits_as_df["split"].extend([split_name] * len(subject_ids)) - input_splits_df = pl.DataFrame(input_splits_as_df) + input_splits_df = pl.DataFrame(input_splits_as_df) unified_inputs["metadata/subject_splits.parquet"] = input_splits_df diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..89520379 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,14 @@ +import os + +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +code_root = root / "src" / "MEDS_transforms" + +USE_LOCAL_SCRIPTS = os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1" + +if USE_LOCAL_SCRIPTS: + RUNNER_SCRIPT = code_root / "runner.py" +else: + RUNNER_SCRIPT = "MEDS_transform-runner" diff --git a/tests/test_with_runner.py b/tests/test_with_runner.py new file mode 100644 index 00000000..42eb7ff2 --- /dev/null +++ b/tests/test_with_runner.py @@ -0,0 +1,160 @@ +"""Tests a multi-stage pre-processing pipeline via the Runner utility. Only checks final outputs. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. + +In this test, the following stages are run: + - filter_subjects + - add_time_derived_measurements + - fit_outlier_detection + - occlude_outliers + - fit_normalization + - fit_vocabulary_indices + - normalization + - tokenization + - tensorization + +The stage configuration arguments will be as given in the yaml block below: +""" + + +from functools import partial + +from tests import RUNNER_SCRIPT, USE_LOCAL_SCRIPTS +from tests.MEDS_Transforms import ( + ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, + AGGREGATE_CODE_METADATA_SCRIPT, + FILTER_SUBJECTS_SCRIPT, + FIT_VOCABULARY_INDICES_SCRIPT, + NORMALIZATION_SCRIPT, + OCCLUDE_OUTLIERS_SCRIPT, + TOKENIZATION_SCRIPT, +) +from tests.MEDS_Transforms.test_multi_stage_preprocess_pipeline import ( + MEDS_CODE_METADATA, + WANT_FILTER, + WANT_FIT_NORMALIZATION, + WANT_FIT_OUTLIERS, + WANT_FIT_VOCABULARY_INDICES, + WANT_NORMALIZATION, + WANT_OCCLUDE_OUTLIERS, + WANT_TIME_DERIVED, + WANT_TOKENIZATION_EVENT_SEQS, + WANT_TOKENIZATION_SCHEMAS, + WANT_NRTs, +) +from tests.MEDS_Transforms.transform_tester_base import MEDS_SHARDS, SPLITS_DF +from tests.utils import add_params, single_stage_tester + + +def scriptify(s: str) -> str: + return f"python {s}" if USE_LOCAL_SCRIPTS else s + + +STAGE_RUNNER_YAML = f""" +filter_subjects: + script: {scriptify(FILTER_SUBJECTS_SCRIPT)} + +add_time_derived_measurements: + script: {scriptify(ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT)} + +fit_outlier_detection: + script: {scriptify(AGGREGATE_CODE_METADATA_SCRIPT)} + +occlude_outliers: + script: {scriptify(OCCLUDE_OUTLIERS_SCRIPT)} + +fit_normalization: + script: {scriptify(AGGREGATE_CODE_METADATA_SCRIPT)} + +fit_vocabulary_indices: + script: {scriptify(FIT_VOCABULARY_INDICES_SCRIPT)} + +normalization: + script: {scriptify(NORMALIZATION_SCRIPT)} + +tokenization: + script: {scriptify(TOKENIZATION_SCRIPT)} +""" + +PIPELINE_YAML = """ +defaults: + - _preprocess + - _self_ + +input_dir: {input_dir} +cohort_dir: {cohort_dir} + +stages: + - filter_subjects + - add_time_derived_measurements + - fit_outlier_detection + - occlude_outliers + - fit_normalization + - fit_vocabulary_indices + - normalization + - tokenization + - tensorization + +stage_configs: + filter_subjects: + min_events_per_subject: 5 + add_time_derived_measurements: + age: + DOB_code: "DOB" # This is the MEDS official code for BIRTH + age_code: "AGE" + age_unit: "years" + time_of_day: + time_of_day_code: "TIME_OF_DAY" + endpoints: [6, 12, 18, 24] + fit_outlier_detection: + aggregations: + - "values/n_occurrences" + - "values/sum" + - "values/sum_sqd" + occlude_outliers: + stddev_cutoff: 1 + fit_normalization: + aggregations: + - "code/n_occurrences" + - "code/n_subjects" + - "values/n_occurrences" + - "values/sum" + - "values/sum_sqd" +""" + + +def test_pipeline(): + single_stage_tester( + script=RUNNER_SCRIPT, + config_name="runner", + stage_name=None, + stage_kwargs=None, + do_pass_stage_name=False, + do_use_config_yaml=False, + input_files={ + **MEDS_SHARDS, + "metadata/codes.parquet": MEDS_CODE_METADATA, + "metadata/subject_splits.parquet": SPLITS_DF, + "pipeline.yaml": partial(add_params, PIPELINE_YAML), + "stage_runner.yaml": STAGE_RUNNER_YAML, + }, + want_outputs={ + **WANT_FIT_NORMALIZATION, + **WANT_FIT_OUTLIERS, + **WANT_FIT_VOCABULARY_INDICES, + **WANT_FILTER, + **WANT_TIME_DERIVED, + **WANT_OCCLUDE_OUTLIERS, + **WANT_NORMALIZATION, + **WANT_TOKENIZATION_SCHEMAS, + **WANT_TOKENIZATION_EVENT_SEQS, + **WANT_NRTs, + }, + assert_no_other_outputs=False, + should_error=False, + pipeline_config_fp="{input_dir}/pipeline.yaml", + stage_runner_fp="{input_dir}/stage_runner.yaml", + test_name="Runner Test", + do_include_dirs=False, + ) diff --git a/tests/utils.py b/tests/utils.py index 400dd2eb..d6fa6378 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -137,6 +137,9 @@ def run_command( err_cmd_lines = [] + if config_name is not None and not config_name.startswith("_"): + config_name = f"_{config_name}" + if do_use_config_yaml: if config_name is None: raise ValueError("config_name must be provided if do_use_config_yaml is True.") From c302c63574fc71559d87c26a9cfdea35443a612d Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 22:29:42 -0400 Subject: [PATCH 32/62] Basic runner script working and tests passing. --- src/MEDS_transforms/mapreduce/utils.py | 6 ++++++ tests/test_with_runner.py | 3 ++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/MEDS_transforms/mapreduce/utils.py b/src/MEDS_transforms/mapreduce/utils.py index a653d203..f941eaa4 100644 --- a/src/MEDS_transforms/mapreduce/utils.py +++ b/src/MEDS_transforms/mapreduce/utils.py @@ -474,6 +474,12 @@ def shard_iterator( shard_name = shard_name[: -len(in_suffix)] shards.append(shard_name) + if not shards: + raise FileNotFoundError( + f"No shards found in {input_dir} with suffix {in_suffix}. Directory contents: " + f"{', '.join(str(p.relative_to(input_dir)) for p in input_dir.glob('**/*'))}" + ) + # We initialize this to False and overwrite it if we find dedicated train shards. includes_only_train = False diff --git a/tests/test_with_runner.py b/tests/test_with_runner.py index 42eb7ff2..abc7be69 100644 --- a/tests/test_with_runner.py +++ b/tests/test_with_runner.py @@ -133,7 +133,7 @@ def test_pipeline(): do_pass_stage_name=False, do_use_config_yaml=False, input_files={ - **MEDS_SHARDS, + **{f"data/{k}": v for k, v in MEDS_SHARDS.items()}, "metadata/codes.parquet": MEDS_CODE_METADATA, "metadata/subject_splits.parquet": SPLITS_DF, "pipeline.yaml": partial(add_params, PIPELINE_YAML), @@ -157,4 +157,5 @@ def test_pipeline(): stage_runner_fp="{input_dir}/stage_runner.yaml", test_name="Runner Test", do_include_dirs=False, + df_check_kwargs={"check_column_order": False}, ) From 980382236ac742628ff98190a44340106afdff3f Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 22:35:58 -0400 Subject: [PATCH 33/62] Tried setting up a testing method that also tests automatic discovery of script name. --- tests/test_with_runner.py | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/tests/test_with_runner.py b/tests/test_with_runner.py index abc7be69..8dcca99e 100644 --- a/tests/test_with_runner.py +++ b/tests/test_with_runner.py @@ -46,36 +46,42 @@ from tests.MEDS_Transforms.transform_tester_base import MEDS_SHARDS, SPLITS_DF from tests.utils import add_params, single_stage_tester - -def scriptify(s: str) -> str: - return f"python {s}" if USE_LOCAL_SCRIPTS else s - - -STAGE_RUNNER_YAML = f""" +# Normally, you wouldn't need to specify all of these scripts, but in testing with local scripts we need to +# specify them all as they need to point to their python paths. +if USE_LOCAL_SCRIPTS: + STAGE_RUNNER_YAML = f""" filter_subjects: - script: {scriptify(FILTER_SUBJECTS_SCRIPT)} + script: "python {FILTER_SUBJECTS_SCRIPT}" add_time_derived_measurements: - script: {scriptify(ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT)} + script: "python {ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT}" fit_outlier_detection: - script: {scriptify(AGGREGATE_CODE_METADATA_SCRIPT)} + script: "python {AGGREGATE_CODE_METADATA_SCRIPT}" occlude_outliers: - script: {scriptify(OCCLUDE_OUTLIERS_SCRIPT)} + script: "python {OCCLUDE_OUTLIERS_SCRIPT}" fit_normalization: - script: {scriptify(AGGREGATE_CODE_METADATA_SCRIPT)} + script: "python {AGGREGATE_CODE_METADATA_SCRIPT}" fit_vocabulary_indices: - script: {scriptify(FIT_VOCABULARY_INDICES_SCRIPT)} + script: "python {FIT_VOCABULARY_INDICES_SCRIPT}" normalization: - script: {scriptify(NORMALIZATION_SCRIPT)} + script: "python {NORMALIZATION_SCRIPT}" tokenization: - script: {scriptify(TOKENIZATION_SCRIPT)} -""" + script: "python {TOKENIZATION_SCRIPT}" + """ +else: + STAGE_RUNNER_YAML = f""" +fit_outlier_detection: + script: {AGGREGATE_CODE_METADATA_SCRIPT} + +fit_normalization: + script: {AGGREGATE_CODE_METADATA_SCRIPT} + """ PIPELINE_YAML = """ defaults: From 24b8eeb254067d9c7b3a10329c5cd62b9a09f097 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 22:43:56 -0400 Subject: [PATCH 34/62] Testing script specification in the pipeline config. --- tests/test_with_runner.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/test_with_runner.py b/tests/test_with_runner.py index 8dcca99e..97a30bdb 100644 --- a/tests/test_with_runner.py +++ b/tests/test_with_runner.py @@ -56,9 +56,6 @@ add_time_derived_measurements: script: "python {ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT}" -fit_outlier_detection: - script: "python {AGGREGATE_CODE_METADATA_SCRIPT}" - occlude_outliers: script: "python {OCCLUDE_OUTLIERS_SCRIPT}" @@ -76,20 +73,17 @@ """ else: STAGE_RUNNER_YAML = f""" -fit_outlier_detection: - script: {AGGREGATE_CODE_METADATA_SCRIPT} - fit_normalization: script: {AGGREGATE_CODE_METADATA_SCRIPT} """ -PIPELINE_YAML = """ +PIPELINE_YAML = f""" defaults: - _preprocess - _self_ -input_dir: {input_dir} -cohort_dir: {cohort_dir} +input_dir: {{input_dir}} +cohort_dir: {{cohort_dir}} stages: - filter_subjects @@ -114,6 +108,7 @@ time_of_day_code: "TIME_OF_DAY" endpoints: [6, 12, 18, 24] fit_outlier_detection: + _script: {("python " if USE_LOCAL_SCRIPTS else "") + str(AGGREGATE_CODE_METADATA_SCRIPT)} aggregations: - "values/n_occurrences" - "values/sum" From 5f381543d9cf139cde3f08b93e9a32b79428d884 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 22:51:51 -0400 Subject: [PATCH 35/62] Fixed tests. --- src/MEDS_transforms/mapreduce/utils.py | 7 ++++--- src/MEDS_transforms/runner.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/MEDS_transforms/mapreduce/utils.py b/src/MEDS_transforms/mapreduce/utils.py index f941eaa4..716ddc09 100644 --- a/src/MEDS_transforms/mapreduce/utils.py +++ b/src/MEDS_transforms/mapreduce/utils.py @@ -453,10 +453,11 @@ def shard_iterator( >>> includes_only_train False - If it can't find any files, it will return an empty list: + If it can't find any files, it will error: >>> fps, includes_only_train = shard_iterator(cfg) - >>> fps - [] + Traceback (most recent call last): + ... + FileNotFoundError: No shards found in ... with suffix .parquet. Directory contents:... """ input_dir = Path(cfg.stage_cfg.data_input_dir) diff --git a/src/MEDS_transforms/runner.py b/src/MEDS_transforms/runner.py index afd68902..51de2389 100644 --- a/src/MEDS_transforms/runner.py +++ b/src/MEDS_transforms/runner.py @@ -242,8 +242,8 @@ def fix_str_for_path(s: str) -> str: return s.replace(" ", "_").replace("/", ".") -if __name__ == "__main__": - OmegaConf.register_new_resolver("load_yaml_file", load_yaml_file, replace=False) - OmegaConf.register_new_resolver("fix_str_for_path", fix_str_for_path, replace=False) +OmegaConf.register_new_resolver("load_yaml_file", load_yaml_file, replace=False) +OmegaConf.register_new_resolver("fix_str_for_path", fix_str_for_path, replace=False) +if __name__ == "__main__": main() From b9266a6cc9cf5e2054b83e4357f65e0487d2a41a Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 23:27:53 -0400 Subject: [PATCH 36/62] Added help string checkers. --- src/MEDS_transforms/configs/_runner.yaml | 6 +-- src/MEDS_transforms/runner.py | 4 +- tests/test_with_runner.py | 69 +++++++++++++++++++++++- tests/utils.py | 15 +++++- 4 files changed, 87 insertions(+), 7 deletions(-) mode change 100644 => 100755 src/MEDS_transforms/runner.py diff --git a/src/MEDS_transforms/configs/_runner.yaml b/src/MEDS_transforms/configs/_runner.yaml index 9c1ec89b..1abec8aa 100644 --- a/src/MEDS_transforms/configs/_runner.yaml +++ b/src/MEDS_transforms/configs/_runner.yaml @@ -2,13 +2,13 @@ pipeline_config_fp: ??? stage_runner_fp: null -_local_pipeline_config: ${oc.create:${load_yaml_file:${pipeline_config_fp}}} +_local_pipeline_config: ${oc.create:${load_yaml_file:${oc.select:pipeline_config_fp,null}}} _stage_runners: ${oc.create:${load_yaml_file:${stage_runner_fp}}} log_dir: "${_local_pipeline_config.cohort_dir}/.logs" -_pipeline_name: ${oc.select:_local_pipeline_config.pipeline_name, "MEDS-transforms Pipeline"} -_pipeline_description: ${oc.select:_pipeline_config.description, "No description provided."} +_pipeline_name: ${oc.select:_local_pipeline_config.etl_metadata.pipeline_name,"MEDS-transforms Pipeline"} +_pipeline_description: ${oc.select:_local_pipeline_config.description,"No description provided."} do_profile: False diff --git a/src/MEDS_transforms/runner.py b/src/MEDS_transforms/runner.py old mode 100644 new mode 100755 index 51de2389..af7edb01 --- a/src/MEDS_transforms/runner.py +++ b/src/MEDS_transforms/runner.py @@ -242,8 +242,8 @@ def fix_str_for_path(s: str) -> str: return s.replace(" ", "_").replace("/", ".") -OmegaConf.register_new_resolver("load_yaml_file", load_yaml_file, replace=False) -OmegaConf.register_new_resolver("fix_str_for_path", fix_str_for_path, replace=False) +OmegaConf.register_new_resolver("load_yaml_file", load_yaml_file, replace=True) +OmegaConf.register_new_resolver("fix_str_for_path", fix_str_for_path, replace=True) if __name__ == "__main__": main() diff --git a/tests/test_with_runner.py b/tests/test_with_runner.py index 97a30bdb..9ce308b0 100644 --- a/tests/test_with_runner.py +++ b/tests/test_with_runner.py @@ -44,7 +44,7 @@ WANT_NRTs, ) from tests.MEDS_Transforms.transform_tester_base import MEDS_SHARDS, SPLITS_DF -from tests.utils import add_params, single_stage_tester +from tests.utils import add_params, exact_str_regex, single_stage_tester # Normally, you wouldn't need to specify all of these scripts, but in testing with local scripts we need to # specify them all as they need to point to their python paths. @@ -85,6 +85,8 @@ input_dir: {{input_dir}} cohort_dir: {{cohort_dir}} +description: "A test pipeline for the MEDS-transforms pipeline runner." + stages: - filter_subjects - add_time_derived_measurements @@ -124,8 +126,73 @@ - "values/sum_sqd" """ +NO_ARGS_HELP_STR = """ +== MEDS-Transforms Pipeline Runner == +MEDS-Transforms Pipeline Runner is a command line tool for running entire MEDS-transform pipelines in a single +command. + +Runs the entire pipeline, end-to-end, based on the configuration provided. + +This script will launch many subsidiary commands via `subprocess`, one for each stage of the specified +pipeline. + +**MEDS-transforms Pipeline description:** + +No description provided. +""" + +WITH_CONFIG_HELP_STR = """ +== MEDS-Transforms Pipeline Runner == +MEDS-Transforms Pipeline Runner is a command line tool for running entire MEDS-transform pipelines in a single +command. + +Runs the entire pipeline, end-to-end, based on the configuration provided. + +This script will launch many subsidiary commands via `subprocess`, one for each stage of the specified +pipeline. + +**preprocess Pipeline description:** + +A test pipeline for the MEDS-transforms pipeline runner. +""" + def test_pipeline(): + single_stage_tester( + script=str(RUNNER_SCRIPT) + " -h", + config_name="runner", + stage_name=None, + stage_kwargs=None, + do_pass_stage_name=False, + do_use_config_yaml=False, + input_files={}, + want_outputs={}, + assert_no_other_outputs=True, + should_error=False, + test_name="Runner Help Test", + do_include_dirs=False, + hydra_verbose=False, + stdout_regex=exact_str_regex(NO_ARGS_HELP_STR.strip()), + ) + + single_stage_tester( + script=str(RUNNER_SCRIPT) + " -h", + config_name="runner", + stage_name=None, + stage_kwargs=None, + do_pass_stage_name=False, + do_use_config_yaml=False, + input_files={"pipeline.yaml": partial(add_params, PIPELINE_YAML)}, + want_outputs={}, + assert_no_other_outputs=True, + should_error=False, + pipeline_config_fp="{input_dir}/pipeline.yaml", + test_name="Runner Help Test", + do_include_dirs=False, + hydra_verbose=False, + stdout_regex=exact_str_regex(WITH_CONFIG_HELP_STR.strip()), + ) + single_stage_tester( script=RUNNER_SCRIPT, config_name="runner", diff --git a/tests/utils.py b/tests/utils.py index d6fa6378..a5018b4a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,5 @@ import json +import re import subprocess import tempfile from contextlib import contextmanager @@ -31,6 +32,10 @@ } +def exact_str_regex(s: str) -> str: + return f"^{re.escape(s)}$" + + def parse_meds_csvs( csvs: str | dict[str, str], schema: dict[str, pl.DataType] = MEDS_PL_SCHEMA ) -> pl.DataFrame | dict[str, pl.DataFrame]: @@ -381,6 +386,8 @@ def single_stage_tester( df_check_kwargs: dict | None = None, test_name: str | None = None, do_include_dirs: bool = True, + hydra_verbose: bool = True, + stdout_regex: str | None = None, **pipeline_kwargs, ): if test_name is None: @@ -395,7 +402,7 @@ def single_stage_tester( pipeline_kwargs[k] = v.format(input_dir=str(input_dir.resolve())) pipeline_config_kwargs = { - "hydra.verbose": True, + "hydra.verbose": hydra_verbose, **pipeline_kwargs, } @@ -426,6 +433,12 @@ def single_stage_tester( if should_error: return + if stdout_regex is not None: + regex = re.compile(stdout_regex) + assert regex.search(stdout) is not None, ( + f"Expected stdout to match regex:\n{stdout_regex}\n" f"Got:\n{stdout}" + ) + try: check_outputs( cohort_dir, From d642f309224ed69d7c1a024963245a1e22ba8688 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 23:33:58 -0400 Subject: [PATCH 37/62] Working help string testers. Had to remove some stuff from the config leveraging because the loaded yaml doesn't resolve defaults yet. --- src/MEDS_transforms/configs/_runner.yaml | 5 ++--- src/MEDS_transforms/runner.py | 6 ------ tests/test_with_runner.py | 2 +- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/MEDS_transforms/configs/_runner.yaml b/src/MEDS_transforms/configs/_runner.yaml index 1abec8aa..f8266788 100644 --- a/src/MEDS_transforms/configs/_runner.yaml +++ b/src/MEDS_transforms/configs/_runner.yaml @@ -7,14 +7,13 @@ _stage_runners: ${oc.create:${load_yaml_file:${stage_runner_fp}}} log_dir: "${_local_pipeline_config.cohort_dir}/.logs" -_pipeline_name: ${oc.select:_local_pipeline_config.etl_metadata.pipeline_name,"MEDS-transforms Pipeline"} _pipeline_description: ${oc.select:_local_pipeline_config.description,"No description provided."} do_profile: False hydra: job: - name: "${fix_str_for_path:${_pipeline_name}}_runner_${now:%Y-%m-%d_%H-%M-%S}" + name: "MEDS-transforms_runner_${now:%Y-%m-%d_%H-%M-%S}" run: dir: "${log_dir}" help: @@ -27,6 +26,6 @@ hydra: ${get_script_docstring:runner} - **${_pipeline_name} description:** + **MEDS-transforms Pipeline description:** ${_pipeline_description} diff --git a/src/MEDS_transforms/runner.py b/src/MEDS_transforms/runner.py index af7edb01..91ffe7e0 100755 --- a/src/MEDS_transforms/runner.py +++ b/src/MEDS_transforms/runner.py @@ -237,13 +237,7 @@ def load_yaml_file(path: str | None) -> dict | DictConfig: return yaml.load(yaml_text, Loader=Loader) -def fix_str_for_path(s: str) -> str: - """Replaces all space characters with underscores and all slashes with periods.""" - return s.replace(" ", "_").replace("/", ".") - - OmegaConf.register_new_resolver("load_yaml_file", load_yaml_file, replace=True) -OmegaConf.register_new_resolver("fix_str_for_path", fix_str_for_path, replace=True) if __name__ == "__main__": main() diff --git a/tests/test_with_runner.py b/tests/test_with_runner.py index 9ce308b0..0d2cf4f3 100644 --- a/tests/test_with_runner.py +++ b/tests/test_with_runner.py @@ -151,7 +151,7 @@ This script will launch many subsidiary commands via `subprocess`, one for each stage of the specified pipeline. -**preprocess Pipeline description:** +**MEDS-transforms Pipeline description:** A test pipeline for the MEDS-transforms pipeline runner. """ From 9f0e7f35b4ea95f01112449b9d310270fbc72a4d Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 23:48:41 -0400 Subject: [PATCH 38/62] Added test with parallelization. --- pyproject.toml | 2 +- src/MEDS_transforms/runner.py | 16 ++++++------- tests/test_with_runner.py | 44 +++++++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ba6d30f9..756ac8ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ [project.optional-dependencies] dev = ["pre-commit"] -tests = ["pytest", "pytest-cov", "rootutils"] +tests = ["pytest", "pytest-cov", "rootutils", "hydra-joblib-launcher"] local_parallelism = ["hydra-joblib-launcher"] slurm_parallelism = ["hydra-submitit-launcher"] docs = [ diff --git a/src/MEDS_transforms/runner.py b/src/MEDS_transforms/runner.py index 91ffe7e0..210a636f 100755 --- a/src/MEDS_transforms/runner.py +++ b/src/MEDS_transforms/runner.py @@ -18,7 +18,7 @@ try: from yaml import CLoader as Loader -except ImportError: +except ImportError: # pragma: no cover from yaml import Loader from MEDS_transforms import RESERVED_CONFIG_NAMES, RUNNER_CONFIG_YAML @@ -48,7 +48,7 @@ def get_script_from_name(stage_name: str) -> str | None: except ImportError: pass - return None + raise ValueError(f"Could not find a script for stage {stage_name}.") def get_parallelization_args( @@ -120,10 +120,8 @@ def run_stage(cfg: DictConfig, stage_name: str, default_parallelization_cfg: dic script = stage_runner_config.script elif "_script" in stage_config: script = stage_config._script - elif get_script_from_name(stage_name): - script = get_script_from_name(stage_name) else: - raise ValueError(f"Cannot determine script for {stage_name}") + script = get_script_from_name(stage_name) command_parts = [ script, @@ -190,7 +188,7 @@ def main(cfg: DictConfig): if cfg.get("do_profile", False): try: - pass + import hydra_profiler # noqa: F401 except ImportError as e: raise ValueError( "You can't run in profiling mode without installing hydra-profiler. Try installing " @@ -203,7 +201,9 @@ def main(cfg: DictConfig): logger.info("All stages are already complete. Exiting.") return - if "parallelize" in cfg: + if "parallelize" in cfg._stage_runners: + default_parallelization_cfg = cfg._stage_runners.parallelize + elif "parallelize" in cfg: default_parallelization_cfg = cfg.parallelize else: default_parallelization_cfg = None @@ -239,5 +239,5 @@ def load_yaml_file(path: str | None) -> dict | DictConfig: OmegaConf.register_new_resolver("load_yaml_file", load_yaml_file, replace=True) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/tests/test_with_runner.py b/tests/test_with_runner.py index 0d2cf4f3..5e2a4a18 100644 --- a/tests/test_with_runner.py +++ b/tests/test_with_runner.py @@ -77,6 +77,15 @@ script: {AGGREGATE_CODE_METADATA_SCRIPT} """ +PARALLEL_STAGE_RUNNER_YAML = f""" +parallelize: + n_workers: 2 + launcher: "joblib" + +{STAGE_RUNNER_YAML} +""" + + PIPELINE_YAML = f""" defaults: - _preprocess @@ -227,3 +236,38 @@ def test_pipeline(): do_include_dirs=False, df_check_kwargs={"check_column_order": False}, ) + + single_stage_tester( + script=RUNNER_SCRIPT, + config_name="runner", + stage_name=None, + stage_kwargs=None, + do_pass_stage_name=False, + do_use_config_yaml=False, + input_files={ + **{f"data/{k}": v for k, v in MEDS_SHARDS.items()}, + "metadata/codes.parquet": MEDS_CODE_METADATA, + "metadata/subject_splits.parquet": SPLITS_DF, + "pipeline.yaml": partial(add_params, PIPELINE_YAML), + "stage_runner.yaml": PARALLEL_STAGE_RUNNER_YAML, + }, + want_outputs={ + **WANT_FIT_NORMALIZATION, + **WANT_FIT_OUTLIERS, + **WANT_FIT_VOCABULARY_INDICES, + **WANT_FILTER, + **WANT_TIME_DERIVED, + **WANT_OCCLUDE_OUTLIERS, + **WANT_NORMALIZATION, + **WANT_TOKENIZATION_SCHEMAS, + **WANT_TOKENIZATION_EVENT_SEQS, + **WANT_NRTs, + }, + assert_no_other_outputs=False, + should_error=False, + pipeline_config_fp="{input_dir}/pipeline.yaml", + stage_runner_fp="{input_dir}/stage_runner.yaml", + test_name="Runner Test with parallelism", + do_include_dirs=False, + df_check_kwargs={"check_column_order": False}, + ) From 6fa8781e0f9ae958163bfcb7af6f0008cf86d53b Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 00:04:41 -0400 Subject: [PATCH 39/62] Fixed the parallel test. --- src/MEDS_transforms/runner.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/MEDS_transforms/runner.py b/src/MEDS_transforms/runner.py index 210a636f..8f82109c 100755 --- a/src/MEDS_transforms/runner.py +++ b/src/MEDS_transforms/runner.py @@ -56,7 +56,10 @@ def get_parallelization_args( ) -> list[str]: """Gets the parallelization args.""" - if parallelization_cfg is None or len(parallelization_cfg) == 0: + if parallelization_cfg is None: + return [] + + if len(parallelization_cfg) == 0 and len(default_parallelization_cfg) == 0: return [] if "n_workers" in parallelization_cfg: @@ -131,10 +134,14 @@ def run_stage(cfg: DictConfig, stage_name: str, default_parallelization_cfg: dic f"stage={stage_name}", ] - command_parts.extend( - get_parallelization_args(stage_runner_config.get("parallelize", {}), default_parallelization_cfg) + parallelization_args = get_parallelization_args( + stage_runner_config.get("parallelize", {}), default_parallelization_cfg ) + if parallelization_args: + multirun = parallelization_args.pop(0) + command_parts = command_parts[:3] + [multirun] + command_parts[3:] + parallelization_args + if do_profile: command_parts.append("++hydra.callbacks.profiler._target_=hydra_profiler.profiler.ProfilerCallback") From dc1f92ba704c7f226234e35cf13c31683fa67c75 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 08:40:17 -0400 Subject: [PATCH 40/62] Use up-to-date MEDS package and filepaths/names. --- pyproject.toml | 2 +- .../extract/finalize_MEDS_metadata.py | 12 +++++++++--- src/MEDS_transforms/mapreduce/mapper.py | 4 ++-- src/MEDS_transforms/reshard_to_split.py | 11 ++++++----- tests/MEDS_Extract/test_extract.py | 4 ++-- tests/MEDS_Extract/test_extract_no_metadata.py | 3 ++- tests/MEDS_Transforms/transform_tester_base.py | 4 ++-- tests/test_with_runner.py | 10 ++++++---- 8 files changed, 30 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 756ac8ea..a9b32925 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "polars~=1.1.0", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy", "meds==0.3.2", + "polars~=1.1.0", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy", "meds==0.3.3", ] [tool.setuptools_scm] diff --git a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py index 65549309..ce10a036 100755 --- a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py +++ b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py @@ -12,11 +12,14 @@ from loguru import logger from meds import __version__ as MEDS_VERSION from meds import ( + code_metadata_filepath, code_metadata_schema, + dataset_metadata_filepath, dataset_metadata_schema, held_out_split, subject_id_field, subject_split_schema, + subject_splits_filepath, train_split, tuning_split, ) @@ -150,9 +153,12 @@ def main(cfg: DictConfig): _, _, input_metadata_dir = stage_init(cfg) output_metadata_dir = Path(cfg.stage_cfg.reducer_output_dir) - output_code_metadata_fp = output_metadata_dir / "codes.parquet" - dataset_metadata_fp = output_metadata_dir / "dataset.json" - subject_splits_fp = output_metadata_dir / "subject_splits.parquet" + if output_metadata_dir.parts[-1] != Path(code_metadata_filepath).parts[0]: + raise ValueError(f"Output metadata directory must end in 'metadata'. Got {output_metadata_dir}") + + output_code_metadata_fp = output_metadata_dir.parent / code_metadata_filepath + dataset_metadata_fp = output_metadata_dir.parent / dataset_metadata_filepath + subject_splits_fp = output_metadata_dir.parent / subject_splits_filepath for out_fp in [output_code_metadata_fp, dataset_metadata_fp, subject_splits_fp]: out_fp.parent.mkdir(parents=True, exist_ok=True) diff --git a/src/MEDS_transforms/mapreduce/mapper.py b/src/MEDS_transforms/mapreduce/mapper.py index 6247be60..ade9910a 100644 --- a/src/MEDS_transforms/mapreduce/mapper.py +++ b/src/MEDS_transforms/mapreduce/mapper.py @@ -11,7 +11,7 @@ import hydra import polars as pl from loguru import logger -from meds import subject_id_field +from meds import subject_id_field, subject_splits_filepath from omegaconf import DictConfig, ListConfig from ..parser import is_matcher, matcher_to_expr @@ -625,7 +625,7 @@ def map_over( shards, includes_only_train = shard_iterator_fntr(cfg) if train_only: - split_fp = Path(cfg.input_dir) / "metadata" / "subject_splits.parquet" + split_fp = Path(cfg.input_dir) / subject_splits_filepath if includes_only_train: logger.info( f"Processing train split only via shard prefix. Not filtering with {str(split_fp.resolve())}." diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index 0fc06fef..deccc49f 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -10,6 +10,7 @@ import hydra import polars as pl from loguru import logger +from meds import subject_id_field, subject_splits_filepath, time_field from omegaconf import DictConfig from MEDS_transforms import PREPROCESS_CONFIG_YAML @@ -60,7 +61,7 @@ def make_new_shards_fn(df: pl.DataFrame, cfg: DictConfig, stage_cfg: DictConfig) splits_map[sp].append(pt_id) return shard_subjects( - subjects=df["subject_id"].to_numpy(), + subjects=df[subject_id_field].to_numpy(), n_subjects_per_shard=stage_cfg.n_subjects_per_shard, external_splits=splits_map, split_fracs_dict=None, @@ -96,7 +97,7 @@ def main(cfg: DictConfig): output_dir = Path(cfg.stage_cfg.output_dir) - splits_file = Path(cfg.input_dir) / "metadata" / "subject_splits.parquet" + splits_file = Path(cfg.input_dir) / subject_splits_filepath shards_fp = output_dir / ".shards.json" rwlock_wrap( @@ -139,15 +140,15 @@ def read_fn(input_dir: Path) -> pl.LazyFrame: logger.info(f"Reading shards for {subshard_name} (file names are in the input sharding scheme):") for in_fp, _ in orig_shards_iter: logger.info(f" - {str(in_fp.relative_to(input_dir).resolve())}") - new_df = pl.scan_parquet(in_fp, glob=False).filter(pl.col("subject_id").is_in(subjects)) + new_df = pl.scan_parquet(in_fp, glob=False).filter(pl.col(subject_id_field).is_in(subjects)) if df is None: df = new_df else: - df = df.merge_sorted(new_df, key="subject_id") + df = df.merge_sorted(new_df, key=subject_id_field) return df def compute_fn(df: list[pl.DataFrame]) -> pl.LazyFrame: - return df.sort(by=["subject_id", "time"], maintain_order=True, multithreaded=False) + return df.sort(by=[subject_id_field, time_field], maintain_order=True, multithreaded=False) def write_fn(df: pl.LazyFrame, out_fp: Path) -> None: write_lazyframe(df, out_fp) diff --git a/tests/MEDS_Extract/test_extract.py b/tests/MEDS_Extract/test_extract.py index 954ca21e..17e74f75 100644 --- a/tests/MEDS_Extract/test_extract.py +++ b/tests/MEDS_Extract/test_extract.py @@ -11,6 +11,7 @@ import polars as pl from meds import __version__ as MEDS_VERSION +from meds import subject_splits_filepath from tests.MEDS_Extract import ( CONVERT_TO_SHARDED_EVENTS_SCRIPT, @@ -628,8 +629,7 @@ def test_extraction(): got_json.pop("etl_version") # We don't test this as it changes with the commits. assert got_json == MEDS_OUTPUT_DATASET_METADATA_JSON, f"Dataset metadata differs: {got_json}" - # Check the splits parquet - output_file = MEDS_cohort_dir / "metadata" / "subject_splits.parquet" + output_file = MEDS_cohort_dir / subject_splits_filepath assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" got_df = pl.read_parquet(output_file, glob=False, use_pyarrow=True) diff --git a/tests/MEDS_Extract/test_extract_no_metadata.py b/tests/MEDS_Extract/test_extract_no_metadata.py index 0fa8eec8..fdea870e 100644 --- a/tests/MEDS_Extract/test_extract_no_metadata.py +++ b/tests/MEDS_Extract/test_extract_no_metadata.py @@ -11,6 +11,7 @@ import polars as pl from meds import __version__ as MEDS_VERSION +from meds import subject_splits_filepath from tests.MEDS_Extract import ( CONVERT_TO_SHARDED_EVENTS_SCRIPT, @@ -601,7 +602,7 @@ def test_extraction(): assert got_json == MEDS_OUTPUT_DATASET_METADATA_JSON, f"Dataset metadata differs: {got_json}" # Check the splits parquet - output_file = MEDS_cohort_dir / "metadata" / "subject_splits.parquet" + output_file = MEDS_cohort_dir / subject_splits_filepath assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" got_df = pl.read_parquet(output_file, glob=False, use_pyarrow=True) diff --git a/tests/MEDS_Transforms/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py index 93e70aa9..ac8a195d 100644 --- a/tests/MEDS_Transforms/transform_tester_base.py +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -15,7 +15,7 @@ from pathlib import Path import polars as pl -from meds import subject_id_field +from meds import subject_id_field, subject_splits_filepath from tests.utils import FILE_T, multi_stage_tester, parse_shards_yaml, single_stage_tester @@ -197,7 +197,7 @@ def remap_inputs_for_transform( input_splits_df = pl.DataFrame(input_splits_as_df) - unified_inputs["metadata/subject_splits.parquet"] = input_splits_df + unified_inputs[subject_splits_filepath] = input_splits_df return unified_inputs diff --git a/tests/test_with_runner.py b/tests/test_with_runner.py index 5e2a4a18..7913cb13 100644 --- a/tests/test_with_runner.py +++ b/tests/test_with_runner.py @@ -20,6 +20,8 @@ from functools import partial +from meds import code_metadata_filepath, subject_splits_filepath + from tests import RUNNER_SCRIPT, USE_LOCAL_SCRIPTS from tests.MEDS_Transforms import ( ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, @@ -211,8 +213,8 @@ def test_pipeline(): do_use_config_yaml=False, input_files={ **{f"data/{k}": v for k, v in MEDS_SHARDS.items()}, - "metadata/codes.parquet": MEDS_CODE_METADATA, - "metadata/subject_splits.parquet": SPLITS_DF, + code_metadata_filepath: MEDS_CODE_METADATA, + subject_splits_filepath: SPLITS_DF, "pipeline.yaml": partial(add_params, PIPELINE_YAML), "stage_runner.yaml": STAGE_RUNNER_YAML, }, @@ -246,8 +248,8 @@ def test_pipeline(): do_use_config_yaml=False, input_files={ **{f"data/{k}": v for k, v in MEDS_SHARDS.items()}, - "metadata/codes.parquet": MEDS_CODE_METADATA, - "metadata/subject_splits.parquet": SPLITS_DF, + code_metadata_filepath: MEDS_CODE_METADATA, + subject_splits_filepath: SPLITS_DF, "pipeline.yaml": partial(add_params, PIPELINE_YAML), "stage_runner.yaml": PARALLEL_STAGE_RUNNER_YAML, }, From e22d9cc56475968f3877faee5c39eedb9b7008e6 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 10:02:27 -0400 Subject: [PATCH 41/62] Some further updates to better support meds v0.3.3 --- .../extract/finalize_MEDS_metadata.py | 2 ++ tests/MEDS_Extract/test_extract.py | 15 ++++++-- .../MEDS_Extract/test_extract_no_metadata.py | 15 ++++++-- .../test_finalize_MEDS_metadata.py | 35 ++++++++++++++----- tests/utils.py | 23 ++++++++---- 5 files changed, 68 insertions(+), 22 deletions(-) diff --git a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py index ce10a036..a0201803 100755 --- a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py +++ b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py @@ -2,6 +2,7 @@ """Utilities for finalizing the metadata files for extracted MEDS datasets.""" import json +from datetime import datetime from pathlib import Path import hydra @@ -193,6 +194,7 @@ def main(cfg: DictConfig): "etl_name": cfg.etl_metadata.package_name, "etl_version": str(cfg.etl_metadata.package_version), "meds_version": MEDS_VERSION, + "created_at": datetime.now().isoformat(), } jsonschema.validate(instance=dataset_metadata, schema=dataset_metadata_schema) diff --git a/tests/MEDS_Extract/test_extract.py b/tests/MEDS_Extract/test_extract.py index 17e74f75..96aeba4d 100644 --- a/tests/MEDS_Extract/test_extract.py +++ b/tests/MEDS_Extract/test_extract.py @@ -6,12 +6,13 @@ import json import tempfile +from datetime import datetime from io import StringIO from pathlib import Path import polars as pl from meds import __version__ as MEDS_VERSION -from meds import subject_splits_filepath +from meds import code_metadata_filepath, dataset_metadata_filepath, subject_splits_filepath from tests.MEDS_Extract import ( CONVERT_TO_SHARDED_EVENTS_SCRIPT, @@ -598,7 +599,7 @@ def test_extraction(): full_stdout = "\n".join(all_stdouts) # Check code metadata - output_file = MEDS_cohort_dir / "metadata" / "codes.parquet" + output_file = MEDS_cohort_dir / code_metadata_filepath assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" got_df = pl.read_parquet(output_file, glob=False, use_pyarrow=True) @@ -621,12 +622,20 @@ def test_extraction(): ) # Check dataset metadata - output_file = MEDS_cohort_dir / "metadata" / "dataset.json" + output_file = MEDS_cohort_dir / dataset_metadata_filepath assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" got_json = json.loads(output_file.read_text()) assert "etl_version" in got_json, "Expected 'etl_version' to be in the dataset metadata." got_json.pop("etl_version") # We don't test this as it changes with the commits. + + assert "created_at" in got_json, "Expected 'created_at' to be in the dataset metadata." + created_at_obs = got_json.pop("created_at") + as_dt = datetime.fromisoformat(created_at_obs) + assert as_dt < datetime.now(), f"Expected 'created_at' to be before now, got {created_at_obs}." + created_ago = datetime.now() - as_dt + assert created_ago.total_seconds() < 5 * 60, "Expected 'created_at' to be within 5 minutes of now." + assert got_json == MEDS_OUTPUT_DATASET_METADATA_JSON, f"Dataset metadata differs: {got_json}" output_file = MEDS_cohort_dir / subject_splits_filepath diff --git a/tests/MEDS_Extract/test_extract_no_metadata.py b/tests/MEDS_Extract/test_extract_no_metadata.py index fdea870e..a8d83a01 100644 --- a/tests/MEDS_Extract/test_extract_no_metadata.py +++ b/tests/MEDS_Extract/test_extract_no_metadata.py @@ -6,12 +6,13 @@ import json import tempfile +from datetime import datetime from io import StringIO from pathlib import Path import polars as pl from meds import __version__ as MEDS_VERSION -from meds import subject_splits_filepath +from meds import code_metadata_filepath, dataset_metadata_filepath, subject_splits_filepath from tests.MEDS_Extract import ( CONVERT_TO_SHARDED_EVENTS_SCRIPT, @@ -570,7 +571,7 @@ def test_extraction(): full_stdout = "\n".join(all_stdouts) # Check code metadata - output_file = MEDS_cohort_dir / "metadata" / "codes.parquet" + output_file = MEDS_cohort_dir / code_metadata_filepath assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" got_df = pl.read_parquet(output_file, glob=False, use_pyarrow=True) @@ -593,12 +594,20 @@ def test_extraction(): ) # Check dataset metadata - output_file = MEDS_cohort_dir / "metadata" / "dataset.json" + output_file = MEDS_cohort_dir / dataset_metadata_filepath assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" got_json = json.loads(output_file.read_text()) assert "etl_version" in got_json, "Expected 'etl_version' to be in the dataset metadata." got_json.pop("etl_version") # We don't test this as it changes with the commits. + + assert "created_at" in got_json, "Expected 'created_at' to be in the dataset metadata." + created_at_obs = got_json.pop("created_at") + as_dt = datetime.fromisoformat(created_at_obs) + assert as_dt < datetime.now(), f"Expected 'created_at' to be before now, got {created_at_obs}." + created_ago = datetime.now() - as_dt + assert created_ago.total_seconds() < 5 * 60, "Expected 'created_at' to be within 5 minutes of now." + assert got_json == MEDS_OUTPUT_DATASET_METADATA_JSON, f"Dataset metadata differs: {got_json}" # Check the splits parquet diff --git a/tests/MEDS_Extract/test_finalize_MEDS_metadata.py b/tests/MEDS_Extract/test_finalize_MEDS_metadata.py index 274997f4..fd926055 100644 --- a/tests/MEDS_Extract/test_finalize_MEDS_metadata.py +++ b/tests/MEDS_Extract/test_finalize_MEDS_metadata.py @@ -5,8 +5,11 @@ """ +from datetime import datetime + import polars as pl from meds import __version__ as MEDS_VERSION +from meds import code_metadata_filepath, dataset_metadata_filepath, subject_splits_filepath from MEDS_transforms.utils import get_package_version as get_meds_transform_version from tests.MEDS_Extract import FINALIZE_METADATA_SCRIPT @@ -49,27 +52,41 @@ } ) + +def want_dataset_metadata(got: dict): + want_known = { + "dataset_name": "TEST", + "dataset_version": "1.0", + "etl_name": "MEDS_transforms", + "etl_version": get_meds_transform_version(), + "meds_version": MEDS_VERSION, + } + + assert "created_at" in got, "Expected 'created_at' to be in the dataset metadata." + created_at_obs = got.pop("created_at") + as_dt = datetime.fromisoformat(created_at_obs) + assert as_dt < datetime.now(), f"Expected 'created_at' to be before now, got {created_at_obs}." + created_ago = datetime.now() - as_dt + assert created_ago.total_seconds() < 5 * 60, "Expected 'created_at' to be within 5 minutes of now." + + assert got == want_known, f"Expected dataset metadata (less created at) to be {want_known}, got {got}." + + WANT_OUTPUTS = { - "metadata/codes": ( + code_metadata_filepath: ( METADATA_DF.with_columns( pl.col("code").cast(pl.String), pl.col("description").cast(pl.String), pl.col("parent_codes").cast(pl.List(pl.String)), ).select(["code", "description", "parent_codes"]) ), - "metadata/subject_splits": pl.DataFrame( + subject_splits_filepath: pl.DataFrame( { "subject_id": [239684, 1195293, 68729, 814703, 754281, 1500733], "split": ["train", "train", "train", "train", "tuning", "held_out"], } ), - "metadata/dataset.json": { - "dataset_name": "TEST", - "dataset_version": "1.0", - "etl_name": "MEDS_transforms", - "etl_version": get_meds_transform_version(), - "meds_version": MEDS_VERSION, - }, + dataset_metadata_filepath: want_dataset_metadata, } diff --git a/tests/utils.py b/tests/utils.py index a5018b4a..0e9ae943 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,6 +2,7 @@ import re import subprocess import tempfile +from collections.abc import Callable from contextlib import contextmanager from io import StringIO from pathlib import Path @@ -228,6 +229,19 @@ def assert_df_equal(want: pl.DataFrame, got: pl.DataFrame, msg: str = None, **kw raise AssertionError(f"{msg}:\nWant:\n{want}\nGot:\n{got}\n{e}") from e +def check_json(want: dict | Callable, got: dict, msg: str): + try: + match want: + case dict(): + assert got == want, f"Want:\n{want}\nGot:\n{got}" + case _ if callable(want): + want(got) + case _: + raise ValueError(f"Unknown want type: {type(want)}") + except AssertionError as e: + raise AssertionError(f"{msg}: {e}") from e + + def check_NRT_output( output_fp: Path, want_nrt: JointNestedRaggedTensorDict, @@ -352,13 +366,8 @@ def check_outputs( case ".nrt": check_NRT_output(output_fp, want, msg=msg) case ".json": - with open(output_fp) as f: - got = json.load(f) - assert got == want, ( - f"Expected JSON at {output_fp.relative_to(cohort_dir)} to be equal to the target.\n" - f"Wanted:\n{want}\n" - f"Got:\n{got}" - ) + got = json.loads(output_fp.read_text()) + check_json(want, got, msg=msg) case _: raise ValueError(f"Unknown file suffix: {file_suffix}") From 9823d4c0a722ec19ca70f416ef31af6ee4532207 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 11:21:38 -0400 Subject: [PATCH 42/62] Setting up one-script runnable for MIMIC-IV. --- MIMIC-IV_Example/configs/extract_MIMIC.yaml | 33 +++++ MIMIC-IV_Example/configs/pre_MEDS.yaml | 12 +- .../local_parallelism_runner.yaml | 3 + MIMIC-IV_Example/pre_MEDS.py | 28 +++-- MIMIC-IV_Example/run.sh | 117 ++++++++++++++++++ MIMIC-IV_Example/slurm_runner.yaml | 60 +++++++++ src/MEDS_transforms/runner.py | 2 +- 7 files changed, 240 insertions(+), 15 deletions(-) create mode 100644 MIMIC-IV_Example/configs/extract_MIMIC.yaml create mode 100644 MIMIC-IV_Example/local_parallelism_runner.yaml create mode 100755 MIMIC-IV_Example/run.sh create mode 100644 MIMIC-IV_Example/slurm_runner.yaml diff --git a/MIMIC-IV_Example/configs/extract_MIMIC.yaml b/MIMIC-IV_Example/configs/extract_MIMIC.yaml new file mode 100644 index 00000000..75e89740 --- /dev/null +++ b/MIMIC-IV_Example/configs/extract_MIMIC.yaml @@ -0,0 +1,33 @@ +defaults: + - _extract + - _self_ + +description: |- + This pipeline extracts the MIMIC-IV dataet in longitudinal, sparse form from an input dataset meeting select + criteria and converts them to the flattened, MEDS format. You can control the key arguments to this pipeline + by setting environment variables: + ```bash + $EVENT_CONVERSION_CONFIG_FP=# Path to your event conversion config + $MIMICIV_PRE_MEDS_DIR=# Path to the output dir of the pre-MEDS step + $MIMICIV_MEDS_COHORT_DIR=# Path to where you want the dataset to live + ``` + + +# The event conversion configuration file is used throughout the pipeline to define the events to extract. +event_conversion_config_fp: ${oc.env:EVENT_CONVERSION_CONFIG_FP} + +input_dir: ${oc.env:MIMICIV_PRE_MEDS_DIR} +cohort_dir: ${oc.env:MIMICIV_MEDS_COHORT_DIR} + +stage_configs.shard_events.infer_schema_length: 999999999 +etl_metadata.dataset_name: MIMIC-IV +etl_metadata.dataset_version: 2.2 + +stages: + - shard_events + - split_and_shard_subjects + - convert_to_sharded_events + - merge_to_MEDS_cohort + - extract_code_metadata + - finalize_MEDS_metadata + - finalize_MEDS_data diff --git a/MIMIC-IV_Example/configs/pre_MEDS.yaml b/MIMIC-IV_Example/configs/pre_MEDS.yaml index b5cfa4cb..325903e0 100644 --- a/MIMIC-IV_Example/configs/pre_MEDS.yaml +++ b/MIMIC-IV_Example/configs/pre_MEDS.yaml @@ -1,11 +1,15 @@ -raw_cohort_dir: ??? -output_dir: ??? +input_dir: ${oc.env:MIMICIV_RAW_DIR} +cohort_dir: ${oc.env:MIMICIV_PRE_MEDS_DIR} + +do_overwrite: false + +log_dir: ${cohort_dir}/.logs # Hydra hydra: job: name: pre_MEDS_${now:%Y-%m-%d_%H-%M-%S} run: - dir: ${output_dir}/.logs/${hydra.job.name} + dir: ${log_dir} sweep: - dir: ${output_dir}/.logs/${hydra.job.name} + dir: ${log_dir} diff --git a/MIMIC-IV_Example/local_parallelism_runner.yaml b/MIMIC-IV_Example/local_parallelism_runner.yaml new file mode 100644 index 00000000..a1d9a6c1 --- /dev/null +++ b/MIMIC-IV_Example/local_parallelism_runner.yaml @@ -0,0 +1,3 @@ +parallelize: + n_workers: ${oc.env:N_WORKERS} + launcher: "joblib" diff --git a/MIMIC-IV_Example/pre_MEDS.py b/MIMIC-IV_Example/pre_MEDS.py index b40bb925..745c21ba 100755 --- a/MIMIC-IV_Example/pre_MEDS.py +++ b/MIMIC-IV_Example/pre_MEDS.py @@ -56,26 +56,34 @@ def fix_static_data(raw_static_df: pl.LazyFrame, death_times_df: pl.LazyFrame) - def main(cfg: DictConfig): """Performs pre-MEDS data wrangling for MIMIC-IV. - Inputs are the raw MIMIC files, read from the `raw_cohort_dir` config parameter. Output files are either + Inputs are the raw MIMIC files, read from the `input_dir` config parameter. Output files are either symlinked (if they are not modified) or written in processed form to the `MEDS_input_dir` config parameter. Hydra is used to manage configuration parameters and logging. """ hydra_loguru_init() - raw_cohort_dir = Path(cfg.raw_cohort_dir) - MEDS_input_dir = Path(cfg.output_dir) + input_dir = Path(cfg.input_dir) + MEDS_input_dir = Path(cfg.cohort_dir) - all_fps = list(raw_cohort_dir.glob("**/*.*")) + done_fp = MEDS_input_dir / ".done" + if done_fp.is_file() and not cfg.do_overwrite: + logger.info( + f"Pre-MEDS transformation already complete as {done_fp} exists and " + f"do_overwrite={cfg.do_overwrite}. Returning." + ) + exit(0) + + all_fps = list(input_dir.glob("**/*.*")) dfs_to_load = {} seen_fps = {} for in_fp in all_fps: - pfx = get_shard_prefix(raw_cohort_dir, in_fp) + pfx = get_shard_prefix(input_dir, in_fp) try: - fp, read_fn = get_supported_fp(raw_cohort_dir, pfx) + fp, read_fn = get_supported_fp(input_dir, pfx) except FileNotFoundError: logger.info(f"Skipping {pfx} @ {str(in_fp.resolve())} as no compatible dataframe file was found.") continue @@ -88,7 +96,7 @@ def main(cfg: DictConfig): else: seen_fps[str(fp.resolve())] = read_fn - out_fp = MEDS_input_dir / fp.relative_to(raw_cohort_dir) + out_fp = MEDS_input_dir / fp.relative_to(input_dir) if out_fp.is_file(): print(f"Done with {pfx}. Continuing") @@ -130,7 +138,7 @@ def main(cfg: DictConfig): fps = fps_and_cols["fps"] cols = list(fps_and_cols["cols"]) - df_to_load_fp, df_to_load_read_fn = get_supported_fp(raw_cohort_dir, df_to_load_pfx) + df_to_load_fp, df_to_load_read_fn = get_supported_fp(input_dir, df_to_load_pfx) st = datetime.now() @@ -142,7 +150,7 @@ def main(cfg: DictConfig): logger.info(f" Loaded in {datetime.now() - st}") for fp in fps: - pfx = get_shard_prefix(raw_cohort_dir, fp) + pfx = get_shard_prefix(input_dir, fp) out_fp = MEDS_input_dir / f"{pfx}.parquet" logger.info(f" Processing dependent df @ {pfx}...") @@ -157,7 +165,7 @@ def main(cfg: DictConfig): logger.info(f" Processed and wrote to {str(out_fp.resolve())} in {datetime.now() - fp_st}") logger.info(f"Done! All dataframes processed and written to {str(MEDS_input_dir.resolve())}") - + done_fp.write_text(f"Finished at {datetime.now()}") if __name__ == "__main__": main() diff --git a/MIMIC-IV_Example/run.sh b/MIMIC-IV_Example/run.sh new file mode 100755 index 00000000..21486365 --- /dev/null +++ b/MIMIC-IV_Example/run.sh @@ -0,0 +1,117 @@ +#!/usr/bin/env bash + +# This makes the script fail if any internal script fails +set -e + +# Function to display help message +function display_help() { + echo "Usage: $0 " + echo + echo "This script processes MIMIC-IV data through several steps, handling raw data conversion," + echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." + echo + echo "Arguments:" + echo " MIMICIV_RAW_DIR Directory containing raw MIMIC-IV data files." + echo " MIMICIV_PREMEDS_DIR Output directory for pre-MEDS data." + echo " MIMICIV_MEDS_DIR Output directory for processed MEDS data." + echo " (OPTIONAL) STAGE_RUNNER_CONFIG_FP Where the stage runner config lives, if desired." + echo " (OPTIONAL) do_unzip=true OR do_unzip=false Optional flag to unzip csv files before processing." + echo + echo "Options:" + echo " -h, --help Display this help message and exit." + exit 1 +} + +echo "Unsetting SLURM_CPU_BIND in case you're running this on a slurm interactive node with slurm parallelism" +unset SLURM_CPU_BIND + +# Check if the first parameter is '-h' or '--help' +if [[ "$1" == "-h" || "$1" == "--help" ]]; then + display_help +fi + +# Check for mandatory parameters +if [ "$#" -lt 3 ]; then + echo "Error: Incorrect number of arguments provided." + display_help +fi + +if [ "$#" -gt 5 ]; then + echo "Error: Incorrect number of arguments provided." + display_help +fi + +export MIMICIV_RAW_DIR=$1 +export MIMICIV_PRE_MEDS_DIR=$2 +export MIMICIV_MEDS_COHORT_DIR=$3 +shift 3 + +# Defaults +STAGE_RUNNER_ARG="" +_DO_UNZIP_ARG_STR="" + +if [ $# -ge 1 ]; then + case "$1" in + do_unzip=*) + if [ $# -ge 2 ]; then + echo "Error: Stage runner filepath must come before do_unzip if both are specified!" + display_help + else + _DO_UNZIP_ARG_STR="$1" + shift 1 + fi + ;; + *) + STAGE_RUNNER_ARG="stage_runner_fp=$1" + if [ $# -ge 2 ]; then + _DO_UNZIP_ARG_STR="$2" + shift 2 + else + shift 1 + fi + ;; + esac +fi + +DO_UNZIP="false" + +if [ -z "$_DO_UNZIP_ARG_STR" ]; then + case "$_DO_UNZIP_ARG_STR" in + do_unzip=true) + DO_UNZIP="true" + ;; + do_unzip=false) + DO_UNZIP="false" + ;; + *) + echo "Error: Invalid do_unzip value. Use 'do_unzip=true' or 'do_unzip=false'." + exit 1 + ;; + esac + echo "Setting DO_UNZIP=$DO_UNZIP" +fi + +# TODO: Add wget blocks once testing is validated. + +export EVENT_CONVERSION_CONFIG_FP="$(pwd)/configs/event_configs.yaml" +export PIPELINE_CONFIG_FP="$(pwd)/configs/extract_MIMIC.yaml" +export PRE_MEDS_PY_FP="$(pwd)/pre_MEDS.py" + +if [ "$DO_UNZIP" == "true" ]; then + echo "Unzipping csv.gz files in ${MIMICIV_RAW_DIR}." + for file in "${MIMICIV_RAW_DIR}"/*/*.csv.gz; do gzip -d --force "$file"; done +else + echo "Skipping unzipping." +fi + +echo "Running pre-MEDS conversion." +python "$PRE_MEDS_PY_FP" input_dir="$MIMICIV_RAW_DIR" cohort_dir="$MIMICIV_PRE_MEDS_DIR" + +if [ -z "$N_WORKERS" ]; then + echo "Setting N_WORKERS to 1 to avoid issues with the runners." + export N_WORKERS="1" +fi + +echo "Running extraction pipeline." +MEDS_transform-runner "pipeline_config_fp=$PIPELINE_CONFIG_FP" "$STAGE_RUNNER_ARG" + diff --git a/MIMIC-IV_Example/slurm_runner.yaml b/MIMIC-IV_Example/slurm_runner.yaml new file mode 100644 index 00000000..01a59e58 --- /dev/null +++ b/MIMIC-IV_Example/slurm_runner.yaml @@ -0,0 +1,60 @@ +parallelize: + n_workers: ${oc.env:N_WORKERS} + launcher: "submitit_slurm" + +shard_events: + parallelize: + launcher_params: + timeout_min: 60 + cpus_per_task: 10 + mem_gb: 10 + partition: "short" + +split_and_shard_subjects: + parallelize: + launcher_params: + timeout_min: 60 + cpus_per_task: 10 + mem_gb: 10 + partition: "short" + +convert_to_sharded_events: + parallelize: + launcher_params: + timeout_min: 60 + cpus_per_task: 10 + mem_gb: 10 + partition: "short" + +merge_to_MEDS_cohort: + parallelize: + launcher_params: + timeout_min: 60 + cpus_per_task: 10 + mem_gb: 10 + partition: "short" + +extract_code_metadata: + parallelize: + launcher_params: + timeout_min: 60 + cpus_per_task: 10 + mem_gb: 10 + partition: "short" + +finalize_MEDS_metadata: + parallelize: + launcher_params: + timeout_min: 60 + cpus_per_task: 10 + mem_gb: 10 + partition: "short" + +finalize_MEDS_data: + parallelize: + n_workers: 1 + launcher_params: + timeout_min: 15 + cpus_per_task: 5 + mem_gb: 10 + partition: "short" diff --git a/src/MEDS_transforms/runner.py b/src/MEDS_transforms/runner.py index 8f82109c..e99e014a 100755 --- a/src/MEDS_transforms/runner.py +++ b/src/MEDS_transforms/runner.py @@ -115,7 +115,7 @@ def run_stage(cfg: DictConfig, stage_name: str, default_parallelization_cfg: dic do_profile = cfg.get("do_profile", False) pipeline_config_fp = Path(cfg.pipeline_config_fp) - stage_config = cfg._local_pipeline_config.stage_configs.get(stage_name, {}) + stage_config = cfg._local_pipeline_config.get("stage_configs", {}).get(stage_name, {}) stage_runner_config = cfg._stage_runners.get(stage_name, {}) script = None From ed2c7b7249f5353cc77347b93fd3054bae6369a0 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 11:30:39 -0400 Subject: [PATCH 43/62] Fixed some small typos and issues. --- MIMIC-IV_Example/run.sh | 38 +++++++----------------------- MIMIC-IV_Example/slurm_runner.yaml | 2 +- 2 files changed, 9 insertions(+), 31 deletions(-) diff --git a/MIMIC-IV_Example/run.sh b/MIMIC-IV_Example/run.sh index 21486365..3a028ec6 100755 --- a/MIMIC-IV_Example/run.sh +++ b/MIMIC-IV_Example/run.sh @@ -11,11 +11,10 @@ function display_help() { echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." echo echo "Arguments:" - echo " MIMICIV_RAW_DIR Directory containing raw MIMIC-IV data files." - echo " MIMICIV_PREMEDS_DIR Output directory for pre-MEDS data." - echo " MIMICIV_MEDS_DIR Output directory for processed MEDS data." - echo " (OPTIONAL) STAGE_RUNNER_CONFIG_FP Where the stage runner config lives, if desired." - echo " (OPTIONAL) do_unzip=true OR do_unzip=false Optional flag to unzip csv files before processing." + echo " MIMICIV_RAW_DIR Directory containing raw MIMIC-IV data files." + echo " MIMICIV_PREMEDS_DIR Output directory for pre-MEDS data." + echo " MIMICIV_MEDS_DIR Output directory for processed MEDS data." + echo " (OPTIONAL) do_unzip=true OR do_unzip=false Optional flag to unzip files before processing." echo echo "Options:" echo " -h, --help Display this help message and exit." @@ -36,46 +35,26 @@ if [ "$#" -lt 3 ]; then display_help fi -if [ "$#" -gt 5 ]; then - echo "Error: Incorrect number of arguments provided." - display_help -fi - export MIMICIV_RAW_DIR=$1 export MIMICIV_PRE_MEDS_DIR=$2 export MIMICIV_MEDS_COHORT_DIR=$3 shift 3 # Defaults -STAGE_RUNNER_ARG="" _DO_UNZIP_ARG_STR="" if [ $# -ge 1 ]; then case "$1" in do_unzip=*) - if [ $# -ge 2 ]; then - echo "Error: Stage runner filepath must come before do_unzip if both are specified!" - display_help - else - _DO_UNZIP_ARG_STR="$1" - shift 1 - fi - ;; - *) - STAGE_RUNNER_ARG="stage_runner_fp=$1" - if [ $# -ge 2 ]; then - _DO_UNZIP_ARG_STR="$2" - shift 2 - else - shift 1 - fi + _DO_UNZIP_ARG_STR="$1" + shift 1 ;; esac fi DO_UNZIP="false" -if [ -z "$_DO_UNZIP_ARG_STR" ]; then +if [ ! -z "$_DO_UNZIP_ARG_STR" ]; then case "$_DO_UNZIP_ARG_STR" in do_unzip=true) DO_UNZIP="true" @@ -113,5 +92,4 @@ if [ -z "$N_WORKERS" ]; then fi echo "Running extraction pipeline." -MEDS_transform-runner "pipeline_config_fp=$PIPELINE_CONFIG_FP" "$STAGE_RUNNER_ARG" - +MEDS_transform-runner "pipeline_config_fp=$PIPELINE_CONFIG_FP" "$@" diff --git a/MIMIC-IV_Example/slurm_runner.yaml b/MIMIC-IV_Example/slurm_runner.yaml index 01a59e58..8d20e59e 100644 --- a/MIMIC-IV_Example/slurm_runner.yaml +++ b/MIMIC-IV_Example/slurm_runner.yaml @@ -7,7 +7,7 @@ shard_events: launcher_params: timeout_min: 60 cpus_per_task: 10 - mem_gb: 10 + mem_gb: 25 partition: "short" split_and_shard_subjects: From 46ef7524e4c86803982e6019370d0be138fc2e08 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 11:50:01 -0400 Subject: [PATCH 44/62] Some more improvements. --- MIMIC-IV_Example/configs/extract_MIMIC.yaml | 10 +++++++--- MIMIC-IV_Example/run.sh | 9 +++++++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/MIMIC-IV_Example/configs/extract_MIMIC.yaml b/MIMIC-IV_Example/configs/extract_MIMIC.yaml index 75e89740..cb2ebac7 100644 --- a/MIMIC-IV_Example/configs/extract_MIMIC.yaml +++ b/MIMIC-IV_Example/configs/extract_MIMIC.yaml @@ -19,9 +19,13 @@ event_conversion_config_fp: ${oc.env:EVENT_CONVERSION_CONFIG_FP} input_dir: ${oc.env:MIMICIV_PRE_MEDS_DIR} cohort_dir: ${oc.env:MIMICIV_MEDS_COHORT_DIR} -stage_configs.shard_events.infer_schema_length: 999999999 -etl_metadata.dataset_name: MIMIC-IV -etl_metadata.dataset_version: 2.2 +etl_metadata: + dataset_name: MIMIC-IV + dataset_version: 2.2 + +stage_configs: + shard_events: + infer_schema_length: 999999999 stages: - shard_events diff --git a/MIMIC-IV_Example/run.sh b/MIMIC-IV_Example/run.sh index 3a028ec6..85aefe31 100755 --- a/MIMIC-IV_Example/run.sh +++ b/MIMIC-IV_Example/run.sh @@ -77,8 +77,13 @@ export PIPELINE_CONFIG_FP="$(pwd)/configs/extract_MIMIC.yaml" export PRE_MEDS_PY_FP="$(pwd)/pre_MEDS.py" if [ "$DO_UNZIP" == "true" ]; then - echo "Unzipping csv.gz files in ${MIMICIV_RAW_DIR}." - for file in "${MIMICIV_RAW_DIR}"/*/*.csv.gz; do gzip -d --force "$file"; done + GZ_FILES="${MIMICIV_RAW_DIR}/*/*.csv.gz" + if compgen -G $GZ_FILES > /dev/null; then + echo "Unzipping csv.gz files matching $GZ_FILES." + for file in $GZ_FILES; do gzip -d --force "$file"; done + else + echo "No csz.gz files to unzip at $GZ_FILES." + fi else echo "Skipping unzipping." fi From 6a905cd0e53fe38d6f4f1120ff5bdf985d68cc72 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 12:06:15 -0400 Subject: [PATCH 45/62] Fixed lint errors. --- MIMIC-IV_Example/configs/extract_MIMIC.yaml | 13 ++++++------- MIMIC-IV_Example/pre_MEDS.py | 1 + MIMIC-IV_Example/run.sh | 15 ++++++++++----- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/MIMIC-IV_Example/configs/extract_MIMIC.yaml b/MIMIC-IV_Example/configs/extract_MIMIC.yaml index cb2ebac7..eb9b32ee 100644 --- a/MIMIC-IV_Example/configs/extract_MIMIC.yaml +++ b/MIMIC-IV_Example/configs/extract_MIMIC.yaml @@ -3,16 +3,15 @@ defaults: - _self_ description: |- - This pipeline extracts the MIMIC-IV dataet in longitudinal, sparse form from an input dataset meeting select - criteria and converts them to the flattened, MEDS format. You can control the key arguments to this pipeline - by setting environment variables: + This pipeline extracts the MIMIC-IV dataset in longitudinal, sparse form from an input dataset meeting + select criteria and converts them to the flattened, MEDS format. You can control the key arguments to this + pipeline by setting environment variables: ```bash - $EVENT_CONVERSION_CONFIG_FP=# Path to your event conversion config - $MIMICIV_PRE_MEDS_DIR=# Path to the output dir of the pre-MEDS step - $MIMICIV_MEDS_COHORT_DIR=# Path to where you want the dataset to live + export EVENT_CONVERSION_CONFIG_FP=# Path to your event conversion config + export MIMICIV_PRE_MEDS_DIR=# Path to the output dir of the pre-MEDS step + export MIMICIV_MEDS_COHORT_DIR=# Path to where you want the dataset to live ``` - # The event conversion configuration file is used throughout the pipeline to define the events to extract. event_conversion_config_fp: ${oc.env:EVENT_CONVERSION_CONFIG_FP} diff --git a/MIMIC-IV_Example/pre_MEDS.py b/MIMIC-IV_Example/pre_MEDS.py index 745c21ba..6007c64a 100755 --- a/MIMIC-IV_Example/pre_MEDS.py +++ b/MIMIC-IV_Example/pre_MEDS.py @@ -167,5 +167,6 @@ def main(cfg: DictConfig): logger.info(f"Done! All dataframes processed and written to {str(MEDS_input_dir.resolve())}") done_fp.write_text(f"Finished at {datetime.now()}") + if __name__ == "__main__": main() diff --git a/MIMIC-IV_Example/run.sh b/MIMIC-IV_Example/run.sh index 85aefe31..9c06c7e9 100755 --- a/MIMIC-IV_Example/run.sh +++ b/MIMIC-IV_Example/run.sh @@ -54,7 +54,7 @@ fi DO_UNZIP="false" -if [ ! -z "$_DO_UNZIP_ARG_STR" ]; then +if [ -n "$_DO_UNZIP_ARG_STR" ]; then case "$_DO_UNZIP_ARG_STR" in do_unzip=true) DO_UNZIP="true" @@ -72,13 +72,18 @@ fi # TODO: Add wget blocks once testing is validated. -export EVENT_CONVERSION_CONFIG_FP="$(pwd)/configs/event_configs.yaml" -export PIPELINE_CONFIG_FP="$(pwd)/configs/extract_MIMIC.yaml" -export PRE_MEDS_PY_FP="$(pwd)/pre_MEDS.py" +EVENT_CONVERSION_CONFIG_FP="$(pwd)/configs/event_configs.yaml" +PIPELINE_CONFIG_FP="$(pwd)/configs/extract_MIMIC.yaml" +PRE_MEDS_PY_FP="$(pwd)/pre_MEDS.py" + +# We export these variables separately from their assignment so that any errors during assignment are caught. +export EVENT_CONVERSION_CONFIG_FP +export PIPELINE_CONFIG_FP +export PRE_MEDS_PY_FP if [ "$DO_UNZIP" == "true" ]; then GZ_FILES="${MIMICIV_RAW_DIR}/*/*.csv.gz" - if compgen -G $GZ_FILES > /dev/null; then + if compgen -G "$GZ_FILES" > /dev/null; then echo "Unzipping csv.gz files matching $GZ_FILES." for file in $GZ_FILES; do gzip -d --force "$file"; done else From 4cea12cf4717e9f3ba7e03a9f3a8ed56eb2594e7 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 12:10:39 -0400 Subject: [PATCH 46/62] Removed outdated files. --- MIMIC-IV_Example/joint_script.sh | 166 ------------------------- MIMIC-IV_Example/joint_script_slurm.sh | 141 --------------------- 2 files changed, 307 deletions(-) delete mode 100755 MIMIC-IV_Example/joint_script.sh delete mode 100755 MIMIC-IV_Example/joint_script_slurm.sh diff --git a/MIMIC-IV_Example/joint_script.sh b/MIMIC-IV_Example/joint_script.sh deleted file mode 100755 index dd1459c4..00000000 --- a/MIMIC-IV_Example/joint_script.sh +++ /dev/null @@ -1,166 +0,0 @@ -#!/usr/bin/env bash - -# This makes the script fail if any internal script fails -set -e - -# Function to display help message -function display_help() { - echo "Usage: $0 " - echo - echo "This script processes MIMIC-IV data through several steps, handling raw data conversion," - echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." - echo - echo "Arguments:" - echo " MIMICIV_RAW_DIR Directory containing raw MIMIC-IV data files." - echo " MIMICIV_PREMEDS_DIR Output directory for pre-MEDS data." - echo " MIMICIV_MEDS_DIR Output directory for processed MEDS data." - echo " N_PARALLEL_WORKERS Number of parallel workers for processing." - echo " (OPTIONAL) do_unzip=true OR do_unzip=false Optional flag to unzip csv files before processing." - echo - echo "Options:" - echo " -h, --help Display this help message and exit." - exit 1 -} - -# Check if the first parameter is '-h' or '--help' -if [[ "$1" == "-h" || "$1" == "--help" ]]; then - display_help -fi - -# Check for mandatory parameters -if [ "$#" -lt 4 ]; then - echo "Error: Incorrect number of arguments provided." - display_help -fi - -MIMICIV_RAW_DIR="$1" -MIMICIV_PREMEDS_DIR="$2" -MIMICIV_MEDS_DIR="$3" -N_PARALLEL_WORKERS="$4" - -# Default do_unzip value -DO_UNZIP="false" - -# Check if the 5th argument is either do_unzip=true or do_unzip=false -if [ $# -ge 5 ]; then - case "$5" in - do_unzip=true) - DO_UNZIP="true" - shift 5 - ;; - do_unzip=false) - DO_UNZIP="false" - shift 5 - ;; - do_unzip=*) - echo "Error: Invalid do_unzip value. Use 'do_unzip=true' or 'do_unzip=false'." - exit 1 - ;; - *) - # If the 5th argument is not related to do_unzip, leave it for other_args - shift 4 - ;; - esac -else - shift 4 -fi - -if [ "$DO_UNZIP" == "true" ]; then - echo "Unzipping csv files." - for file in "${MIMICIV_RAW_DIR}"/*/*.csv.gz; do gzip -d --force "$file"; done -else - echo "Skipping unzipping." -fi - -echo "Running pre-MEDS conversion." -./MIMIC-IV_Example/pre_MEDS.py raw_cohort_dir="$MIMICIV_RAW_DIR" output_dir="$MIMICIV_PREMEDS_DIR" - -echo "Running shard_events.py with $N_PARALLEL_WORKERS workers in parallel" -MEDS_extract-shard_events \ - --multirun \ - worker="range(0,$N_PARALLEL_WORKERS)" \ - hydra/launcher=joblib \ - input_dir="$MIMICIV_PREMEDS_DIR" \ - cohort_dir="$MIMICIV_MEDS_DIR" \ - stage="shard_events" \ - stage_configs.shard_events.infer_schema_length=999999999 \ - etl_metadata.dataset_name="MIMIC-IV" \ - etl_metadata.dataset_version="2.2" \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" - -echo "Splitting subjects in serial" -MEDS_extract-split_and_shard_subjects \ - input_dir="$MIMICIV_PREMEDS_DIR" \ - cohort_dir="$MIMICIV_MEDS_DIR" \ - stage="split_and_shard_subjects" \ - etl_metadata.dataset_name="MIMIC-IV" \ - etl_metadata.dataset_version="2.2" \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" - -echo "Converting to sharded events with $N_PARALLEL_WORKERS workers in parallel" -MEDS_extract-convert_to_sharded_events \ - --multirun \ - worker="range(0,$N_PARALLEL_WORKERS)" \ - hydra/launcher=joblib \ - input_dir="$MIMICIV_PREMEDS_DIR" \ - cohort_dir="$MIMICIV_MEDS_DIR" \ - stage="convert_to_sharded_events" \ - etl_metadata.dataset_name="MIMIC-IV" \ - etl_metadata.dataset_version="2.2" \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" - -echo "Merging to a MEDS cohort with $N_PARALLEL_WORKERS workers in parallel" -MEDS_extract-merge_to_MEDS_cohort \ - --multirun \ - worker="range(0,$N_PARALLEL_WORKERS)" \ - hydra/launcher=joblib \ - input_dir="$MIMICIV_PREMEDS_DIR" \ - cohort_dir="$MIMICIV_MEDS_DIR" \ - stage="merge_to_MEDS_cohort" \ - etl_metadata.dataset_name="MIMIC-IV" \ - etl_metadata.dataset_version="2.2" \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" - -echo "Aggregating initial code stats with $N_PARALLEL_WORKERS workers in parallel" -MEDS_transform-aggregate_code_metadata \ - --config-name="extract" \ - --multirun \ - worker="range(0,$N_PARALLEL_WORKERS)" \ - hydra/launcher=joblib \ - input_dir="$MIMICIV_PREMEDS_DIR" \ - cohort_dir="$MIMICIV_MEDS_DIR" \ - stage="aggregate_code_metadata" \ - etl_metadata.dataset_name="MIMIC-IV" \ - etl_metadata.dataset_version="2.2" \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" - -# TODO -- make this the pre-meds dir and have the pre-meds script symlink -echo "Collecting code metadata in serial." -MEDS_extract-extract_code_metadata \ - input_dir="$MIMICIV_RAW_DIR" \ - cohort_dir="$MIMICIV_MEDS_DIR" \ - stage="extract_code_metadata" \ - etl_metadata.dataset_name="MIMIC-IV" \ - etl_metadata.dataset_version="2.2" \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" - -echo "Finalizing MEDS data with $N_PARALLEL_WORKERS workers in parallel" -MEDS_extract-finalize_MEDS_data \ - --multirun \ - worker="range(0,$N_PARALLEL_WORKERS)" \ - hydra/launcher=joblib \ - input_dir="$MIMICIV_RAW_DIR" \ - cohort_dir="$MIMICIV_MEDS_DIR" \ - stage="finalize_MEDS_data" \ - etl_metadata.dataset_name="MIMIC-IV" \ - etl_metadata.dataset_version="2.2" \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" - -echo "Finalizing MEDS metadata in serial." -MEDS_extract-finalize_MEDS_metadata \ - input_dir="$MIMICIV_RAW_DIR" \ - cohort_dir="$MIMICIV_MEDS_DIR" \ - stage="finalize_MEDS_metadata" \ - etl_metadata.dataset_name="MIMIC-IV" \ - etl_metadata.dataset_version="2.2" \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" diff --git a/MIMIC-IV_Example/joint_script_slurm.sh b/MIMIC-IV_Example/joint_script_slurm.sh deleted file mode 100755 index e13fb7e9..00000000 --- a/MIMIC-IV_Example/joint_script_slurm.sh +++ /dev/null @@ -1,141 +0,0 @@ -#!/usr/bin/env bash - -# This makes the script fail if any internal script fails -set -e - -# Function to display help message -function display_help() { - echo "Usage: $0 " - echo - echo "This script processes MIMIC-IV data through several steps, handling raw data conversion," - echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." - echo "This script uses slurm to process the data in parallel via the 'submitit' Hydra launcher." - echo - echo "Arguments:" - echo " MIMICIV_RAW_DIR Directory containing raw MIMIC-IV data files." - echo " MIMICIV_PREMEDS_DIR Output directory for pre-MEDS data." - echo " MIMICIV_MEDS_DIR Output directory for processed MEDS data." - echo " N_PARALLEL_WORKERS Number of parallel workers for processing." - echo - echo "Options:" - echo " -h, --help Display this help message and exit." - exit 1 -} - -# Check if the first parameter is '-h' or '--help' -if [[ "$1" == "-h" || "$1" == "--help" ]]; then - display_help -fi - -# Check for mandatory parameters -if [ "$#" -ne 4 ]; then - echo "Error: Incorrect number of arguments provided." - display_help -fi - -export MIMICIV_RAW_DIR="$1" -export MIMICIV_PREMEDS_DIR="$2" -export MIMICIV_MEDS_DIR="$3" -export N_PARALLEL_WORKERS="$4" - -shift 4 - -# Note we use `--multirun` throughout here due to ensure the submitit launcher is used throughout, so that -# this doesn't fall back on running anything locally in a setting where only slurm worker nodes have -# sufficient computational resources to run the actual jobs. - -echo "Running pre-MEDS conversion on one worker." -./MIMIC-IV_Example/pre_MEDS.py \ - --multirun \ - +worker="range(0,1)" \ - hydra/launcher=submitit_slurm \ - hydra.launcher.timeout_min=60 \ - hydra.launcher.cpus_per_task=10 \ - hydra.launcher.mem_gb=50 \ - hydra.launcher.partition="short" \ - raw_cohort_dir="$MIMICIV_RAW_DIR" \ - output_dir="$MIMICIV_PREMEDS_DIR" - -echo "Trying submitit launching with $N_PARALLEL_WORKERS jobs." - -MEDS_extract-shard_events \ - --multirun \ - worker="range(0,$N_PARALLEL_WORKERS)" \ - hydra/launcher=submitit_slurm \ - hydra.launcher.timeout_min=60 \ - hydra.launcher.cpus_per_task=10 \ - hydra.launcher.mem_gb=50 \ - hydra.launcher.partition="short" \ - "hydra.job.env_copy=[PATH]" \ - input_dir="$MIMICIV_PREMEDS_DIR" \ - cohort_dir="$MIMICIV_MEDS_DIR" \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml \ - stage=shard_events - -echo "Splitting subjects on one worker" -MEDS_extract-split_and_shard_subjects \ - --multirun \ - worker="range(0,1)" \ - hydra/launcher=submitit_slurm \ - hydra.launcher.timeout_min=60 \ - hydra.launcher.cpus_per_task=10 \ - hydra.launcher.mem_gb=50 \ - hydra.launcher.partition="short" \ - input_dir="$MIMICIV_PREMEDS_DIR" \ - cohort_dir="$MIMICIV_MEDS_DIR" \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" - -echo "Converting to sharded events with $N_PARALLEL_WORKERS workers in parallel" -MEDS_extract-convert_to_sharded_events \ - --multirun \ - worker="range(0,$N_PARALLEL_WORKERS)" \ - hydra/launcher=submitit_slurm \ - hydra.launcher.timeout_min=60 \ - hydra.launcher.cpus_per_task=10 \ - hydra.launcher.mem_gb=50 \ - hydra.launcher.partition="short" \ - input_dir="$MIMICIV_PREMEDS_DIR" \ - cohort_dir="$MIMICIV_MEDS_DIR" \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" - -echo "Merging to a MEDS cohort with $N_PARALLEL_WORKERS workers in parallel" -MEDS_extract-merge_to_MEDS_cohort \ - --multirun \ - worker="range(0,$N_PARALLEL_WORKERS)" \ - hydra/launcher=submitit_slurm \ - hydra.launcher.timeout_min=60 \ - hydra.launcher.cpus_per_task=10 \ - hydra.launcher.mem_gb=50 \ - hydra.launcher.partition="short" \ - input_dir="$MIMICIV_PREMEDS_DIR" \ - cohort_dir="$MIMICIV_MEDS_DIR" \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" - -echo "Aggregating initial code stats with $N_PARALLEL_WORKERS workers in parallel" -MEDS_transform-aggregate_code_metadata \ - --config-name="extract" \ - --multirun \ - worker="range(0,$N_PARALLEL_WORKERS)" \ - hydra/launcher=submitit_slurm \ - hydra.launcher.timeout_min=60 \ - hydra.launcher.cpus_per_task=10 \ - hydra.launcher.mem_gb=50 \ - hydra.launcher.partition="short" \ - input_dir="$MIMICIV_PREMEDS_DIR" \ - cohort_dir="$MIMICIV_MEDS_DIR" \ - stage="aggregate_code_metadata" - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" - -# TODO -- make this the pre-meds dir and have the pre-meds script symlink -echo "Collecting code metadata with $N_PARALLEL_WORKERS workers in parallel" -MEDS_extract-extract_code_metadata \ - --multirun \ - worker="range(0,$N_PARALLEL_WORKERS)" \ - hydra/launcher=submitit_slurm \ - hydra.launcher.timeout_min=60 \ - hydra.launcher.cpus_per_task=10 \ - hydra.launcher.mem_gb=50 \ - hydra.launcher.partition="short" \ - input_dir="$MIMICIV_RAW_DIR" \ - cohort_dir="$MIMICIV_MEDS_DIR" \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" From 110bd81a7f49bd3fad079eecac84361cc1cbf110 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 12:25:43 -0400 Subject: [PATCH 47/62] Updated the README to reflect the new usage. --- MIMIC-IV_Example/README.md | 155 +++++++++++++------------------------ 1 file changed, 53 insertions(+), 102 deletions(-) diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index 6bf348d0..dbfebf9e 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -6,33 +6,34 @@ up from this one). ## Step 0: Installation -Download this repository and install the requirements: -If you want to install via pypi, (note that for now, you still need to copy some files locally even with a -pypi installation, which is covered below, so make sure you are in a suitable directory) use: - ```bash conda create -n MEDS python=3.12 conda activate MEDS -pip install "MEDS_transforms[local_parallelism]" -mkdir MIMIC-IV_Example -cd MIMIC-IV_Example -wget https://raw.githubusercontent.com/mmcdermott/MEDS_transforms/main/MIMIC-IV_Example/joint_script.sh -wget https://raw.githubusercontent.com/mmcdermott/MEDS_transforms/main/MIMIC-IV_Example/joint_script_slurm.sh -wget https://raw.githubusercontent.com/mmcdermott/MEDS_transforms/main/MIMIC-IV_Example/pre_MEDS.py -chmod +x joint_script.sh -chmod +x joint_script_slurm.sh -chmod +x pre_MEDS.py -cd .. +pip install "MEDS_transforms[local_parallelism,slurm_parallelism]" ``` -If you want to install locally, use: +If you want to profile the time and memory costs of your ETL, also install: `pip install hydra-profiler`. +## Step 0.5: Set-up +Set some environment variables and download the necessary files: ```bash -git clone git@github.com:mmcdermott/MEDS_transforms.git -cd MEDS_transforms -conda create -n MEDS python=3.12 -conda activate MEDS -pip install .[local_parallelism] +export MIMICIV_RAW_DIR=??? # set to the directory in which you want to store the raw MIMIC-IV data +export MIMICIV_PRE_MEDS_DIR=??? # set to the directory in which you want to store the raw MIMIC-IV data +export MIMICIV_MEDS_COHORT_DIR=??? # set to the directory in which you want to store the raw MIMIC-IV data + +export VERSION=0.0.6 # or whatever version you want +export URL="https://raw.githubusercontent.com/mmcdermott/MEDS_transforms/$VERSION/MIMIC-IV_Example" + +wget $URL/run.sh +wget $URL/pre_MEDS.py +wget $URL/local_parallelism_runner.yaml +wget $URL/slurm_runner.yaml +mkdir configs +cd configs +wget $URL/configs/extract_MIMIC.yaml +cd .. +chmod +x run.sh +chmod +x pre_MEDS.py ``` ## Step 1: Download MIMIC-IV @@ -46,101 +47,51 @@ the root directory of where the resulting _core data files_ are stored -- e.g., ```bash cd $MIMIC_RAW_DIR -wget https://raw.githubusercontent.com/MIT-LCP/mimic-code/v2.4.0/mimic-iv/concepts/concept_map/d_labitems_to_loinc.csv -wget https://raw.githubusercontent.com/MIT-LCP/mimic-code/v2.4.0/mimic-iv/concepts/concept_map/inputevents_to_rxnorm.csv -wget https://raw.githubusercontent.com/MIT-LCP/mimic-code/v2.4.0/mimic-iv/concepts/concept_map/lab_itemid_to_loinc.csv -wget https://raw.githubusercontent.com/MIT-LCP/mimic-code/v2.4.0/mimic-iv/concepts/concept_map/meas_chartevents_main.csv -wget https://raw.githubusercontent.com/MIT-LCP/mimic-code/v2.4.0/mimic-iv/concepts/concept_map/meas_chartevents_value.csv -wget https://raw.githubusercontent.com/MIT-LCP/mimic-code/v2.4.0/mimic-iv/concepts/concept_map/numerics-summary.csv -wget https://raw.githubusercontent.com/MIT-LCP/mimic-code/v2.4.0/mimic-iv/concepts/concept_map/outputevents_to_loinc.csv -wget https://raw.githubusercontent.com/MIT-LCP/mimic-code/v2.4.0/mimic-iv/concepts/concept_map/proc_datetimeevents.csv -wget https://raw.githubusercontent.com/MIT-LCP/mimic-code/v2.4.0/mimic-iv/concepts/concept_map/proc_itemid.csv -wget https://raw.githubusercontent.com/MIT-LCP/mimic-code/v2.4.0/mimic-iv/concepts/concept_map/waveforms-summary.csv +export MIMIC_URL=https://raw.githubusercontent.com/MIT-LCP/mimic-code/v2.4.0/mimic-iv/concepts/concept_map +wget $MIMIC_URL/d_labitems_to_loinc.csv +wget $MIMIC_URL/inputevents_to_rxnorm.csv +wget $MIMIC_URL/lab_itemid_to_loinc.csv +wget $MIMIC_URL/meas_chartevents_main.csv +wget $MIMIC_URL/meas_chartevents_value.csv +wget $MIMIC_URL/numerics-summary.csv +wget $MIMIC_URL/outputevents_to_loinc.csv +wget $MIMIC_URL/proc_datetimeevents.csv +wget $MIMIC_URL/proc_itemid.csv +wget $MIMIC_URL/waveforms-summary.csv ``` -## Step 2: Run the basic MEDS ETL - -This step contains several sub-steps; luckily, all these substeps can be run via a single script, with the -`joint_script.sh` script which uses the Hydra `joblib` launcher to run things with local parallelism (make -sure you enable this feature by including the `[local_parallelism]` option during installation) or via -`joint_script_slurm.sh` which uses the Hydra `submitit` launcher to run things through slurm (make sure you -enable this feature by including the `[slurm_parallelism]` option during installation). This script entails -several steps: - -### Step 2.1: Get the data ready for base MEDS extraction - -This is a step in a few parts: - -1. Join a few tables by `hadm_id` to get the right times in the right rows for processing. In - particular, we need to join: - - the `hosp/diagnoses_icd` table with the `hosp/admissions` table to get the `dischtime` for each - `hadm_id`. - - the `hosp/drgcodes` table with the `hosp/admissions` table to get the `dischtime` for each `hadm_id`. -2. Convert the subject's static data to a more parseable form. This entails: - - Get the subject's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and - `anchor_offset` fields. - - Merge the subject's `dod` with the `deathtime` from the `admissions` table. - -After these steps, modified files or symlinks to the original files will be written in a new directory which -will be used as the input to the actual MEDS extraction ETL. We'll use `$MIMICIV_PREMEDS_DIR` to denote this -directory. +## Step 2: Run the MEDS ETL -This step is run in the `joint_script.sh` script or the `joint_script_slurm.sh` script, but in either case the -base command that is run is as follows (assumed to be run **not** from this directory but from the -root directory of this repository): +To run the MEDS ETL, run the following command: ```bash -./MIMIC-IV_Example/pre_MEDS.py raw_cohort_dir=$MIMICIV_RAW_DIR output_dir=$MIMICIV_PREMEDS_DIR +./run.sh $MIMICIV_RAW_DIR $MIMICIV_PRE_MEDS_DIR $MIMICIV_MEDS_DIR do_unzip=true ``` -In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. +To not unzip the `.csv.gz` files, set `do_unzip=false` instead of `do_unzip=true`. -### Step 2.2: Run the MEDS extraction ETL +To use a specific stage runner file (e.g., to set different parallelism options), you can specify it as an +additional argument -We will assume you want to output the final MEDS dataset into a directory we'll denote as `$MIMICIV_MEDS_DIR`. -Note this is a different directory than the pre-MEDS directory (though, of course, they can both be -subdirectories of the same root directory). - -This is a step in 4 parts: - -1. Sub-shard the raw files. Run this command as many times simultaneously as you would like to have workers - performing this sub-sharding step. See below for how to automate this parallelism using hydra launchers. - - This step uses the `./scripts/extraction/shard_events.py` script. See `joint_script*.sh` for the expected - format of the command. - -2. Extract and form the subject splits and sub-shards. The `./scripts/extraction/split_and_shard_subjects.py` - script is used for this step. See `joint_script*.sh` for the expected format of the command. - -3. Extract subject sub-shards and convert to MEDS events. The - `./scripts/extraction/convert_to_sharded_events.py` script is used for this step. See `joint_script*.sh` for - the expected format of the command. - -4. Merge the MEDS events into a single file per subject sub-shard. The - `./scripts/extraction/merge_to_MEDS_cohort.py` script is used for this step. See `joint_script*.sh` for the - expected format of the command. - -5. (Optional) Generate preliminary code statistics and merge to external metadata. This is not performed - currently in the `joint_script*.sh` scripts. - -## Limitations / TO-DOs: - -Currently, some tables are ignored, including: +```bash +export N_WORKERS=5 +./run.sh $MIMICIV_RAW_DIR $MIMICIV_PRE_MEDS_DIR $MIMICIV_MEDS_DIR do_unzip=true \ + stage_runner_fp=slurm_runner.yaml +``` -1. `hosp/emar_detail` -2. `hosp/microbiologyevents` -3. `hosp/services` -4. `icu/datetimeevents` -5. `icu/ingredientevents` +The `N_WORKERS` environment variable set before the command controls how many parallel workers should be used +at maximum. -Lots of questions remain about how to appropriately handle times of the data -- e.g., things like HCPCS -events are stored at the level of the _date_, not the _datetime_. How should those be slotted into the -timeline which is otherwise stored at the _datetime_ resolution? +The `slurm_runner.yaml` file (downloaded above) runs each stage across several workers on separate slurm +worker nodes using the `submitit` launcher. _**You will need to customize this file to your own slurm system +so that the partition names are correct before use.**_ The memory and time costs are viable in the current +configuration, but if your nodes are sufficiently different you may need to adjust those as well. -Other questions: +The `local_parallelism_runner.yaml` file (downloaded above) runs each stage via separate processes on the +launching machine. There are no additional arguments needed for this stage beyond the `N_WORKERS` environment +variable and there is nothing to customize in this file. -1. How to handle merging the deathtimes between the hosp table and the subjects table? -2. How to handle the dob nonsense MIMIC has? +To profile the time and memory costs of your ETL, add the `do_profile=true` flag at the end. ## Notes From 2f09440d0f7dfba55d4b8d73a707b7a2a879a0d9 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 12:33:12 -0400 Subject: [PATCH 48/62] Fixed typo in event config --- MIMIC-IV_Example/configs/event_configs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MIMIC-IV_Example/configs/event_configs.yaml b/MIMIC-IV_Example/configs/event_configs.yaml index 2986a958..b7bc7b49 100644 --- a/MIMIC-IV_Example/configs/event_configs.yaml +++ b/MIMIC-IV_Example/configs/event_configs.yaml @@ -303,7 +303,7 @@ icu/inputevents: - KG time: col(starttime) time_format: "%Y-%m-%d %H:%M:%S" - numeric_value: subjectweight + numeric_value: patientweight icu/outputevents: output: From 942613ac5dcb448858f7eb953ed908ae71f43caf Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 18:29:49 -0400 Subject: [PATCH 49/62] Params changes. --- MIMIC-IV_Example/configs/event_configs.yaml | 2 +- MIMIC-IV_Example/slurm_runner.yaml | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/MIMIC-IV_Example/configs/event_configs.yaml b/MIMIC-IV_Example/configs/event_configs.yaml index b7bc7b49..666bdd32 100644 --- a/MIMIC-IV_Example/configs/event_configs.yaml +++ b/MIMIC-IV_Example/configs/event_configs.yaml @@ -109,7 +109,7 @@ hosp/omr: time: col(chartdate) time_format: "%Y-%m-%d" -hosp/subjects: +hosp/patients: gender: code: - GENDER diff --git a/MIMIC-IV_Example/slurm_runner.yaml b/MIMIC-IV_Example/slurm_runner.yaml index 8d20e59e..4b1e716b 100644 --- a/MIMIC-IV_Example/slurm_runner.yaml +++ b/MIMIC-IV_Example/slurm_runner.yaml @@ -7,15 +7,16 @@ shard_events: launcher_params: timeout_min: 60 cpus_per_task: 10 - mem_gb: 25 + mem_gb: 40 partition: "short" split_and_shard_subjects: parallelize: + n_workers: 1 launcher_params: - timeout_min: 60 + timeout_min: 10 cpus_per_task: 10 - mem_gb: 10 + mem_gb: 7 partition: "short" convert_to_sharded_events: @@ -23,7 +24,7 @@ convert_to_sharded_events: launcher_params: timeout_min: 60 cpus_per_task: 10 - mem_gb: 10 + mem_gb: 25 partition: "short" merge_to_MEDS_cohort: @@ -31,7 +32,7 @@ merge_to_MEDS_cohort: launcher_params: timeout_min: 60 cpus_per_task: 10 - mem_gb: 10 + mem_gb: 25 partition: "short" extract_code_metadata: From 6fa26dc3bd055bc05659e3b26021e92a25592cec Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 20:38:29 -0400 Subject: [PATCH 50/62] Added a fix for #182 --- MIMIC-IV_Example/slurm_runner.yaml | 12 ++--- src/MEDS_transforms/configs/_extract.yaml | 1 + .../convert_to_sharded_events.yaml | 2 + .../extract/convert_to_sharded_events.py | 49 ++++++++++++++----- 4 files changed, 45 insertions(+), 19 deletions(-) create mode 100644 src/MEDS_transforms/configs/stage_configs/convert_to_sharded_events.yaml diff --git a/MIMIC-IV_Example/slurm_runner.yaml b/MIMIC-IV_Example/slurm_runner.yaml index 4b1e716b..a8e5979b 100644 --- a/MIMIC-IV_Example/slurm_runner.yaml +++ b/MIMIC-IV_Example/slurm_runner.yaml @@ -5,7 +5,7 @@ parallelize: shard_events: parallelize: launcher_params: - timeout_min: 60 + timeout_min: 50 cpus_per_task: 10 mem_gb: 40 partition: "short" @@ -22,7 +22,7 @@ split_and_shard_subjects: convert_to_sharded_events: parallelize: launcher_params: - timeout_min: 60 + timeout_min: 10 cpus_per_task: 10 mem_gb: 25 partition: "short" @@ -30,17 +30,17 @@ convert_to_sharded_events: merge_to_MEDS_cohort: parallelize: launcher_params: - timeout_min: 60 + timeout_min: 15 cpus_per_task: 10 - mem_gb: 25 + mem_gb: 85 partition: "short" extract_code_metadata: parallelize: launcher_params: - timeout_min: 60 + timeout_min: 10 cpus_per_task: 10 - mem_gb: 10 + mem_gb: 25 partition: "short" finalize_MEDS_metadata: diff --git a/src/MEDS_transforms/configs/_extract.yaml b/src/MEDS_transforms/configs/_extract.yaml index dec483d2..2ee757cd 100644 --- a/src/MEDS_transforms/configs/_extract.yaml +++ b/src/MEDS_transforms/configs/_extract.yaml @@ -3,6 +3,7 @@ defaults: - stage_configs: - shard_events - split_and_shard_subjects + - convert_to_sharded_events - merge_to_MEDS_cohort - extract_code_metadata - finalize_MEDS_metadata diff --git a/src/MEDS_transforms/configs/stage_configs/convert_to_sharded_events.yaml b/src/MEDS_transforms/configs/stage_configs/convert_to_sharded_events.yaml new file mode 100644 index 00000000..7ab5c1b8 --- /dev/null +++ b/src/MEDS_transforms/configs/stage_configs/convert_to_sharded_events.yaml @@ -0,0 +1,2 @@ +convert_to_sharded_events: + do_dedup_text_and_numeric: True diff --git a/src/MEDS_transforms/extract/convert_to_sharded_events.py b/src/MEDS_transforms/extract/convert_to_sharded_events.py index 8ac66ac0..fb8794b9 100755 --- a/src/MEDS_transforms/extract/convert_to_sharded_events.py +++ b/src/MEDS_transforms/extract/convert_to_sharded_events.py @@ -93,7 +93,9 @@ def get_code_expr(code_field: str | list | ListConfig) -> tuple[pl.Expr, pl.Expr return code_expr, code_null_filter_expr, needed_cols -def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.LazyFrame: +def extract_event( + df: pl.LazyFrame, event_cfg: dict[str, str | None], do_dedup_text_and_numeric: bool = False, +) -> pl.LazyFrame: """Extracts a single event dataframe from the raw data. Args: @@ -123,6 +125,8 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy possible, these additional columns should conform to the conventions of the MEDS data schema --- e.g., primary numeric values associated with the event should be named `"numeric_value"` in the output MEDS data (and thus have the key `"numeric_value"` in the `event_cfg` dictionary). + do_dedup_text_and_numeric: If true, the result will ensure that the `text_value` column is dropped if + it is simply a string version of the `numeric_value` column. Returns: A DataFrame containing the event data extracted from the raw data, containing only unique rows across @@ -150,25 +154,27 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy ... "code_modifier": ["1", "2", "3", "4"], ... "time": ["2021-01-01", "2021-01-02", "2021-01-03", "2021-01-04"], ... "numeric_value": [1, 2, 3, 4], + ... "woo_text": ["1", "2", "3/10", "4.24"], ... }) >>> event_cfg = { ... "code": ["FOO", "col(code)", "col(code_modifier)"], ... "time": "col(time)", ... "time_format": "%Y-%m-%d", ... "numeric_value": "numeric_value", + ... "text_value": "woo_text", ... } - >>> extract_event(raw_data, event_cfg) + >>> extract_event(raw_data, event_cfg, do_dedup_text_and_numeric=True) shape: (4, 4) - ┌────────────┬───────────┬─────────────────────┬───────────────┐ - │ subject_id ┆ code ┆ time ┆ numeric_value │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ str ┆ datetime[μs] ┆ i64 │ - ╞════════════╪═══════════╪═════════════════════╪═══════════════╡ - │ 1 ┆ FOO//A//1 ┆ 2021-01-01 00:00:00 ┆ 1 │ - │ 1 ┆ FOO//B//2 ┆ 2021-01-02 00:00:00 ┆ 2 │ - │ 2 ┆ FOO//C//3 ┆ 2021-01-03 00:00:00 ┆ 3 │ - │ 2 ┆ FOO//D//4 ┆ 2021-01-04 00:00:00 ┆ 4 │ - └────────────┴───────────┴─────────────────────┴───────────────┘ + ┌────────────┬───────────┬─────────────────────┬───────────────┬────────────┐ + │ subject_id ┆ code ┆ time ┆ numeric_value ┆ text_value │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ str ┆ datetime[μs] ┆ i64 ┆ str │ + ╞════════════╪═══════════╪═════════════════════╪═══════════════╪════════════╡ + │ 1 ┆ FOO//A//1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ null │ + │ 1 ┆ FOO//B//2 ┆ 2021-01-02 00:00:00 ┆ 2 ┆ null │ + │ 2 ┆ FOO//C//3 ┆ 2021-01-03 00:00:00 ┆ 3 ┆ 3/10 │ + │ 2 ┆ FOO//D//4 ┆ 2021-01-04 00:00:00 ┆ 4 ┆ 4.24 │ + └────────────┴───────────┴─────────────────────┴───────────────┴────────────┘ >>> data_with_nulls = pl.DataFrame({ ... "subject_id": [1, 1, 2, 2], ... "code": ["A", None, "C", "D"], @@ -484,6 +490,20 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy event_exprs[k] = col + has_numeric = "numeric_value" in event_exprs + has_text = "text_value" in event_exprs + + if do_dedup_text_and_numeric and has_numeric and has_text: + text_expr = event_exprs["text_value"] + num_expr = event_exprs["numeric_value"] + event_exprs["text_value"] = ( + pl.when(text_expr.cast(pl.Float32, strict=False) == num_expr.cast(pl.Float32)) + .then(pl.lit(None, pl.String)) + .otherwise(text_expr) + ) + + if "numeric_value" in event_exprs and "text_value" in event_exprs: + if code_null_filter_expr is not None: logger.info(f"Filtering out rows with null codes via {code_null_filter_expr}") df = df.filter(code_null_filter_expr) @@ -656,7 +676,9 @@ def convert_to_events( for event_name, event_cfg in event_cfgs.items(): try: logger.info(f"Building computational graph for extracting {event_name}") - event_dfs.append(extract_event(df, event_cfg)) + event_dfs.append(extract_event( + df, event_cfg, do_dedup_text_and_numeric=do_dedup_text_and_numeric, + )) except Exception as e: raise ValueError(f"Error extracting event {event_name}: {e}") from e @@ -731,6 +753,7 @@ def compute_fn(df: pl.LazyFrame) -> pl.LazyFrame: return convert_to_events( df.filter(pl.col("subject_id").is_in(typed_subjects)), event_cfgs=copy.deepcopy(event_cfgs), + do_dedup_text_and_numeric=cfg.stage_cfg.get("do_dedup_text_and_numeric", False), ) except Exception as e: raise ValueError( From dd167335c8094510c7af217c0f2de903aedf64a8 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 20:40:03 -0400 Subject: [PATCH 51/62] Fixed lint errors. --- .../extract/convert_to_sharded_events.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/MEDS_transforms/extract/convert_to_sharded_events.py b/src/MEDS_transforms/extract/convert_to_sharded_events.py index fb8794b9..f4f8cc08 100755 --- a/src/MEDS_transforms/extract/convert_to_sharded_events.py +++ b/src/MEDS_transforms/extract/convert_to_sharded_events.py @@ -94,7 +94,9 @@ def get_code_expr(code_field: str | list | ListConfig) -> tuple[pl.Expr, pl.Expr def extract_event( - df: pl.LazyFrame, event_cfg: dict[str, str | None], do_dedup_text_and_numeric: bool = False, + df: pl.LazyFrame, + event_cfg: dict[str, str | None], + do_dedup_text_and_numeric: bool = False, ) -> pl.LazyFrame: """Extracts a single event dataframe from the raw data. @@ -502,8 +504,6 @@ def extract_event( .otherwise(text_expr) ) - if "numeric_value" in event_exprs and "text_value" in event_exprs: - if code_null_filter_expr is not None: logger.info(f"Filtering out rows with null codes via {code_null_filter_expr}") df = df.filter(code_null_filter_expr) @@ -517,7 +517,9 @@ def extract_event( def convert_to_events( - df: pl.LazyFrame, event_cfgs: dict[str, dict[str, str | None | Sequence[str]]] + df: pl.LazyFrame, + event_cfgs: dict[str, dict[str, str | None | Sequence[str]]], + do_dedup_text_and_numeric: bool = False, ) -> pl.LazyFrame: """Converts a DataFrame of raw data into a DataFrame of events. @@ -676,9 +678,13 @@ def convert_to_events( for event_name, event_cfg in event_cfgs.items(): try: logger.info(f"Building computational graph for extracting {event_name}") - event_dfs.append(extract_event( - df, event_cfg, do_dedup_text_and_numeric=do_dedup_text_and_numeric, - )) + event_dfs.append( + extract_event( + df, + event_cfg, + do_dedup_text_and_numeric=do_dedup_text_and_numeric, + ) + ) except Exception as e: raise ValueError(f"Error extracting event {event_name}: {e}") from e From 0c072bfe628973161118de292b8d194423d69401 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 20:48:47 -0400 Subject: [PATCH 52/62] Fixed a typ on the convert tests. --- src/MEDS_transforms/extract/convert_to_sharded_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MEDS_transforms/extract/convert_to_sharded_events.py b/src/MEDS_transforms/extract/convert_to_sharded_events.py index f4f8cc08..39aea54f 100755 --- a/src/MEDS_transforms/extract/convert_to_sharded_events.py +++ b/src/MEDS_transforms/extract/convert_to_sharded_events.py @@ -166,7 +166,7 @@ def extract_event( ... "text_value": "woo_text", ... } >>> extract_event(raw_data, event_cfg, do_dedup_text_and_numeric=True) - shape: (4, 4) + shape: (4, 5) ┌────────────┬───────────┬─────────────────────┬───────────────┬────────────┐ │ subject_id ┆ code ┆ time ┆ numeric_value ┆ text_value │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ From 4ce9bb9e419f5e6b7a91a05b07d9d448326d1337 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 20:49:18 -0400 Subject: [PATCH 53/62] Updated params for MIMIC --- MIMIC-IV_Example/slurm_runner.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/MIMIC-IV_Example/slurm_runner.yaml b/MIMIC-IV_Example/slurm_runner.yaml index a8e5979b..47d09e77 100644 --- a/MIMIC-IV_Example/slurm_runner.yaml +++ b/MIMIC-IV_Example/slurm_runner.yaml @@ -45,17 +45,17 @@ extract_code_metadata: finalize_MEDS_metadata: parallelize: + n_workers: 1 launcher_params: - timeout_min: 60 - cpus_per_task: 10 + timeout_min: 10 + cpus_per_task: 5 mem_gb: 10 partition: "short" finalize_MEDS_data: parallelize: - n_workers: 1 launcher_params: - timeout_min: 15 - cpus_per_task: 5 - mem_gb: 10 + timeout_min: 10 + cpus_per_task: 10 + mem_gb: 25 partition: "short" From 616ecb81e66a774fcab174ac685020b6fb8a63d4 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 21:02:56 -0400 Subject: [PATCH 54/62] Updated parameters. MIMIC pipeline runs end to end with these parameters. Need to re-test after modifications to convert_to_sharded_events are finalized still. --- MIMIC-IV_Example/slurm_runner.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MIMIC-IV_Example/slurm_runner.yaml b/MIMIC-IV_Example/slurm_runner.yaml index 47d09e77..4dbed261 100644 --- a/MIMIC-IV_Example/slurm_runner.yaml +++ b/MIMIC-IV_Example/slurm_runner.yaml @@ -57,5 +57,5 @@ finalize_MEDS_data: launcher_params: timeout_min: 10 cpus_per_task: 10 - mem_gb: 25 + mem_gb: 70 partition: "short" From a88b4c9a3e8d7e31a61ea22f590d42aaf06e8608 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 21:03:25 -0400 Subject: [PATCH 55/62] Added integration test for sharded events with deduplication. --- .../test_convert_to_sharded_events.py | 243 +++++++++++++----- 1 file changed, 185 insertions(+), 58 deletions(-) diff --git a/tests/MEDS_Extract/test_convert_to_sharded_events.py b/tests/MEDS_Extract/test_convert_to_sharded_events.py index 6d8311d2..653ce737 100644 --- a/tests/MEDS_Extract/test_convert_to_sharded_events.py +++ b/tests/MEDS_Extract/test_convert_to_sharded_events.py @@ -89,6 +89,7 @@ time: col(vitals_date) time_format: "%m/%d/%Y, %H:%M:%S" numeric_value: temp + text_value: temp _metadata: input_metadata: description: {"title": {"lab_code": "temp"}} @@ -102,12 +103,12 @@ "held_out/0": [1500733], } -WANT_OUTPUTS = parse_shards_yaml( +WANT_OUTPUTS_NO_DEDUP = parse_shards_yaml( """ data/train/0/subjects/[0-6).parquet: |-2 subject_id,time,code,numeric_value 239684,,EYE_COLOR//BROWN, - 239684,,HEIGHT,175.271115221764 + 239684,,HEIGHT,175.271115221765 239684,"12/28/1980, 00:00:00",DOB, 1195293,,EYE_COLOR//BLUE, 1195293,,HEIGHT,164.6868838269085 @@ -135,76 +136,182 @@ 1500733,"07/20/1986, 00:00:00",DOB, data/train/0/admit_vitals/[0-10).parquet: |-2 - subject_id,time,code,numeric_value - 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, - 239684,"05/11/2010, 17:41:51",HR,102.6 - 239684,"05/11/2010, 17:41:51",TEMP,96.0 - 239684,"05/11/2010, 17:48:48",HR,105.1 - 239684,"05/11/2010, 17:48:48",TEMP,96.2 - 239684,"05/11/2010, 18:25:35",HR,113.4 - 239684,"05/11/2010, 18:25:35",TEMP,95.8 - 239684,"05/11/2010, 18:57:18",HR,112.6 - 239684,"05/11/2010, 18:57:18",TEMP,95.5 - 239684,"05/11/2010, 19:27:19",DISCHARGE, - 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, - 1195293,"06/20/2010, 19:25:32",HR,114.1 - 1195293,"06/20/2010, 19:25:32",TEMP,100.0 - 1195293,"06/20/2010, 20:12:31",HR,112.5 - 1195293,"06/20/2010, 20:12:31",TEMP,99.8 - 1195293,"06/20/2010, 20:50:04",DISCHARGE, + subject_id,time,code,numeric_value,text_value + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC,, + 239684,"05/11/2010, 17:41:51",HR,102.6, + 239684,"05/11/2010, 17:41:51",TEMP,96.0,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1, + 239684,"05/11/2010, 17:48:48",TEMP,96.2,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4, + 239684,"05/11/2010, 18:25:35",TEMP,95.8,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6, + 239684,"05/11/2010, 18:57:18",TEMP,95.5,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE,, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC,, + 1195293,"06/20/2010, 19:25:32",HR,114.1, + 1195293,"06/20/2010, 19:25:32",TEMP,100.0,100.0 + 1195293,"06/20/2010, 20:12:31",HR,112.5, + 1195293,"06/20/2010, 20:12:31",TEMP,99.8,99.8 + 1195293,"06/20/2010, 20:50:04",DISCHARGE,, data/train/0/admit_vitals/[10-16).parquet: |-2 - subject_id,time,code,numeric_value - 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, - 1195293,"06/20/2010, 19:23:52",HR,109.0 - 1195293,"06/20/2010, 19:23:52",TEMP,100.0 - 1195293,"06/20/2010, 19:45:19",HR,119.8 - 1195293,"06/20/2010, 19:45:19",TEMP,99.9 - 1195293,"06/20/2010, 20:24:44",HR,107.7 - 1195293,"06/20/2010, 20:24:44",TEMP,100.0 - 1195293,"06/20/2010, 20:41:33",HR,107.5 - 1195293,"06/20/2010, 20:41:33",TEMP,100.4 - 1195293,"06/20/2010, 20:50:04",DISCHARGE, + subject_id,time,code,numeric_value,text_value + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC,, + 1195293,"06/20/2010, 19:23:52",HR,109.0, + 1195293,"06/20/2010, 19:23:52",TEMP,100.0,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8, + 1195293,"06/20/2010, 19:45:19",TEMP,99.9,99.9 + 1195293,"06/20/2010, 20:24:44",HR,107.7, + 1195293,"06/20/2010, 20:24:44",TEMP,100.0,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5, + 1195293,"06/20/2010, 20:41:33",TEMP,100.4,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE,, data/train/1/admit_vitals/[0-10).parquet: |-2 - subject_id,time,code,numeric_value - 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, - 68729,"05/26/2010, 02:30:56",HR,86.0 - 68729,"05/26/2010, 02:30:56",TEMP,97.8 - 68729,"05/26/2010, 04:51:52",DISCHARGE, - 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, - 814703,"02/05/2010, 05:55:39",HR,170.2 - 814703,"02/05/2010, 05:55:39",TEMP,100.1 - 814703,"02/05/2010, 07:02:30",DISCHARGE, + subject_id,time,code,numeric_value,text_value + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY,, + 68729,"05/26/2010, 02:30:56",HR,86.0, + 68729,"05/26/2010, 02:30:56",TEMP,97.8,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE,, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC,, + 814703,"02/05/2010, 05:55:39",HR,170.2, + 814703,"02/05/2010, 05:55:39",TEMP,100.1,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE,, data/train/1/admit_vitals/[10-16).parquet: |-2 - subject_id,time,code,numeric_value + subject_id,time,code,numeric_value,text_value data/tuning/0/admit_vitals/[0-10).parquet: |-2 - subject_id,time,code,numeric_value - 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, - 754281,"01/03/2010, 06:27:59",HR,142.0 - 754281,"01/03/2010, 06:27:59",TEMP,99.8 - 754281,"01/03/2010, 08:22:13",DISCHARGE, + subject_id,time,code,numeric_value,text_value + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY,, + 754281,"01/03/2010, 06:27:59",HR,142.0, + 754281,"01/03/2010, 06:27:59",TEMP,99.8,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE,, data/tuning/0/admit_vitals/[10-16).parquet: |-2 - subject_id,time,code,numeric_value + subject_id,time,code,numeric_value,text_value data/held_out/0/admit_vitals/[0-10).parquet: |-2 - subject_id,time,code,numeric_value - 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, - 1500733,"06/03/2010, 16:20:49",HR,90.1 - 1500733,"06/03/2010, 16:20:49",TEMP,100.1 - 1500733,"06/03/2010, 16:44:26",DISCHARGE, + subject_id,time,code,numeric_value,text_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC,, + 1500733,"06/03/2010, 16:20:49",HR,90.1, + 1500733,"06/03/2010, 16:20:49",TEMP,100.1,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE,, data/held_out/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value,text_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC,, + 1500733,"06/03/2010, 14:54:38",HR,91.4, + 1500733,"06/03/2010, 14:54:38",TEMP,100.0,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4, + 1500733,"06/03/2010, 15:39:49",TEMP,100.3,100.3 + 1500733,"06/03/2010, 16:44:26",DISCHARGE,, + """ +) + +WANT_OUTPUTS = parse_shards_yaml( + """ +data/train/0/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221765 + 239684,"12/28/1980, 00:00:00",DOB, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + +data/train/1/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + +data/tuning/0/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + +data/held_out/0/subjects/[0-6).parquet: |-2 subject_id,time,code,numeric_value - 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, - 1500733,"06/03/2010, 14:54:38",HR,91.4 - 1500733,"06/03/2010, 14:54:38",TEMP,100.0 - 1500733,"06/03/2010, 15:39:49",HR,84.4 - 1500733,"06/03/2010, 15:39:49",TEMP,100.3 - 1500733,"06/03/2010, 16:44:26",DISCHARGE, + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + +data/train/0/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value,text_value + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC,, + 239684,"05/11/2010, 17:41:51",HR,102.6, + 239684,"05/11/2010, 17:41:51",TEMP,96.0, + 239684,"05/11/2010, 17:48:48",HR,105.1, + 239684,"05/11/2010, 17:48:48",TEMP,96.2, + 239684,"05/11/2010, 18:25:35",HR,113.4, + 239684,"05/11/2010, 18:25:35",TEMP,95.8, + 239684,"05/11/2010, 18:57:18",HR,112.6, + 239684,"05/11/2010, 18:57:18",TEMP,95.5, + 239684,"05/11/2010, 19:27:19",DISCHARGE,, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC,, + 1195293,"06/20/2010, 19:25:32",HR,114.1, + 1195293,"06/20/2010, 19:25:32",TEMP,100.0, + 1195293,"06/20/2010, 20:12:31",HR,112.5, + 1195293,"06/20/2010, 20:12:31",TEMP,99.8, + 1195293,"06/20/2010, 20:50:04",DISCHARGE,, + +data/train/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value,text_value + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC,, + 1195293,"06/20/2010, 19:23:52",HR,109.0, + 1195293,"06/20/2010, 19:23:52",TEMP,100.0, + 1195293,"06/20/2010, 19:45:19",HR,119.8, + 1195293,"06/20/2010, 19:45:19",TEMP,99.9, + 1195293,"06/20/2010, 20:24:44",HR,107.7, + 1195293,"06/20/2010, 20:24:44",TEMP,100.0, + 1195293,"06/20/2010, 20:41:33",HR,107.5, + 1195293,"06/20/2010, 20:41:33",TEMP,100.4, + 1195293,"06/20/2010, 20:50:04",DISCHARGE,, + +data/train/1/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value,text_value + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY,, + 68729,"05/26/2010, 02:30:56",HR,86.0, + 68729,"05/26/2010, 02:30:56",TEMP,97.8, + 68729,"05/26/2010, 04:51:52",DISCHARGE,, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC,, + 814703,"02/05/2010, 05:55:39",HR,170.2, + 814703,"02/05/2010, 05:55:39",TEMP,100.1, + 814703,"02/05/2010, 07:02:30",DISCHARGE,, + +data/train/1/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value,text_value + +data/tuning/0/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value,text_value + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY,, + 754281,"01/03/2010, 06:27:59",HR,142.0, + 754281,"01/03/2010, 06:27:59",TEMP,99.8, + 754281,"01/03/2010, 08:22:13",DISCHARGE,, + +data/tuning/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value,text_value + +data/held_out/0/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value,text_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC,, + 1500733,"06/03/2010, 16:20:49",HR,90.1, + 1500733,"06/03/2010, 16:20:49",TEMP,100.1, + 1500733,"06/03/2010, 16:44:26",DISCHARGE,, + +data/held_out/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value,text_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC,, + 1500733,"06/03/2010, 14:54:38",HR,91.4, + 1500733,"06/03/2010, 14:54:38",TEMP,100.0, + 1500733,"06/03/2010, 15:39:49",HR,84.4, + 1500733,"06/03/2010, 15:39:49",TEMP,100.3, + 1500733,"06/03/2010, 16:44:26",DISCHARGE,, """ ) @@ -213,7 +320,7 @@ def test_convert_to_sharded_events(): single_stage_tester( script=CONVERT_TO_SHARDED_EVENTS_SCRIPT, stage_name="convert_to_sharded_events", - stage_kwargs=None, + stage_kwargs={"do_dedup_text_and_numeric": True}, config_name="extract", input_files={ "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), @@ -225,5 +332,25 @@ def test_convert_to_sharded_events(): event_conversion_config_fp="{input_dir}/event_cfgs.yaml", shards_map_fp="{input_dir}/metadata/.shards.json", want_outputs=WANT_OUTPUTS, + test_name="Stage tester: convert_to_sharded_events ; with dedup", + df_check_kwargs={"check_row_order": False, "check_column_order": False, "check_dtypes": False}, + ) + + single_stage_tester( + script=CONVERT_TO_SHARDED_EVENTS_SCRIPT, + stage_name="convert_to_sharded_events", + stage_kwargs={"do_dedup_text_and_numeric": False}, + config_name="extract", + input_files={ + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_0_10_CSV)), + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), + "event_cfgs.yaml": EVENT_CFGS_YAML, + "metadata/.shards.json": SHARDS_JSON, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + shards_map_fp="{input_dir}/metadata/.shards.json", + want_outputs=WANT_OUTPUTS_NO_DEDUP, + test_name="Stage tester: convert_to_sharded_events ; no dedup", df_check_kwargs={"check_row_order": False, "check_column_order": False, "check_dtypes": False}, ) From be6db594bcbf6a5a75cf4d048934d5c6106c8245 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 21:52:04 -0400 Subject: [PATCH 56/62] Upgraded polars and set infer schema to false for metadata extraction. --- pyproject.toml | 2 +- src/MEDS_transforms/extract/extract_code_metadata.py | 2 +- src/MEDS_transforms/parser.py | 12 ------------ 3 files changed, 2 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a9b32925..ef352990 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "polars~=1.1.0", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy", "meds==0.3.3", + "polars~=1.6.0", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy", "meds==0.3.3", ] [tool.setuptools_scm] diff --git a/src/MEDS_transforms/extract/extract_code_metadata.py b/src/MEDS_transforms/extract/extract_code_metadata.py index 31d883c4..3460cfbd 100644 --- a/src/MEDS_transforms/extract/extract_code_metadata.py +++ b/src/MEDS_transforms/extract/extract_code_metadata.py @@ -386,7 +386,7 @@ def main(cfg: DictConfig): metadata_fp, read_fn = get_supported_fp(raw_input_dir, input_prefix) if metadata_fp.suffix != ".parquet": - read_fn = partial(read_fn, infer_schema_length=999999999) + read_fn = partial(read_fn, infer_schema=False) out_fp = partial_metadata_dir / f"{input_prefix}.parquet" logger.info(f"Extracting metadata from {metadata_fp} and saving to {out_fp}") diff --git a/src/MEDS_transforms/parser.py b/src/MEDS_transforms/parser.py index 948ca003..3b663f7f 100644 --- a/src/MEDS_transforms/parser.py +++ b/src/MEDS_transforms/parser.py @@ -596,18 +596,6 @@ def cfg_to_expr(cfg: str | ListConfig | DictConfig) -> tuple[pl.Expr, set[str]]: ['34.2', 'bar//2', '34.2'] >>> sorted(cols) ['baz'] - - Note that sometimes coalescing can lead to unexpected results. For example, if the first expression is of - a different type than the second, the second expression may have its type coerced to match the first, - potentially in an unexpected manner. This is also related to some polars, bugs, such as - https://github.com/pola-rs/polars/issues/17773 - >>> cfg = [ - ... {"matcher": {"baz": 2}, "output": {"str": "bar//{baz}"}}, - ... {"literal": 34.8218}, - ... ] - >>> expr, cols = cfg_to_expr(cfg) - >>> data.select(expr.alias("out"))["out"].to_list() - ['34', 'bar//2', '34'] """ structured_expr = parse_col_expr(cfg) return structured_expr_to_pl(structured_expr) From a069965586e6b0b9a0c9c23050229cdea84de9e2 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 21:55:57 -0400 Subject: [PATCH 57/62] Need to use strings with this change. --- MIMIC-IV_Example/configs/event_configs.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MIMIC-IV_Example/configs/event_configs.yaml b/MIMIC-IV_Example/configs/event_configs.yaml index 666bdd32..9d23bb64 100644 --- a/MIMIC-IV_Example/configs/event_configs.yaml +++ b/MIMIC-IV_Example/configs/event_configs.yaml @@ -165,8 +165,8 @@ hosp/procedures_icd: hosp/d_icd_procedures: description: "long_title" parent_codes: # List of objects are string labels mapping to filters to be evaluated. - - "ICD{icd_version}Proc/{icd_code}": { icd_version: 9 } - - "ICD{icd_version}PCS/{icd_code}": { icd_version: 10 } + - "ICD{icd_version}Proc/{icd_code}": { icd_version: "9" } + - "ICD{icd_version}PCS/{icd_code}": { icd_version: "10" } hosp/transfers: transfer: From c71b82c24194e49ac1ba54c99671d4509551eb70 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 22:16:16 -0400 Subject: [PATCH 58/62] Attempted change to pre-MEDS to get it to work. --- MIMIC-IV_Example/pre_MEDS.py | 173 ++++++++++++++++++++++++++++++++++- 1 file changed, 172 insertions(+), 1 deletion(-) diff --git a/MIMIC-IV_Example/pre_MEDS.py b/MIMIC-IV_Example/pre_MEDS.py index 6007c64a..1dd6694b 100755 --- a/MIMIC-IV_Example/pre_MEDS.py +++ b/MIMIC-IV_Example/pre_MEDS.py @@ -15,6 +15,156 @@ from MEDS_transforms.utils import get_shard_prefix, hydra_loguru_init, write_lazyframe +def add_dot(code: pl.Expr, position: int) -> pl.Expr: + """Adds a dot to the code expression at the specified position. + + Args: + code: The code expression. + position: The position to add the dot. + + Returns: + The expression which would yield the code string with a dot added at the specified position + + Example: + >>> pl.select(add_dot(pl.lit("12345"), 3)) + shape: (1, 1) + ┌─────────┐ + │ literal │ + │ --- │ + │ str │ + ╞═════════╡ + │ 123.45 │ + └─────────┘ + >>> pl.select(add_dot(pl.lit("12345"), 1)) + shape: (1, 1) + ┌─────────┐ + │ literal │ + │ --- │ + │ str │ + ╞═════════╡ + │ 1.2345 │ + └─────────┘ + >>> pl.select(add_dot(pl.lit("12345"), 6)) + shape: (1, 1) + ┌─────────┐ + │ literal │ + │ --- │ + │ str │ + ╞═════════╡ + │ 12345 │ + └─────────┘ + """ + return ( + pl.when(code.str.len_chars() > position) + .then(code.str.slice(0, position) + "." + code.str.slice(position)) + .otherwise(code) + ) + + +def add_icd_diagnosis_dot(icd_version: pl.Expr, icd_code: pl.Expr) -> pl.Expr: + """Adds the appropriate dot to the ICD diagnosis codebased on the version. + + Args: + icd_version: The ICD version. + icd_code: The ICD code. + + Returns: + The ICD code with appropriate dot syntax based on the version. + + Examples: + >>> pl.select(add_icd_diagnosis_dot(pl.lit("9"), pl.lit("12345"))) + shape: (1, 1) + ┌─────────┐ + │ literal │ + │ --- │ + │ str │ + ╞═════════╡ + │ 123.45 │ + └─────────┘ + >>> pl.select(add_icd_diagnosis_dot(pl.lit("9"), pl.lit("E1234"))) + shape: (1, 1) + ┌─────────┐ + │ literal │ + │ --- │ + │ str │ + ╞═════════╡ + │ E123.4 │ + └─────────┘ + >>> pl.select(add_icd_diagnosis_dot(pl.lit("9"), pl.lit("F1234"))) + shape: (1, 1) + ┌─────────┐ + │ literal │ + │ --- │ + │ str │ + ╞═════════╡ + │ F12.34 │ + └─────────┘ + >>> pl.select(add_icd_diagnosis_dot(pl.lit("10"), pl.lit("12345"))) + shape: (1, 1) + ┌─────────┐ + │ literal │ + │ --- │ + │ str │ + ╞═════════╡ + │ 123.45 │ + └─────────┘ + >>> pl.select(add_icd_diagnosis_dot(pl.lit("10"), pl.lit("E1234"))) + shape: (1, 1) + ┌─────────┐ + │ literal │ + │ --- │ + │ str │ + ╞═════════╡ + │ E12.34 │ + └─────────┘ + """ + + icd9_code = ( + pl.when(icd_code.str.starts_with("E")).then(add_dot(icd_code, 4)).otherwise(add_dot(icd_code, 3)) + ) + + icd10_code = add_dot(icd_code, 3) + + return pl.when(icd_version == "9").then(icd9_code).otherwise(icd10_code) + + +def add_icd_procedure_dot(icd_version: pl.Expr, icd_code: pl.Expr) -> pl.Expr: + """Adds the appropriate dot to the ICD procedure code based on the version. + + Args: + icd_version: The ICD version. + icd_code: The ICD code. + + Returns: + The ICD code with appropriate dot syntax based on the version. + + Examples: + >>> pl.select(add_icd_procedure_dot(pl.lit("9"), pl.lit("12345"))) + shape: (1, 1) + ┌─────────┐ + │ literal │ + │ --- │ + │ str │ + ╞═════════╡ + │ 12.345 │ + └─────────┘ + >>> pl.select(add_icd_procedure_dot(pl.lit("10"), pl.lit("12345"))) + shape: (1, 1) + ┌─────────┐ + │ literal │ + │ --- │ + │ str │ + ╞═════════╡ + │ 12345 │ + └─────────┘ + """ + + icd9_code = add_dot(icd_code, 2) + icd10_code = icd_code + + return pl.when(icd_version == "9").then(icd9_code).otherwise(icd10_code) + + def add_discharge_time_by_hadm_id( df: pl.LazyFrame, discharge_time_df: pl.LazyFrame, out_column_name: str = "hadm_discharge_time" ) -> pl.LazyFrame: @@ -51,6 +201,11 @@ def fix_static_data(raw_static_df: pl.LazyFrame, death_times_df: pl.LazyFrame) - "hosp/patients": (fix_static_data, ("hosp/admissions", ["subject_id", "deathtime"])), } +ICD_DFS_TO_FIX = [ + ("hosp/d_icd_diagnoses", add_icd_diagnosis_dot), + ("hosp/d_icd_procedures", add_icd_procedure_dot), +] + @hydra.main(version_base=None, config_path="configs", config_name="pre_MEDS") def main(cfg: DictConfig): @@ -104,7 +259,7 @@ def main(cfg: DictConfig): out_fp.parent.mkdir(parents=True, exist_ok=True) - if pfx not in FUNCTIONS: + if pfx not in FUNCTIONS and pfx not in [p for p, _ in ICD_DFS_TO_FIX]: logger.info( f"No function needed for {pfx}: " f"Symlinking {str(fp.resolve())} to {str(out_fp.resolve())}" ) @@ -164,6 +319,22 @@ def main(cfg: DictConfig): write_lazyframe(processed_df, out_fp) logger.info(f" Processed and wrote to {str(out_fp.resolve())} in {datetime.now() - fp_st}") + for pfx, fn in ICD_DFS_TO_FIX: + fp, read_fn = get_supported_fp(input_dir, pfx) + out_fp = MEDS_input_dir / f"{pfx}.parquet" + + if out_fp.is_file(): + print(f"Done with {pfx}. Continuing") + continue + + st = datetime.now() + logger.info(f"Processing {pfx}...") + processed_df = read_fn(fp).with_columns( + fn(pl.col("icd_version"), pl.col("icd_code")).alias("icd_code") + ) + write_lazyframe(processed_df, out_fp) + logger.info(f" Processed and wrote to {str(out_fp.resolve())} in {datetime.now() - st}") + logger.info(f"Done! All dataframes processed and written to {str(MEDS_input_dir.resolve())}") done_fp.write_text(f"Finished at {datetime.now()}") From 4ab14268e93422161389d0c3e1357457cec013a5 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 22:19:04 -0400 Subject: [PATCH 59/62] Minor correction. --- MIMIC-IV_Example/pre_MEDS.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MIMIC-IV_Example/pre_MEDS.py b/MIMIC-IV_Example/pre_MEDS.py index 1dd6694b..ee9ffb5d 100755 --- a/MIMIC-IV_Example/pre_MEDS.py +++ b/MIMIC-IV_Example/pre_MEDS.py @@ -266,7 +266,7 @@ def main(cfg: DictConfig): relative_in_fp = fp.relative_to(out_fp.resolve().parent, walk_up=True) out_fp.symlink_to(relative_in_fp) continue - else: + elif pfx in FUNCTIONS: out_fp = MEDS_input_dir / f"{pfx}.parquet" if out_fp.is_file(): print(f"Done with {pfx}. Continuing") From 91bb00c1a3847347bc1711be0fdd9a2388f156e6 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 22:24:12 -0400 Subject: [PATCH 60/62] Got ICD code normalization running; yet to be validated in practice. --- MIMIC-IV_Example/pre_MEDS.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/MIMIC-IV_Example/pre_MEDS.py b/MIMIC-IV_Example/pre_MEDS.py index ee9ffb5d..7521a18e 100755 --- a/MIMIC-IV_Example/pre_MEDS.py +++ b/MIMIC-IV_Example/pre_MEDS.py @@ -327,12 +327,15 @@ def main(cfg: DictConfig): print(f"Done with {pfx}. Continuing") continue + if fp.suffix != ".parquet": + read_fn = partial(read_fn, infer_schema=False) + st = datetime.now() logger.info(f"Processing {pfx}...") - processed_df = read_fn(fp).with_columns( - fn(pl.col("icd_version"), pl.col("icd_code")).alias("icd_code") + processed_df = read_fn(fp).collect().with_columns( + fn(pl.col("icd_version").cast(pl.String), pl.col("icd_code").cast(pl.String)).alias("icd_code") ) - write_lazyframe(processed_df, out_fp) + processed_df.write_parquet(out_fp, use_pyarrow=True) logger.info(f" Processed and wrote to {str(out_fp.resolve())} in {datetime.now() - st}") logger.info(f"Done! All dataframes processed and written to {str(MEDS_input_dir.resolve())}") From b5cf0cc26edd75d4bcccb9605b8d4c7355c74cf1 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 29 Aug 2024 22:35:14 -0400 Subject: [PATCH 61/62] Added a hacky solution to make sure that both the raw code for joining and the norm code for parent code are present. --- MIMIC-IV_Example/configs/event_configs.yaml | 6 +++--- MIMIC-IV_Example/pre_MEDS.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/MIMIC-IV_Example/configs/event_configs.yaml b/MIMIC-IV_Example/configs/event_configs.yaml index 9d23bb64..0d67a6c6 100644 --- a/MIMIC-IV_Example/configs/event_configs.yaml +++ b/MIMIC-IV_Example/configs/event_configs.yaml @@ -42,7 +42,7 @@ hosp/diagnoses_icd: _metadata: hosp/d_icd_diagnoses: description: "long_title" - parent_codes: "ICD{icd_version}CM/{icd_code}" # Single strings are templates of columns. + parent_codes: "ICD{icd_version}CM/{norm_icd_code}" # Single strings are templates of columns. hosp/drgcodes: drg: @@ -165,8 +165,8 @@ hosp/procedures_icd: hosp/d_icd_procedures: description: "long_title" parent_codes: # List of objects are string labels mapping to filters to be evaluated. - - "ICD{icd_version}Proc/{icd_code}": { icd_version: "9" } - - "ICD{icd_version}PCS/{icd_code}": { icd_version: "10" } + - "ICD{icd_version}Proc/{norm_icd_code}": { icd_version: "9" } + - "ICD{icd_version}PCS/{norm_icd_code}": { icd_version: "10" } hosp/transfers: transfer: diff --git a/MIMIC-IV_Example/pre_MEDS.py b/MIMIC-IV_Example/pre_MEDS.py index 7521a18e..846c3a9d 100755 --- a/MIMIC-IV_Example/pre_MEDS.py +++ b/MIMIC-IV_Example/pre_MEDS.py @@ -332,8 +332,14 @@ def main(cfg: DictConfig): st = datetime.now() logger.info(f"Processing {pfx}...") - processed_df = read_fn(fp).collect().with_columns( - fn(pl.col("icd_version").cast(pl.String), pl.col("icd_code").cast(pl.String)).alias("icd_code") + processed_df = ( + read_fn(fp) + .collect() + .with_columns( + fn(pl.col("icd_version").cast(pl.String), pl.col("icd_code").cast(pl.String)).alias( + "norm_icd_code" + ) + ) ) processed_df.write_parquet(out_fp, use_pyarrow=True) logger.info(f" Processed and wrote to {str(out_fp.resolve())} in {datetime.now() - st}") From 351976923ccdc7dce3f6d52fd4be470a26b2655c Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 1 Sep 2024 13:39:50 -0400 Subject: [PATCH 62/62] Removed unnecessary block. --- tests/MEDS_Transforms/transform_tester_base.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/MEDS_Transforms/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py index ac8a195d..7a26c855 100644 --- a/tests/MEDS_Transforms/transform_tester_base.py +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -5,11 +5,6 @@ """ -try: - pass -except ImportError: - pass - from collections import defaultdict from io import StringIO from pathlib import Path