Skip to content

Commit

Permalink
Added a new, incomplete MEDS test with different data and configs.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Aug 23, 2024
1 parent a21c1ec commit ded40da
Showing 1 changed file with 331 additions and 0 deletions.
331 changes: 331 additions & 0 deletions tests/test_other_meds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@
"""Tests the full end-to-end extraction process."""


import rootutils

root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True)

from io import StringIO

import polars as pl
import pyarrow as pa
from loguru import logger
from meds import label_schema
from yaml import load as load_yaml

from .utils import cli_test

try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader

pl.enable_string_cache()

TS_FORMAT = "%m/%d/%Y %H:%M"
PRED_CNT_TYPE = pl.Int64
EVENT_INDEX_TYPE = pl.UInt64
ANY_EVENT_COLUMN = "_ANY_EVENT"
LAST_EVENT_INDEX_COLUMN = "_LAST_EVENT_INDEX"

DEFAULT_CSV_TS_FORMAT = "%m/%d/%Y %H:%M"

# TODO: Make use meds library
MEDS_PL_SCHEMA = {
"patient_id": pl.UInt32,
"time": pl.Datetime("us"),
"code": pl.Utf8,
"numeric_value": pl.Float32,
"numeric_value/is_inlier": pl.Boolean,
}


MEDS_LABEL_MANDATORY_TYPES = {
"patient_id": pl.Int64,
}

MEDS_LABEL_OPTIONAL_TYPES = {
"boolean_value": pl.Boolean,
"integer_value": pl.Int64,
"float_value": pl.Float64,
"categorical_value": pl.String,
"prediction_time": pl.Datetime("us"),
}


def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
"""Validates the schema of a MEDS data DataFrame.
This function validates the schema of a MEDS label DataFrame, ensuring that it has the correct columns
and that the columns are of the correct type. This function will:
1. Re-type any of the mandator MEDS column to the appropriate type.
2. Attempt to add the ``numeric_value`` or ``time`` columns if either are missing, and set it to `None`.
It will not attempt to add any other missing columns even if ``do_retype`` is `True` as the other
columns cannot be set to `None`.
Args:
df: The MEDS label DataFrame to validate.
Returns:
pa.Table: The validated MEDS data DataFrame, with columns re-typed as needed.
Raises:
ValueError: if do_retype is False and the MEDS data DataFrame is not schema compliant.
"""

schema = df.schema
if "prediction_time" not in schema:
logger.warning(
"Output DataFrame is missing a 'prediction_time' column. If this is not intentional, add a "
"'index_timestamp' (yes, it should be different) key to the task configuration identifying "
"which window's start or end time to use as the prediction time."
)

errors = []
for col, dtype in MEDS_LABEL_MANDATORY_TYPES.items():
if col in schema and schema[col] != dtype:
df = df.with_columns(pl.col(col).cast(dtype, strict=False))
elif col not in schema:
errors.append(f"MEDS Data DataFrame must have a '{col}' column of type {dtype}.")

if errors:
raise ValueError("\n".join(errors))

for col, dtype in MEDS_LABEL_OPTIONAL_TYPES.items():
if col in schema and schema[col] != dtype:
df = df.with_columns(pl.col(col).cast(dtype, strict=False))
elif col not in schema:
df = df.with_columns(pl.lit(None, dtype=dtype).alias(col))

extra_cols = [
c for c in schema if c not in MEDS_LABEL_MANDATORY_TYPES and c not in MEDS_LABEL_OPTIONAL_TYPES
]
if extra_cols:
err_cols_str = "\n".join(f" - {c}" for c in extra_cols)
logger.warning(
"Output contains columns that are not valid MEDS label columns. For now, we are dropping them.\n"
"If you need these columns, please comment on https://github.com/justin13601/ACES/issues/97\n"
f"Columns:\n{err_cols_str}"
)
df = df.drop(extra_cols)

df = df.select(
"patient_id", "prediction_time", "boolean_value", "integer_value", "float_value", "categorical_value"
)

return df.to_arrow().cast(label_schema)


def parse_meds_csvs(
csvs: str | dict[str, str], schema: dict[str, pl.DataType] = MEDS_PL_SCHEMA
) -> pl.DataFrame | dict[str, pl.DataFrame]:
"""Converts a string or dict of named strings to a MEDS DataFrame by interpreting them as CSVs.
TODO: doctests.
"""

default_read_schema = {**schema}
default_read_schema["time"] = pl.Utf8

def reader(csv_str: str) -> pl.DataFrame:
cols = csv_str.strip().split("\n")[0].split(",")
read_schema = {k: v for k, v in default_read_schema.items() if k in cols}
return pl.read_csv(StringIO(csv_str), schema=read_schema).with_columns(
pl.col("time").str.strptime(MEDS_PL_SCHEMA["time"], DEFAULT_CSV_TS_FORMAT)
)

if isinstance(csvs, str):
return reader(csvs)
else:
return {k: reader(v) for k, v in csvs.items()}


def parse_shards_yaml(yaml_str: str, **schema_updates) -> dict[str, pl.DataFrame]:
schema = {**MEDS_PL_SCHEMA, **schema_updates}
return parse_meds_csvs(load_yaml(yaml_str, Loader=Loader), schema=schema)


def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
dfs = {}
for k, v in load_yaml(yaml_str, Loader=Loader).items():
dfs[k] = pl.from_arrow(
get_and_validate_label_schema(
pl.read_csv(StringIO(v)).with_columns(
pl.col("prediction_time").str.strptime(pl.Datetime("us"), "%m/%d/%Y %H:%M")
)
)
)
return dfs


