Skip to content

Commit

Permalink
#94 nested derived predicates (ex: needed when creating different ref…
Browse files Browse the repository at this point in the history
…erence ranges for male/female)
  • Loading branch information
justin13601 committed Oct 26, 2024
1 parent 13eb15d commit f5b0dbc
Showing 1 changed file with 31 additions and 8 deletions.
39 changes: 31 additions & 8 deletions src/aces/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,7 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta

final_predicates = {**predicates, **overriding_predicates}
final_demographics = {**patient_demographics, **overriding_demographics}
all_predicates = {**final_predicates, **final_demographics}

logger.info("Parsing windows...")
if windows is None:
Expand All @@ -1288,23 +1289,45 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
logger.info("Parsing trigger event...")
trigger = EventConfig(trigger)

# add window referenced predicates
referenced_predicates = {pred for w in windows.values() for pred in w.referenced_predicates}

# add trigger predicate
referenced_predicates.add(trigger.predicate)

# add label predicate if it exists and not already added
label_reference = [w.label for w in windows.values() if w.label]
if label_reference:
referenced_predicates.update(set(label_reference))
current_predicates = set(referenced_predicates)

special_predicates = {ANY_EVENT_COLUMN, START_OF_RECORD_KEY, END_OF_RECORD_KEY}
for pred in current_predicates - special_predicates:
if pred not in final_predicates:
for pred in set(referenced_predicates) - special_predicates:
if pred not in all_predicates:
raise KeyError(
f"Something referenced predicate {pred} that wasn't defined in the configuration."
)
if "expr" in final_predicates[pred]:
referenced_predicates.update(
DerivedPredicateConfig(**final_predicates[pred]).input_predicates
f"Something referenced predicate '{pred}' that wasn't defined in the configuration."
)

if "expr" in all_predicates[pred]:
stack = list(DerivedPredicateConfig(**all_predicates[pred]).input_predicates)

Check warning on line 1311 in src/aces/config.py

View check run for this annotation

Codecov / codecov/patch

src/aces/config.py#L1311

Added line #L1311 was not covered by tests

while stack:
nested_pred = stack.pop()

Check warning on line 1314 in src/aces/config.py

View check run for this annotation

Codecov / codecov/patch

src/aces/config.py#L1313-L1314

Added lines #L1313 - L1314 were not covered by tests

if nested_pred not in all_predicates:
raise KeyError(

Check warning on line 1317 in src/aces/config.py

View check run for this annotation

Codecov / codecov/patch

src/aces/config.py#L1316-L1317

Added lines #L1316 - L1317 were not covered by tests
f"Predicate '{nested_pred}' referenced in '{pred}' is not defined in the "
"configuration."
)

# if nested_pred is a DerivedPredicateConfig, unpack input_predicates and add to stack
if "expr" in all_predicates[nested_pred]:
derived_config = DerivedPredicateConfig(**all_predicates[nested_pred])
stack.extend(derived_config.input_predicates)
referenced_predicates.add(nested_pred) # also add itself to referenced_predicates

Check warning on line 1326 in src/aces/config.py

View check run for this annotation

Codecov / codecov/patch

src/aces/config.py#L1323-L1326

Added lines #L1323 - L1326 were not covered by tests
else:
# if nested_pred is a PlainPredicateConfig, only add it to referenced_predicates
referenced_predicates.add(nested_pred)

Check warning on line 1329 in src/aces/config.py

View check run for this annotation

Codecov / codecov/patch

src/aces/config.py#L1329

Added line #L1329 was not covered by tests

logger.info("Parsing predicates...")
predicates_to_parse = {k: v for k, v in final_predicates.items() if k in referenced_predicates}
predicate_objs = {}
Expand Down

0 comments on commit f5b0dbc

Please sign in to comment.