diff --git a/cfmtoolbox/plugins/one_wise_sampling.py b/cfmtoolbox/plugins/one_wise_sampling.py new file mode 100644 index 0000000..c1100f6 --- /dev/null +++ b/cfmtoolbox/plugins/one_wise_sampling.py @@ -0,0 +1,161 @@ +import json +import random +from collections import defaultdict +from dataclasses import asdict + +from cfmtoolbox import app +from cfmtoolbox.models import CFM, Cardinality, Feature, FeatureNode + + +@app.command() +def one_wise_sampling( + model: CFM | None, +) -> CFM | None: + if model is None: + print("No model loaded.") + return None + + if model.is_unbound(): + print("Model is unbound. Please apply big-m global bound first.") + return model + + print( + json.dumps( + [asdict(sample) for sample in OneWiseSampler(model).one_wise_sampling()], + indent=2, + ) + ) + + return model + + +# The OneWiseSampler class is responsible for generating one-wise samples under the definitions of Instance-Set, Boundary-Interior Coverage and global constraints +class OneWiseSampler: + def __init__(self, model: CFM): + self.global_feature_count: defaultdict[str, int] = defaultdict(int) + # An assignment describes a feature and the amount of instances it should have + self.assignments: set[tuple[str, int]] = set() + # Covered assignments are all assignments of that appear in a sample and gets filled while generating the sample + self.covered_assignments: set[tuple[str, int]] = set() + # The chosen assignment is the assignment that is currently being used to generate a sample + self.chosen_assignment: tuple[str, int] + self.model = model + + def one_wise_sampling(self) -> list[FeatureNode]: + self.calculate_border_assignments(self.model.features[0]) + + samples = [] + + while self.assignments: + self.chosen_assignment = self.assignments.pop() + samples.append(self.generate_valid_sample()) + self.delete_covered_assignments() + + return samples + + def delete_covered_assignments(self): + for assignment in self.covered_assignments: + self.assignments.discard(assignment) + + def calculate_border_assignments(self, feature: Feature): + for interval in feature.instance_cardinality.intervals: + self.assignments.add((feature.name, interval.lower)) + if interval.upper is not None: + self.assignments.add((feature.name, interval.upper)) + for child in feature.children: + self.calculate_border_assignments(child) + + def generate_valid_sample(self): + while True: + self.global_feature_count = defaultdict(int) + self.covered_assignments = set() + self.covered_assignments.add((self.model.features[0].name, 1)) + random_feature_node = self.generate_random_feature_node_with_assignment( + self.model.features[0] + ) + if ( + random_feature_node.validate(self.model) + and self.chosen_assignment in self.covered_assignments + ): + break + return random_feature_node + + def generate_random_feature_node_with_assignment( + self, + feature: Feature, + ): + feature_node = FeatureNode( + value=f"{feature.name}#{self.global_feature_count[feature.name]}", + children=[], + ) + + self.global_feature_count[feature.name] += 1 + + if not feature.children: + return feature_node + + # Generate until both the group instance and group type cardinalities are valid + while True: + ( + random_children, + summed_random_instance_cardinality, + summed_random_group_type_cardinality, + ) = self.generate_random_children_with_random_cardinality_with_assignment( + feature + ) + if feature.group_instance_cardinality.is_valid_cardinality( + summed_random_instance_cardinality + ) and feature.group_type_cardinality.is_valid_cardinality( + summed_random_group_type_cardinality + ): + break + + for child, random_instance_cardinality in random_children: + # Store already covered assignments while generating for later validation + self.covered_assignments.add((child.name, random_instance_cardinality)) + for _ in range(random_instance_cardinality): + feature_node.children.append( + self.generate_random_feature_node_with_assignment(child) + ) + + return feature_node + + def get_random_cardinality(self, cardinality_list: Cardinality): + random_interval = random.choice(cardinality_list.intervals) + random_cardinality = random.randint( + random_interval.lower, + random_interval.upper + if random_interval.upper is not None + else random_interval.lower + 5, + ) + return random_cardinality + + def generate_random_children_with_random_cardinality_with_assignment( + self, feature: Feature + ): + summed_random_instance_cardinality = 0 + summed_random_group_type_cardinality = 0 + child_with_random_instance_cardinality: list[ + tuple[Feature, int] + ] = [] # List of tuples (child, random_instance_cardinality) + + for child in feature.children: + # Enforces the feature of the chosen assignment to have the chosen amount of instances + if child.name == self.chosen_assignment[0]: + random_instance_cardinality = self.chosen_assignment[1] + else: + random_instance_cardinality = self.get_random_cardinality( + child.instance_cardinality + ) + if random_instance_cardinality != 0: + summed_random_group_type_cardinality += 1 + summed_random_instance_cardinality += random_instance_cardinality + child_with_random_instance_cardinality.append( + (child, random_instance_cardinality) + ) + + return ( + child_with_random_instance_cardinality, + summed_random_instance_cardinality, + summed_random_group_type_cardinality, + ) diff --git a/pyproject.toml b/pyproject.toml index 773c02e..89c4b45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ json-export = "cfmtoolbox.plugins.json_export" featureide-import = "cfmtoolbox.plugins.featureide_import" debugging = "cfmtoolbox.plugins.debugging" big-m = "cfmtoolbox.plugins.big_m" +one-wise-sampling = "cfmtoolbox.plugins.one_wise_sampling" [tool.ruff.lint] extend-select = ["I"] diff --git a/tests/plugins/test_one_wise_sampling.py b/tests/plugins/test_one_wise_sampling.py new file mode 100644 index 0000000..73a7e9d --- /dev/null +++ b/tests/plugins/test_one_wise_sampling.py @@ -0,0 +1,167 @@ +from pathlib import Path + +import pytest + +import cfmtoolbox.plugins.one_wise_sampling as one_wise_sampling_plugin +from cfmtoolbox import app +from cfmtoolbox.models import CFM, Cardinality, Feature, Interval +from cfmtoolbox.plugins.json_import import import_json +from cfmtoolbox.plugins.one_wise_sampling import OneWiseSampler, one_wise_sampling + + +@pytest.fixture +def model(): + return import_json(Path("tests/data/sandwich_bound.json").read_bytes()) + + +@pytest.fixture +def unbound_model(): + return import_json(Path("tests/data/sandwich.json").read_bytes()) + + +@pytest.fixture +def one_wise_sampler(model: CFM): + return OneWiseSampler(model) + + +def test_plugin_can_be_loaded(): + assert one_wise_sampling_plugin in app.load_plugins() + + +def test_one_wise_sampling_without_loaded_model(): + assert one_wise_sampling(None) is None + + +def test_one_wise_sampling_with_unbound_model(unbound_model: CFM, capsys): + one_wise_sampling(unbound_model) is unbound_model + captured = capsys.readouterr() + assert "Model is unbound. Please apply big-m global bound first." in captured.out + + +def test_plugin_passes_though_model(model: CFM): + assert one_wise_sampling(model) is model + + +def test_plugin_outputs_at_least_one_sample(model: CFM, capsys): + one_wise_sampling(model) + captured = capsys.readouterr() + assert captured.out.count("sandwich#0") >= 1 + + +def test_one_wise_sampling_with_loaded_model_every_sample_is_valid(model: CFM): + samples = OneWiseSampler(model).one_wise_sampling() + for feature_node in samples: + assert feature_node.validate(model) + + +def test_delete_covered_assignments(one_wise_sampler: OneWiseSampler): + one_wise_sampler.covered_assignments = {("a", 1), ("b", 2), ("c", 3), ("d", 1)} + one_wise_sampler.assignments = {("a", 1), ("b", 2), ("c", 3), ("d", 4)} + one_wise_sampler.delete_covered_assignments() + assert one_wise_sampler.assignments == {("d", 4)} + + +def test_calculate_border_assignments(one_wise_sampler: OneWiseSampler): + feature = Feature( + "Cheese-mix", + Cardinality([Interval(0, 2), Interval(5, 7), Interval(10, 10)]), + Cardinality([Interval(1, 3)]), + Cardinality([Interval(3, 3)]), + [], + [ + Feature( + "Cheddar", + Cardinality([Interval(0, 1)]), + Cardinality([]), + Cardinality([]), + [], + [], + ), + Feature( + "Swiss", + Cardinality([Interval(0, 2)]), + Cardinality([]), + Cardinality([]), + [], + [], + ), + Feature( + "Gouda", + Cardinality([Interval(0, 3)]), + Cardinality([]), + Cardinality([]), + [], + [], + ), + ], + ) + one_wise_sampler.calculate_border_assignments(feature) + assert one_wise_sampler.assignments == { + ("Cheese-mix", 0), + ("Cheese-mix", 2), + ("Cheese-mix", 5), + ("Cheese-mix", 7), + ("Cheese-mix", 10), + ("Cheddar", 0), + ("Cheddar", 1), + ("Swiss", 0), + ("Swiss", 2), + ("Gouda", 0), + ("Gouda", 3), + } + + +def test_get_random_cardinality(one_wise_sampler: OneWiseSampler): + cardinality = Cardinality([Interval(1, 10), Interval(20, 30), Interval(40, 50)]) + assert cardinality.is_valid_cardinality( + one_wise_sampler.get_random_cardinality(cardinality) + ) + + +def test_generate_random_children_with_random_cardinality_with_assignment( + one_wise_sampler: OneWiseSampler, +): + feature = Feature( + "Cheese-mix", + Cardinality([]), + Cardinality([Interval(1, 3)]), + Cardinality([Interval(3, 3)]), + [], + [ + Feature( + "Cheddar", + Cardinality([Interval(0, 1)]), + Cardinality([]), + Cardinality([]), + [], + [], + ), + Feature( + "Swiss", + Cardinality([Interval(0, 2)]), + Cardinality([]), + Cardinality([]), + [], + [], + ), + Feature( + "Gouda", + Cardinality([Interval(0, 3)]), + Cardinality([]), + Cardinality([]), + [], + [], + ), + ], + ) + one_wise_sampler.chosen_assignment = ("Gouda", 2) + children, _, _ = ( + one_wise_sampler.generate_random_children_with_random_cardinality_with_assignment( + feature + ) + ) + assert children[2][0].name == "Gouda" and children[2][1] == 2 + for child, random_instance_cardinality in children: + assert child.instance_cardinality.is_valid_cardinality( + random_instance_cardinality + ) diff --git a/tests/test_toolbox.py b/tests/test_toolbox.py index 169ea54..4f85869 100644 --- a/tests/test_toolbox.py +++ b/tests/test_toolbox.py @@ -107,4 +107,4 @@ def export_uvl(cfm: CFM): def test_load_plugins(): app = CFMToolbox() plugins = app.load_plugins() - assert len(plugins) == 8 + assert len(plugins) == 9