Skip to content

Commit

Permalink
Merge pull request #385 from gyorilab/stratify_controllers
Browse files Browse the repository at this point in the history
Explicitly filter excluded controllers when considering cartesian product
  • Loading branch information
bgyori authored Oct 9, 2024
2 parents b435773 + 4eaf6d9 commit bfb0975
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
9 changes: 8 additions & 1 deletion mira/metamodel/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,11 @@ def stratify(
continue

# Check if we will have any controllers in the template
ncontrollers = num_controllers(template)
controllers = template.get_controllers()
stratified_controllers = [c for c in controllers if c.name
not in exclude_concepts]
ncontrollers = len(stratified_controllers)

# If we have controllers, and we want cartesian control then
# we will stratify controllers separately
stratify_controllers = (ncontrollers > 0) and cartesian_control
Expand Down Expand Up @@ -226,6 +230,9 @@ def stratify(
for c_strata_tuple in itt.product(strata, repeat=ncontrollers):
stratified_template = deepcopy(new_template)
stratified_controllers = stratified_template.get_controllers()
# Filter to make sure we skip controllers that are excluded
stratified_controllers = [c for c in stratified_controllers
if c.name not in exclude_concepts]
template_strata = [stratum if param_renaming_uses_strata_names
else stratum_idx]
# We now apply the stratum assigned to each controller in this particular
Expand Down
8 changes: 7 additions & 1 deletion tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,13 @@ def test_stratify_excluded_species():
cartesian_control=True,
concepts_to_stratify=['susceptible_population'])

assert len(tm.templates) == 5, templates
assert len(tm.templates) == 3, templates
assert tm.templates[0].subject.name == 'susceptible_population_vax'
assert tm.templates[0].outcome.name == 'infected_population'
assert tm.templates[0].controller.name == 'infected_population'
assert tm.templates[1].subject.name == 'susceptible_population_unvax'
assert tm.templates[1].outcome.name == 'infected_population'
assert tm.templates[1].controller.name == 'infected_population'


def test_stratify_parameter_consistency():
Expand Down

0 comments on commit bfb0975

Please sign in to comment.