Skip to content

Commit

Permalink
Refactor RandomSamplingPlugin class to only be responsible for one ra…
Browse files Browse the repository at this point in the history
…ndom instance
  • Loading branch information
DanielZhangDP committed Aug 11, 2024
1 parent 7247408 commit 2303ed4
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 62 deletions.
38 changes: 15 additions & 23 deletions cfmtoolbox/plugins/random_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,33 @@
from collections import defaultdict

from cfmtoolbox import app
from cfmtoolbox.models import Cardinality, Feature, FeatureNode
from cfmtoolbox.models import CFM, Cardinality, Feature, FeatureNode


@app.command()
def random_sampling(amount: int = 1) -> list[FeatureNode] | None:
instance = RandomSamplingPlugin()
return instance.random_sampling(amount)
if app.model is None:
print("No model loaded.")
return None

return [RandomSampler(app.model).random_sampling() for _ in range(amount)]

class RandomSamplingPlugin:
def __init__(self):
self.global_feature_count = defaultdict(int)

def random_sampling(self, amount: int = 1) -> list[FeatureNode] | None:
if app.model is None:
print("No model loaded.")
return None
class RandomSampler:
def __init__(self, model: CFM):
self.global_feature_count: defaultdict[str, int] = defaultdict(int)
self.model = model

result_instances = []

global_upper_bound = self.get_global_upper_bound(app.model.features[0])
def random_sampling(self) -> FeatureNode:
global_upper_bound = self.get_global_upper_bound(self.model.features[0])

self.replace_infinite_upper_bound_with_global_upper_bound(
app.model.features[0], global_upper_bound
self.model.features[0], global_upper_bound
)

for i in range(amount):
random_featurenode = self.generate_random_feature_node(
app.model.features[0]
)
result_instances.append(random_featurenode)
print("Instance", random_featurenode)
self.global_feature_count = defaultdict(int)

return result_instances
random_featurenode = self.generate_random_feature_node(self.model.features[0])
print("Instance", random_featurenode)
return random_featurenode

def get_global_upper_bound(self, feature: Feature):
global_upper_bound = feature.instance_cardinality.intervals[-1].upper
Expand Down
80 changes: 41 additions & 39 deletions tests/plugins/test_random_sampling.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
from pathlib import Path

import pytest

import cfmtoolbox.plugins.json_import as json_import_plugin
import cfmtoolbox.plugins.random_sampling as random_sampling_plugin
from cfmtoolbox import app
from cfmtoolbox.models import Cardinality, Feature, Interval
from cfmtoolbox.models import CFM, Cardinality, Feature, Interval
from cfmtoolbox.plugins.random_sampling import (
RandomSamplingPlugin,
RandomSampler,
random_sampling,
)


@pytest.fixture
def model():
return json_import_plugin.import_json(Path("tests/data/sandwich.json").read_bytes())


@pytest.fixture
def random_sampler(model: CFM):
return RandomSampler(model)


def test_plugin_can_be_loaded():
assert random_sampling_plugin in app.load_plugins()

Expand All @@ -31,24 +43,21 @@ def test_random_sampling_with_loaded_model():
assert instance.validate(app.model)


def test_get_random_cardinality():
instance = RandomSamplingPlugin()
def test_get_random_cardinality(random_sampler: RandomSampler):
cardinality = Cardinality([Interval(1, 10), Interval(20, 30), Interval(40, 50)])
assert cardinality.is_valid_cardinality(
instance.get_random_cardinality(cardinality)
random_sampler.get_random_cardinality(cardinality)
)


def test_get_random_cardinality_without_zero():
instance = RandomSamplingPlugin()
def test_get_random_cardinality_without_zero(random_sampler: RandomSampler):
cardinality = Cardinality([Interval(1, 10), Interval(20, 30), Interval(40, 50)])
random_cardinality = instance.get_random_cardinality_without_zero(cardinality)
random_cardinality = random_sampler.get_random_cardinality_without_zero(cardinality)
assert cardinality.is_valid_cardinality(random_cardinality)
assert random_cardinality != 0


def test_get_sorted_sample():
instance = RandomSamplingPlugin()
def test_get_sorted_sample(random_sampler: RandomSampler):
feature_list = [
Feature(
"Cheddar",
Expand All @@ -75,7 +84,7 @@ def test_get_sorted_sample():
[],
),
]
sample = instance.get_sorted_sample(feature_list, 2)
sample = random_sampler.get_sorted_sample(feature_list, 2)
assert len(sample) == 2
assert (
(sample[0].name == "Cheddar" and sample[1].name == "Swiss")
Expand All @@ -84,8 +93,7 @@ def test_get_sorted_sample():
)


