From b024dbe5586763d825da1741603656c3fb3731f8 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Sun, 15 Sep 2024 05:48:51 +0200 Subject: [PATCH] Make CFMs required in the command interface --- cfmtoolbox/plugins/big_m.py | 6 +- cfmtoolbox/plugins/conversion.py | 2 +- cfmtoolbox/plugins/debugging.py | 7 +- cfmtoolbox/plugins/one_wise_sampling.py | 7 +- cfmtoolbox/plugins/random_sampling.py | 6 +- cfmtoolbox/toolbox.py | 84 ++++++++++++---------- tests/plugins/test_big_m.py | 4 -- tests/plugins/test_debugging.py | 4 -- tests/plugins/test_one_wise_sampling.py | 5 -- tests/plugins/test_random_sampling.py | 5 -- tests/test_cli.py | 12 ++++ tests/test_toolbox.py | 92 +++++++++++++++++-------- 12 files changed, 130 insertions(+), 104 deletions(-) create mode 100644 tests/test_cli.py diff --git a/cfmtoolbox/plugins/big_m.py b/cfmtoolbox/plugins/big_m.py index 5fa45a9..f20e7f6 100644 --- a/cfmtoolbox/plugins/big_m.py +++ b/cfmtoolbox/plugins/big_m.py @@ -3,11 +3,7 @@ @app.command() -def apply_big_m(model: CFM | None) -> CFM | None: - if model is None: - print("No model loaded.") - return None - +def apply_big_m(model: CFM) -> CFM: global_upper_bound = get_global_upper_bound(model.features[0]) replace_infinite_upper_bound_with_global_upper_bound( diff --git a/cfmtoolbox/plugins/conversion.py b/cfmtoolbox/plugins/conversion.py index a483348..5a3fcaf 100644 --- a/cfmtoolbox/plugins/conversion.py +++ b/cfmtoolbox/plugins/conversion.py @@ -2,6 +2,6 @@ @app.command() -def convert(cfm: CFM | None) -> CFM | None: +def convert(cfm: CFM) -> CFM: print("Converting CFM...") return cfm diff --git a/cfmtoolbox/plugins/debugging.py b/cfmtoolbox/plugins/debugging.py index a8bdaa7..c69bd29 100644 --- a/cfmtoolbox/plugins/debugging.py +++ b/cfmtoolbox/plugins/debugging.py @@ -6,12 +6,9 @@ def stringify_list(name: str, input: list) -> str: return f"- {name}: {stringified_input}\n" -def stringify_cfm(cfm: CFM | None) -> str: +def stringify_cfm(cfm: CFM) -> str: formatted_cfm = "CFM:\n" - if cfm is None: - return formatted_cfm + "None" - for feature in cfm.features: formatted_cfm += f"{feature}: instance [{feature.instance_cardinality}], group type [{feature.group_type_cardinality}], group instance [{feature.group_instance_cardinality}]\n" formatted_cfm += f"- parent: {feature.parent}\n" @@ -24,7 +21,7 @@ def stringify_cfm(cfm: CFM | None) -> str: @app.command() -def debug(model: CFM | None) -> CFM | None: +def debug(model: CFM) -> CFM: print(stringify_cfm(model), end="") return model diff --git a/cfmtoolbox/plugins/one_wise_sampling.py b/cfmtoolbox/plugins/one_wise_sampling.py index a01d53b..7ff9644 100644 --- a/cfmtoolbox/plugins/one_wise_sampling.py +++ b/cfmtoolbox/plugins/one_wise_sampling.py @@ -12,11 +12,8 @@ @app.command() def one_wise_sampling( - model: CFM | None, -) -> CFM | None: - if model is None: - raise typer.Abort("No model loaded.") - + model: CFM, +) -> CFM: if model.is_unbound(): raise typer.Abort("Model is unbound. Please apply big-m global bound first.") diff --git a/cfmtoolbox/plugins/random_sampling.py b/cfmtoolbox/plugins/random_sampling.py index 8b7e63a..df2afb2 100644 --- a/cfmtoolbox/plugins/random_sampling.py +++ b/cfmtoolbox/plugins/random_sampling.py @@ -8,11 +8,7 @@ @app.command() -def random_sampling(model: CFM | None, amount: int = 1) -> CFM | None: - if model is None: - print("No model loaded.") - return None - +def random_sampling(model: CFM, amount: int = 1) -> CFM: if model.is_unbound(): print("Model is unbound. Please apply big-m global bound first.") return model diff --git a/cfmtoolbox/toolbox.py b/cfmtoolbox/toolbox.py index 9e5d726..8cb0ed1 100644 --- a/cfmtoolbox/toolbox.py +++ b/cfmtoolbox/toolbox.py @@ -1,5 +1,4 @@ import inspect -from functools import partial from importlib.metadata import entry_points from pathlib import Path from types import ModuleType @@ -12,22 +11,23 @@ ) import typer +from rich.console import Console from cfmtoolbox.models import CFM Importer: TypeAlias = Callable[[bytes], CFM] Exporter: TypeAlias = Callable[[CFM], bytes] -CommandF = TypeVar("CommandF", bound=Callable[[CFM | None], CFM | None]) +CommandF = TypeVar("CommandF", bound=Callable[[CFM], CFM]) class CFMToolbox: def __init__(self) -> None: self.registered_importers: dict[str, Importer] = {} self.registered_exporters: dict[str, Exporter] = {} - self.model: CFM | None = None self.import_path: Path | None = None self.export_path: Path | None = None - self.typer = typer.Typer(callback=self.prepare, result_callback=self.cleanup) + self.typer = typer.Typer(callback=self.prepare) + self.err_console = Console(stderr=True) def __call__(self) -> None: return self.typer() @@ -39,30 +39,30 @@ def prepare( ) -> None: self.import_path = import_path self.export_path = export_path - self.import_model() - def cleanup(self, *args: tuple[object], **kwargs: dict[str, object]) -> None: - self.export_model() + def import_model(self) -> CFM | None: + if self.import_path is None: + return None - def import_model(self) -> None: - if self.import_path: - importer = self.registered_importers.get(self.import_path.suffix) + importer = self.registered_importers.get(self.import_path.suffix) - if importer is None: - message = f"Unsupported import format: {self.import_path.suffix}" - raise typer.Abort(message) + if importer is None: + message = f"Unsupported import format: {self.import_path.suffix}" + raise typer.Abort(message) - self.model = importer(self.import_path.read_bytes()) + return importer(self.import_path.read_bytes()) - def export_model(self) -> None: - if self.export_path and self.model: - exporter = self.registered_exporters.get(self.export_path.suffix) + def export_model(self, model: CFM) -> None: + if self.export_path is None: + return - if exporter is None: - message = f"Unsupported export format: {self.export_path.suffix}" - raise typer.Abort(message) + exporter = self.registered_exporters.get(self.export_path.suffix) - self.export_path.write_bytes(exporter(self.model)) + if exporter is None: + message = f"Unsupported export format: {self.export_path.suffix}" + raise typer.Abort(message) + + self.export_path.write_bytes(exporter(model)) def importer(self, extension: str) -> Callable[[Importer], Importer]: def decorator(func: Importer) -> Importer: @@ -79,21 +79,33 @@ def decorator(func: Exporter) -> Exporter: return decorator def command(self, *args, **kwargs) -> Callable[[CommandF], CommandF]: - def decorator(func: CommandF) -> CommandF: - partial_func = partial(func, self.model) - - def lazy_partial_func(*args, **kwargs): - self.model = func(self.model, *args, **kwargs) - - lazy_partial_func.__name__ = func.__name__ - lazy_partial_func.__module__ = func.__module__ - lazy_partial_func.__qualname__ = func.__qualname__ - lazy_partial_func.__doc__ = func.__doc__ - lazy_partial_func.__annotations__ = func.__annotations__ - setattr(lazy_partial_func, "__signature__", inspect.signature(partial_func)) - - self.typer.command(*args, **kwargs)(lazy_partial_func) - return func + def decorator(internal_function: CommandF) -> CommandF: + internal_signature = inspect.signature(internal_function) + internal_params = list(internal_signature.parameters.values()) + external_params = internal_params[1:] # Omit the first/CFM parameter + external_signature = internal_signature.replace(parameters=external_params) + + def external_function(*args, **kwargs): + original_model = self.import_model() + + if original_model is None: + self.err_console.print( + "Please provide a model via the --import option." + ) + raise typer.Exit(code=1) + + modified_model = internal_function(original_model, *args, **kwargs) + self.export_model(modified_model) + + external_function.__name__ = internal_function.__name__ + external_function.__module__ = internal_function.__module__ + external_function.__qualname__ = internal_function.__qualname__ + external_function.__doc__ = internal_function.__doc__ + external_function.__annotations__ = internal_function.__annotations__ + setattr(external_function, "__signature__", external_signature) + + self.typer.command(*args, **kwargs)(external_function) + return internal_function return decorator diff --git a/tests/plugins/test_big_m.py b/tests/plugins/test_big_m.py index bb3f941..4eb7c0b 100644 --- a/tests/plugins/test_big_m.py +++ b/tests/plugins/test_big_m.py @@ -18,10 +18,6 @@ def test_plugin_can_be_loaded(): assert big_m in app.load_plugins() -def test_apply_big_m_without_loaded_model(): - assert apply_big_m(None) is None - - def test_apply_big_m_with_loaded_model(model: CFM): assert model.is_unbound() new_model = apply_big_m(model) diff --git a/tests/plugins/test_debugging.py b/tests/plugins/test_debugging.py index 00b04b8..8fb91a4 100644 --- a/tests/plugins/test_debugging.py +++ b/tests/plugins/test_debugging.py @@ -114,7 +114,3 @@ def test_stringify_cfm(): ) assert cfm_str == stringify_cfm(cfm) - - -def test_stringify_cfm_can_stringify_none_cfm(): - assert stringify_cfm(None) == "CFM:\nNone" diff --git a/tests/plugins/test_one_wise_sampling.py b/tests/plugins/test_one_wise_sampling.py index 80c84a0..9ea4598 100644 --- a/tests/plugins/test_one_wise_sampling.py +++ b/tests/plugins/test_one_wise_sampling.py @@ -29,11 +29,6 @@ def test_plugin_can_be_loaded(): assert one_wise_sampling_plugin in app.load_plugins() -def test_one_wise_sampling_without_loaded_model(): - with pytest.raises(typer.Abort, match="No model loaded."): - one_wise_sampling(None) - - def test_one_wise_sampling_with_unbound_model(unbound_model: CFM, capsys): with pytest.raises( typer.Abort, match="Model is unbound. Please apply big-m global bound first." diff --git a/tests/plugins/test_random_sampling.py b/tests/plugins/test_random_sampling.py index 04ecc1a..7d778a0 100644 --- a/tests/plugins/test_random_sampling.py +++ b/tests/plugins/test_random_sampling.py @@ -31,11 +31,6 @@ def test_plugin_can_be_loaded(): assert random_sampling_plugin in app.load_plugins() -def test_random_sampling_without_loaded_model(): - assert random_sampling(None) is None - assert random_sampling(None, 3) is None - - def test_random_sampling_with_unbound_model(unbound_model: CFM, capsys): random_sampling(unbound_model) is unbound_model captured = capsys.readouterr() diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..536e564 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,12 @@ +from typer.testing import CliRunner + +from cfmtoolbox import app + +runner = CliRunner(mix_stderr=False) + + +def test_cli_aborts_when_commands_are_invoked_without_import_option(): + result = runner.invoke(app.typer, ["convert"]) + assert result.exit_code == 1 + assert result.stderr == "Please provide a model via the --import option.\n" + assert not result.stdout diff --git a/tests/test_toolbox.py b/tests/test_toolbox.py index f7f27d0..6630a22 100644 --- a/tests/test_toolbox.py +++ b/tests/test_toolbox.py @@ -1,3 +1,4 @@ +import inspect from pathlib import Path import pytest @@ -7,71 +8,77 @@ from cfmtoolbox.toolbox import CFMToolbox -def test_import_model_without_import_path(): +def test_import_model_does_nothing_without_import_path(): app = CFMToolbox() assert app.import_path is None + assert app.import_model() is None - app.import_model() - assert app.model is None +def test_import_model_reports_unsupported_formats(tmp_path): + import_path = tmp_path / "test.txt" + import_path.touch() -def test_import_model_without_matching_importer(tmp_path: Path): app = CFMToolbox() - app.import_path = tmp_path / "test.txt" - app.import_path.touch() + app.import_path = import_path with pytest.raises(typer.Abort, match="Unsupported import format"): app.import_model() - assert app.model is None - -def test_import_model_with_matching_importer(tmp_path: Path): - app = CFMToolbox() - app.import_path = tmp_path / "test.uvl" - app.import_path.touch() +def test_import_model_returns_cfm_of_supported_format(tmp_path: Path): + import_path = tmp_path / "test.uvl" + import_path.touch() cfm = CFM([], [], []) - assert app.model is not cfm + + app = CFMToolbox() + app.import_path = import_path @app.importer(".uvl") def import_uvl(data: bytes): return cfm - app.import_model() - assert app.model is cfm + assert app.import_model() is cfm + +def test_export_model_does_nothing_without_export_path(): + cfm = CFM([], [], []) -def test_export_model_without_export_path(): app = CFMToolbox() assert app.export_path is None - app.export_model() + app.export_model(cfm) -def test_export_model_without_matching_exporter(tmp_path: Path): +def test_export_model_reports_unsupported_formats(tmp_path): + cfm = CFM([], [], []) + export_path = tmp_path / "test.txt" + app = CFMToolbox() - app.model = CFM([], [], []) - app.export_path = tmp_path / "test.txt" + app.export_path = export_path with pytest.raises(typer.Abort, match="Unsupported export format"): - app.export_model() + app.export_model(cfm) + + assert not export_path.exists() -def test_export_model_with_matching_exporter(tmp_path: Path): +def test_export_model_stores_exported_model_in_supported_format(tmp_path): + cfm = CFM([], [], []) + export_path = tmp_path / "test.uvl" + app = CFMToolbox() - app.model = CFM([], [], []) - app.export_path = tmp_path / "test.uvl" + app.export_path = export_path @app.exporter(".uvl") def export_uvl(cfm: CFM): return "hello".encode() - app.export_model() - assert app.export_path.read_text() == "hello" + app.export_model(cfm) + assert export_path.read_text() == "hello" -def test_importer_registration(): +def test_importer_registers_the_decorated_importer(): app = CFMToolbox() @app.importer(".uvl") @@ -82,7 +89,7 @@ def import_uvl(data: bytes): assert app.registered_importers[".uvl"] == import_uvl -def test_exporter_registration(): +def test_exporter_registers_the_decorated_exporter(): app = CFMToolbox() @app.exporter(".uvl") @@ -93,7 +100,34 @@ def export_uvl(cfm: CFM): assert app.registered_exporters[".uvl"] == export_uvl -def test_load_plugins(): +def test_command_registers_the_decorated_command(): + app = CFMToolbox() + assert len(app.typer.registered_commands) == 0 + + @app.command() + def make_sandwich(cfm: CFM) -> CFM: + return cfm + + assert len(app.typer.registered_commands) == 1 + + command = app.typer.registered_commands[0] + assert getattr(command.callback, "__name__") == "make_sandwich" + + +def test_command_prevent_typer_from_including_the_cfm_argument_in_the_cli(): + app = CFMToolbox() + + @app.command() + def make_sandwich(cfm: CFM) -> CFM: + return cfm + + command = app.typer.registered_commands[0] + callback = command.callback + assert callback is not None + assert "cfm" not in inspect.signature(callback).parameters + + +def test_load_plugins_loads_all_core_plugins(): app = CFMToolbox() plugins = app.load_plugins() assert len(plugins) == 10