diff --git a/mira/metamodel/ops.py b/mira/metamodel/ops.py index adec6c75..3a8f2051 100644 --- a/mira/metamodel/ops.py +++ b/mira/metamodel/ops.py @@ -169,7 +169,11 @@ def stratify( continue # Check if we will have any controllers in the template - ncontrollers = num_controllers(template) + controllers = template.get_controllers() + stratified_controllers = [c for c in controllers if c.name + not in exclude_concepts] + ncontrollers = len(stratified_controllers) + # If we have controllers, and we want cartesian control then # we will stratify controllers separately stratify_controllers = (ncontrollers > 0) and cartesian_control @@ -226,6 +230,9 @@ def stratify( for c_strata_tuple in itt.product(strata, repeat=ncontrollers): stratified_template = deepcopy(new_template) stratified_controllers = stratified_template.get_controllers() + # Filter to make sure we skip controllers that are excluded + stratified_controllers = [c for c in stratified_controllers + if c.name not in exclude_concepts] template_strata = [stratum if param_renaming_uses_strata_names else stratum_idx] # We now apply the stratum assigned to each controller in this particular diff --git a/tests/test_ops.py b/tests/test_ops.py index b46c3ea4..585daab7 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -579,7 +579,13 @@ def test_stratify_excluded_species(): cartesian_control=True, concepts_to_stratify=['susceptible_population']) - assert len(tm.templates) == 5, templates + assert len(tm.templates) == 3, templates + assert tm.templates[0].subject.name == 'susceptible_population_vax' + assert tm.templates[0].outcome.name == 'infected_population' + assert tm.templates[0].controller.name == 'infected_population' + assert tm.templates[1].subject.name == 'susceptible_population_unvax' + assert tm.templates[1].outcome.name == 'infected_population' + assert tm.templates[1].controller.name == 'infected_population' def test_stratify_parameter_consistency():