# Data (input)
MEDS_SHARDS = parse_shards_yaml(
"""
"0": |-
patient_id,time,code,numeric_value,text_value
1,,GENDER//MALE,,
1,,SNP//rs234567,,
1,12/18/60 11:03,MEDS_BIRTH,,
1,8/2/72 10:00,CLINIC_VISIT,,
1,8/2/72 10:00,ICD9CM//493.90,,
1,8/2/72 10:00,LOINC//8310-5,0.65,
1,8/2/72 10:00,VITALS//BP//SYSTOLIC,108,
1,1/14/20 15:14,ADMISSION//MEDICAL,,
1,1/14/20 15:18,VITALS//BP//SYSTOLIC,132,
1,1/14/20 15:18,VITALS//BP//DIASTOLIC,90,
1,1/14/20 15:18,VITALS//HR//BPM,121,
1,1/14/20 15:18,VITALS//WEIGHT//LBS,233.2,
1,1/15/20 10:04,VITALS//BP//SYSTOLIC,126,
1,1/15/20 10:04,VITALS//BP//DIASTOLIC,91,
1,1/15/20 10:04,VITALS//HR//BPM,85,
1,1/16/20 10:11,VITALS//BP//SYSTOLIC,135,
1,1/16/20 10:11,VITALS//BP//DIASTOLIC,88,
1,1/16/20 10:11,VITALS//HR//BPM,79,
1,1/16/20 13:02,LVEF//ECHO,0.24,
1,1/17/20 10:00,ICD9CM//428.9,,
1,1/17/20 10:00,DISCHARGE//HOME,,
1,1/18/22 14:46,ADMISSION//MEDICAL,,
1,1/20/22 15:18,DISCHARGE//HOME_AMA,,
1,1/20/22 8:00,ICD9CM//428.41,,
1,1/20/22 8:00,ICD9CM//451.1,,
1,1/24/22 8:11,ADMISSION//ED,,
1,1/25/22 10:04,VITALS//BP//SYSTOLIC,168,
1,1/25/22 10:04,VITALS//BP//DIASTOLIC,100,
1,1/25/22 10:04,VITALS//HR//BPM,56,
1,2/27/22 1:13,ICD9CM//428.41,,
1,2/27/22 1:13,ICD9CM//410.1,,
1,2/27/22 1:13,DEATH,,
"1": |-2
patient_id,time,code,numeric_value,text_value
3,,GENDER//FEMALE,,
3,,SNP//rs2345291,,
3,,SNP//rs228192,,
3,2/28/82 0:00,MEDS_BIRTH,,
3,1/14/20 15:14,ADMISSION//MEDICAL,,
3,1/14/20 15:18,VITALS//BP//SYSTOLIC,132,
3,1/14/20 15:18,VITALS//BP//DIASTOLIC,90,
3,1/14/20 15:18,VITALS//HR//BPM,121,
3,1/17/20 10:00,ICD9CM//V30.00,,
3,1/17/20 10:00,DISCHARGE//HOME,,
3,1/18/20 18:18,ADMISSION//MEDICAL,,
3,1/20/22 15:18,DISCHARGE//HOME,,
3,3/18/24 16:54,ICD9CM//428.9,,
3,3/18/24 17:11,ADMISSION//SURGICAL,,
3,3/28/24 10:00,DISCHARGE//HOME,,
3,3/29/24 11:00,ADMISSION//SURGICAL,,
3,4/19/24 13:32,DISCHARGE//HOME,,
3,5/22/24 0:00,ICD9CM//428.9,,
""",
text_value=pl.Utf8,
)

# Tasks (input)
TASKS = {
"inhospital_mortality": """
predicates:
admission:
code: {regex: ADMISSION//.*}
discharge:
code: {regex: DISCHARGE//.*}
death:
code: DEATH
discharge_or_death:
expr: or(discharge, death)
trigger: admission
windows:
input:
start: NULL
end: trigger + 24h
start_inclusive: True
end_inclusive: True
has:
_ANY_EVENT: (5, None)
index_timestamp: end
gap:
start: trigger
end: start + 48h
start_inclusive: False
end_inclusive: True
has:
admission: (None, 0)
discharge_or_death: (None, 0)
target:
start: gap.end
end: start -> discharge_or_death
start_inclusive: False
end_inclusive: True
label: death
""",
"HF_derived_readmission": """
predicates:
admission:
code: {regex: ADMISSION//.*}
discharge:
code: {regex: DISCHARGE//.*}
HF_dx:
code: {regex: ICD9CM//428.*}
trigger: discharge
windows:
admission_is_HF:
start: end <- admission
end: trigger
start_inclusive: True
end_inclusive: True
has:
HF_dx: (1, None)
input:
start: NULL
end: trigger
start_inclusive: True
end_inclusive: True
index_timestamp: end
target:
start: input.end
end: start + 30d
start_inclusive: False
end_inclusive: True
label: admission
censor_protection:
start: target.end
end: null
start_inclusive: False
end_inclusive: True
has:
_ANY_EVENT: (1, None)
""",
}

WANT_SHARDS = {
"inhospital_mortality": parse_labels_yaml(
"""
"0": |-2
patient_id,prediction_time,boolean_value,integer_value,float_value,categorical_value
"1": |-2
patient_id,prediction_time,boolean_value,integer_value,float_value,categorical_value
"""
),
"HF_derived_readmission": parse_labels_yaml(
"""
"0": |-2
patient_id,prediction_time,boolean_value,integer_value,float_value,categorical_value
"1": |-2
patient_id,prediction_time,boolean_value,integer_value,float_value,categorical_value
"""
),
}


def test_meds():
cli_test(
input_files=MEDS_SHARDS,
task_configs=TASKS,
want_outputs_by_task=WANT_SHARDS,
data_standard="meds",
)

0 comments on commit ded40da

Please sign in to comment.