diff --git a/cfmtoolbox/plugins/random_sampling.py b/cfmtoolbox/plugins/random_sampling.py index 4f93e23..3ab2868 100644 --- a/cfmtoolbox/plugins/random_sampling.py +++ b/cfmtoolbox/plugins/random_sampling.py @@ -2,41 +2,33 @@ from collections import defaultdict from cfmtoolbox import app -from cfmtoolbox.models import Cardinality, Feature, FeatureNode +from cfmtoolbox.models import CFM, Cardinality, Feature, FeatureNode @app.command() def random_sampling(amount: int = 1) -> list[FeatureNode] | None: - instance = RandomSamplingPlugin() - return instance.random_sampling(amount) + if app.model is None: + print("No model loaded.") + return None + return [RandomSampler(app.model).random_sampling() for _ in range(amount)] -class RandomSamplingPlugin: - def __init__(self): - self.global_feature_count = defaultdict(int) - def random_sampling(self, amount: int = 1) -> list[FeatureNode] | None: - if app.model is None: - print("No model loaded.") - return None +class RandomSampler: + def __init__(self, model: CFM): + self.global_feature_count: defaultdict[str, int] = defaultdict(int) + self.model = model - result_instances = [] - - global_upper_bound = self.get_global_upper_bound(app.model.features[0]) + def random_sampling(self) -> FeatureNode: + global_upper_bound = self.get_global_upper_bound(self.model.features[0]) self.replace_infinite_upper_bound_with_global_upper_bound( - app.model.features[0], global_upper_bound + self.model.features[0], global_upper_bound ) - for i in range(amount): - random_featurenode = self.generate_random_feature_node( - app.model.features[0] - ) - result_instances.append(random_featurenode) - print("Instance", random_featurenode) - self.global_feature_count = defaultdict(int) - - return result_instances + random_featurenode = self.generate_random_feature_node(self.model.features[0]) + print("Instance", random_featurenode) + return random_featurenode def get_global_upper_bound(self, feature: Feature): global_upper_bound = feature.instance_cardinality.intervals[-1].upper diff --git a/tests/plugins/test_random_sampling.py b/tests/plugins/test_random_sampling.py index 84c2a04..ac4e097 100644 --- a/tests/plugins/test_random_sampling.py +++ b/tests/plugins/test_random_sampling.py @@ -1,15 +1,27 @@ from pathlib import Path +import pytest + import cfmtoolbox.plugins.json_import as json_import_plugin import cfmtoolbox.plugins.random_sampling as random_sampling_plugin from cfmtoolbox import app -from cfmtoolbox.models import Cardinality, Feature, Interval +from cfmtoolbox.models import CFM, Cardinality, Feature, Interval from cfmtoolbox.plugins.random_sampling import ( - RandomSamplingPlugin, + RandomSampler, random_sampling, ) +@pytest.fixture +def model(): + return json_import_plugin.import_json(Path("tests/data/sandwich.json").read_bytes()) + + +@pytest.fixture +def random_sampler(model: CFM): + return RandomSampler(model) + + def test_plugin_can_be_loaded(): assert random_sampling_plugin in app.load_plugins() @@ -31,24 +43,21 @@ def test_random_sampling_with_loaded_model(): assert instance.validate(app.model) -def test_get_random_cardinality(): - instance = RandomSamplingPlugin() +def test_get_random_cardinality(random_sampler: RandomSampler): cardinality = Cardinality([Interval(1, 10), Interval(20, 30), Interval(40, 50)]) assert cardinality.is_valid_cardinality( - instance.get_random_cardinality(cardinality) + random_sampler.get_random_cardinality(cardinality) ) -def test_get_random_cardinality_without_zero(): - instance = RandomSamplingPlugin() +def test_get_random_cardinality_without_zero(random_sampler: RandomSampler): cardinality = Cardinality([Interval(1, 10), Interval(20, 30), Interval(40, 50)]) - random_cardinality = instance.get_random_cardinality_without_zero(cardinality) + random_cardinality = random_sampler.get_random_cardinality_without_zero(cardinality) assert cardinality.is_valid_cardinality(random_cardinality) assert random_cardinality != 0 -def test_get_sorted_sample(): - instance = RandomSamplingPlugin() +def test_get_sorted_sample(random_sampler: RandomSampler): feature_list = [ Feature( "Cheddar", @@ -75,7 +84,7 @@ def test_get_sorted_sample(): [], ), ] - sample = instance.get_sorted_sample(feature_list, 2) + sample = random_sampler.get_sorted_sample(feature_list, 2) assert len(sample) == 2 assert ( (sample[0].name == "Cheddar" and sample[1].name == "Swiss") @@ -84,8 +93,7 @@ def test_get_sorted_sample(): ) -def test_get_required_children(): - instance = RandomSamplingPlugin() +def test_get_required_children(random_sampler: RandomSampler): feature = Feature( "Cheese", Cardinality([]), @@ -103,7 +111,7 @@ def test_get_required_children(): ) ], ) - assert len(instance.get_required_children(feature)) == 1 + assert len(random_sampler.get_required_children(feature)) == 1 feature = Feature( "Cheese", Cardinality([]), @@ -129,11 +137,11 @@ def test_get_required_children(): ), ], ) - assert len(instance.get_required_children(feature)) == 2 + assert len(random_sampler.get_required_children(feature)) == 2 feature = Feature( "Cheese", Cardinality([]), Cardinality([]), Cardinality([]), [], [] ) - assert len(instance.get_required_children(feature)) == 0 + assert len(random_sampler.get_required_children(feature)) == 0 feature = Feature( "Cheese", Cardinality([]), @@ -151,11 +159,10 @@ def test_get_required_children(): ) ], ) - assert len(instance.get_required_children(feature)) == 0 + assert len(random_sampler.get_required_children(feature)) == 0 -def test_get_optional_children(): - instance = RandomSamplingPlugin() +def test_get_optional_children(random_sampler: RandomSampler): feature = Feature( "Cheese", Cardinality([]), @@ -181,7 +188,7 @@ def test_get_optional_children(): ), ], ) - assert len(instance.get_optional_children(feature)) == 1 + assert len(random_sampler.get_optional_children(feature)) == 1 feature = Feature( "Cheese", Cardinality([]), @@ -207,11 +214,11 @@ def test_get_optional_children(): ), ], ) - assert len(instance.get_optional_children(feature)) == 2 + assert len(random_sampler.get_optional_children(feature)) == 2 feature = Feature( "Cheese", Cardinality([]), Cardinality([]), Cardinality([]), [], [] ) - assert len(instance.get_optional_children(feature)) == 0 + assert len(random_sampler.get_optional_children(feature)) == 0 feature = Feature( "Cheese", Cardinality([]), @@ -229,7 +236,7 @@ def test_get_optional_children(): ) ], ) - assert len(instance.get_optional_children(feature)) == 0 + assert len(random_sampler.get_optional_children(feature)) == 0 feature = Feature( "Cheese", Cardinality([]), @@ -255,11 +262,12 @@ def test_get_optional_children(): ), ], ) - assert len(instance.get_optional_children(feature)) == 0 + assert len(random_sampler.get_optional_children(feature)) == 0 -def test_generate_random_children_with_random_cardinality(): - instance = RandomSamplingPlugin() +def test_generate_random_children_with_random_cardinality( + random_sampler: RandomSampler, +): feature = Feature( "Cheese-mix", Cardinality([]), @@ -294,7 +302,7 @@ def test_generate_random_children_with_random_cardinality(): ], ) children, summed_random_instance_cardinality = ( - instance.generate_random_children_with_random_cardinality(feature) + random_sampler.generate_random_children_with_random_cardinality(feature) ) for child, random_instance_cardinality in children: assert child.instance_cardinality.is_valid_cardinality( @@ -302,17 +310,11 @@ def test_generate_random_children_with_random_cardinality(): ) -def test_get_global_upper_bound(): - instance = RandomSamplingPlugin() - path = Path("tests/data/sandwich.json") - cfm = json_import_plugin.import_json(path.read_bytes()) - feature = cfm.features[0] - assert instance.get_global_upper_bound(feature) == 12 +def test_get_global_upper_bound(model: CFM, random_sampler: RandomSampler): + feature = model.features[0] + assert random_sampler.get_global_upper_bound(feature) == 12 -def test_generate_feature_node(): - instance = RandomSamplingPlugin() - path = Path("tests/data/sandwich.json") - cfm = json_import_plugin.import_json(path.read_bytes()) - feature = cfm.features[0] - assert instance.generate_random_feature_node(feature).validate(cfm) +def test_generate_feature_node(model: CFM, random_sampler: RandomSampler): + feature = model.features[0] + assert random_sampler.generate_random_feature_node(feature).validate(model)