def test_get_required_children():
instance = RandomSamplingPlugin()
def test_get_required_children(random_sampler: RandomSampler):
feature = Feature(
"Cheese",
Cardinality([]),
Expand All @@ -103,7 +111,7 @@ def test_get_required_children():
)
],
)
assert len(instance.get_required_children(feature)) == 1
assert len(random_sampler.get_required_children(feature)) == 1
feature = Feature(
"Cheese",
Cardinality([]),
Expand All @@ -129,11 +137,11 @@ def test_get_required_children():
),
],
)
assert len(instance.get_required_children(feature)) == 2
assert len(random_sampler.get_required_children(feature)) == 2
feature = Feature(
"Cheese", Cardinality([]), Cardinality([]), Cardinality([]), [], []
)
assert len(instance.get_required_children(feature)) == 0
assert len(random_sampler.get_required_children(feature)) == 0
feature = Feature(
"Cheese",
Cardinality([]),
Expand All @@ -151,11 +159,10 @@ def test_get_required_children():
)
],
)
assert len(instance.get_required_children(feature)) == 0
assert len(random_sampler.get_required_children(feature)) == 0


def test_get_optional_children():
instance = RandomSamplingPlugin()
def test_get_optional_children(random_sampler: RandomSampler):
feature = Feature(
"Cheese",
Cardinality([]),
Expand All @@ -181,7 +188,7 @@ def test_get_optional_children():
),
],
)
assert len(instance.get_optional_children(feature)) == 1
assert len(random_sampler.get_optional_children(feature)) == 1
feature = Feature(
"Cheese",
Cardinality([]),
Expand All @@ -207,11 +214,11 @@ def test_get_optional_children():
),
],
)
assert len(instance.get_optional_children(feature)) == 2
assert len(random_sampler.get_optional_children(feature)) == 2
feature = Feature(
"Cheese", Cardinality([]), Cardinality([]), Cardinality([]), [], []
)
assert len(instance.get_optional_children(feature)) == 0
assert len(random_sampler.get_optional_children(feature)) == 0
feature = Feature(
"Cheese",
Cardinality([]),
Expand All @@ -229,7 +236,7 @@ def test_get_optional_children():
)
],
)
assert len(instance.get_optional_children(feature)) == 0
assert len(random_sampler.get_optional_children(feature)) == 0
feature = Feature(
"Cheese",
Cardinality([]),
Expand All @@ -255,11 +262,12 @@ def test_get_optional_children():
),
],
)
assert len(instance.get_optional_children(feature)) == 0
assert len(random_sampler.get_optional_children(feature)) == 0


def test_generate_random_children_with_random_cardinality():
instance = RandomSamplingPlugin()
def test_generate_random_children_with_random_cardinality(
random_sampler: RandomSampler,
):
feature = Feature(
"Cheese-mix",
Cardinality([]),
Expand Down Expand Up @@ -294,25 +302,19 @@ def test_generate_random_children_with_random_cardinality():
],
)
children, summed_random_instance_cardinality = (
instance.generate_random_children_with_random_cardinality(feature)
random_sampler.generate_random_children_with_random_cardinality(feature)
)
for child, random_instance_cardinality in children:
assert child.instance_cardinality.is_valid_cardinality(
random_instance_cardinality
)


def test_get_global_upper_bound():
instance = RandomSamplingPlugin()
path = Path("tests/data/sandwich.json")
cfm = json_import_plugin.import_json(path.read_bytes())
feature = cfm.features[0]
assert instance.get_global_upper_bound(feature) == 12
def test_get_global_upper_bound(model: CFM, random_sampler: RandomSampler):
feature = model.features[0]
assert random_sampler.get_global_upper_bound(feature) == 12


def test_generate_feature_node():
instance = RandomSamplingPlugin()
path = Path("tests/data/sandwich.json")
cfm = json_import_plugin.import_json(path.read_bytes())
feature = cfm.features[0]
assert instance.generate_random_feature_node(feature).validate(cfm)
def test_generate_feature_node(model: CFM, random_sampler: RandomSampler):
feature = model.features[0]
assert random_sampler.generate_random_feature_node(feature).validate(model)

0 comments on commit 2303ed4

Please sign in to comment.