Skip to content

Commit

Permalink
Make CFMs required in the command interface
Browse files Browse the repository at this point in the history
  • Loading branch information
DoctorJohn committed Sep 15, 2024
1 parent a84884b commit b8e27b1
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 94 deletions.
6 changes: 1 addition & 5 deletions cfmtoolbox/plugins/big_m.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@


@app.command()
def apply_big_m(model: CFM | None) -> CFM | None:
if model is None:
print("No model loaded.")
return None

def apply_big_m(model: CFM) -> CFM:
global_upper_bound = get_global_upper_bound(model.features[0])

replace_infinite_upper_bound_with_global_upper_bound(
Expand Down
2 changes: 1 addition & 1 deletion cfmtoolbox/plugins/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@


@app.command()
def convert(cfm: CFM | None) -> CFM | None:
def convert(cfm: CFM) -> CFM:
print("Converting CFM...")
return cfm
7 changes: 2 additions & 5 deletions cfmtoolbox/plugins/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@ def stringify_list(name: str, input: list) -> str:
return f"- {name}: {stringified_input}\n"


def stringify_cfm(cfm: CFM | None) -> str:
def stringify_cfm(cfm: CFM) -> str:
formatted_cfm = "CFM:\n"

if cfm is None:
return formatted_cfm + "None"

for feature in cfm.features:
formatted_cfm += f"{feature}: instance [{feature.instance_cardinality}], group type [{feature.group_type_cardinality}], group instance [{feature.group_instance_cardinality}]\n"
formatted_cfm += f"- parent: {feature.parent}\n"
Expand All @@ -24,7 +21,7 @@ def stringify_cfm(cfm: CFM | None) -> str:


@app.command()
def debug(model: CFM | None) -> CFM | None:
def debug(model: CFM) -> CFM:
print(stringify_cfm(model), end="")

return model
6 changes: 1 addition & 5 deletions cfmtoolbox/plugins/random_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@


@app.command()
def random_sampling(model: CFM | None, amount: int = 1) -> CFM | None:
if model is None:
print("No model loaded.")
return None

def random_sampling(model: CFM, amount: int = 1) -> CFM:
if model.is_unbound():
print("Model is unbound. Please apply big-m global bound first.")
return model
Expand Down
84 changes: 48 additions & 36 deletions cfmtoolbox/toolbox.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
from functools import partial
from importlib.metadata import entry_points
from pathlib import Path
from types import ModuleType
Expand All @@ -12,22 +11,23 @@
)

import typer
from rich.console import Console

from cfmtoolbox.models import CFM

Importer: TypeAlias = Callable[[bytes], CFM]
Exporter: TypeAlias = Callable[[CFM], bytes]
CommandF = TypeVar("CommandF", bound=Callable[[CFM | None], CFM | None])
CommandF = TypeVar("CommandF", bound=Callable[[CFM], CFM])


class CFMToolbox:
def __init__(self) -> None:
self.registered_importers: dict[str, Importer] = {}
self.registered_exporters: dict[str, Exporter] = {}
self.model: CFM | None = None
self.import_path: Path | None = None
self.export_path: Path | None = None
self.typer = typer.Typer(callback=self.prepare, result_callback=self.cleanup)
self.typer = typer.Typer(callback=self.prepare)
self.err_console = Console(stderr=True)

def __call__(self) -> None:
return self.typer()
Expand All @@ -39,30 +39,30 @@ def prepare(
) -> None:
self.import_path = import_path
self.export_path = export_path
self.import_model()

def cleanup(self, *args: tuple[object], **kwargs: dict[str, object]) -> None:
self.export_model()
def import_model(self) -> CFM | None:
if self.import_path is None:
return None

def import_model(self) -> None:
if self.import_path:
importer = self.registered_importers.get(self.import_path.suffix)
importer = self.registered_importers.get(self.import_path.suffix)

if importer is None:
message = f"Unsupported import format: {self.import_path.suffix}"
raise typer.Abort(message)
if importer is None:
message = f"Unsupported import format: {self.import_path.suffix}"
raise typer.Abort(message)

self.model = importer(self.import_path.read_bytes())
return importer(self.import_path.read_bytes())

def export_model(self) -> None:
if self.export_path and self.model:
exporter = self.registered_exporters.get(self.export_path.suffix)
def export_model(self, model: CFM) -> None:
if self.export_path is None:
return

if exporter is None:
message = f"Unsupported export format: {self.export_path.suffix}"
raise typer.Abort(message)
exporter = self.registered_exporters.get(self.export_path.suffix)

self.export_path.write_bytes(exporter(self.model))
if exporter is None:
message = f"Unsupported export format: {self.export_path.suffix}"
raise typer.Abort(message)

self.export_path.write_bytes(exporter(model))

def importer(self, extension: str) -> Callable[[Importer], Importer]:
def decorator(func: Importer) -> Importer:
Expand All @@ -79,21 +79,33 @@ def decorator(func: Exporter) -> Exporter:
return decorator

def command(self, *args, **kwargs) -> Callable[[CommandF], CommandF]:
def decorator(func: CommandF) -> CommandF:
partial_func = partial(func, self.model)

def lazy_partial_func(*args, **kwargs):
self.model = func(self.model, *args, **kwargs)

lazy_partial_func.__name__ = func.__name__
lazy_partial_func.__module__ = func.__module__
lazy_partial_func.__qualname__ = func.__qualname__
lazy_partial_func.__doc__ = func.__doc__
lazy_partial_func.__annotations__ = func.__annotations__
setattr(lazy_partial_func, "__signature__", inspect.signature(partial_func))

self.typer.command(*args, **kwargs)(lazy_partial_func)
return func
def decorator(internal_function: CommandF) -> CommandF:
internal_signature = inspect.signature(internal_function)
internal_params = list(internal_signature.parameters.values())
external_params = internal_params[1:] # Omit the first/CFM parameter
external_signature = internal_signature.replace(parameters=external_params)

def external_function(*args, **kwargs):
original_model = self.import_model()

if original_model is None:
self.err_console.print(
"Please provide a model via the --import option."
)
raise typer.Exit(code=1)

modified_model = internal_function(original_model, *args, **kwargs)
self.export_model(modified_model)

external_function.__name__ = internal_function.__name__
external_function.__module__ = internal_function.__module__
external_function.__qualname__ = internal_function.__qualname__
external_function.__doc__ = internal_function.__doc__
external_function.__annotations__ = internal_function.__annotations__
setattr(external_function, "__signature__", external_signature)

self.typer.command(*args, **kwargs)(external_function)
return internal_function

return decorator

Expand Down
4 changes: 0 additions & 4 deletions tests/plugins/test_big_m.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ def test_plugin_can_be_loaded():
assert big_m in app.load_plugins()


def test_apply_big_m_without_loaded_model():
assert apply_big_m(None) is None


def test_apply_big_m_with_loaded_model(model: CFM):
assert model.is_unbound()
new_model = apply_big_m(model)
Expand Down
4 changes: 0 additions & 4 deletions tests/plugins/test_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,3 @@ def test_stringify_cfm():
)

assert cfm_str == stringify_cfm(cfm)


def test_stringify_cfm_can_stringify_none_cfm():
assert stringify_cfm(None) == "CFM:\nNone"
5 changes: 0 additions & 5 deletions tests/plugins/test_random_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ def test_plugin_can_be_loaded():
assert random_sampling_plugin in app.load_plugins()


def test_random_sampling_without_loaded_model():
assert random_sampling(None) is None
assert random_sampling(None, 3) is None


def test_random_sampling_with_unbound_model(unbound_model: CFM, capsys):
random_sampling(unbound_model) is unbound_model
captured = capsys.readouterr()
Expand Down
12 changes: 12 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typer.testing import CliRunner

from cfmtoolbox import app

runner = CliRunner(mix_stderr=False)


def test_cli_aborts_when_commands_are_invoked_without_import_option():
result = runner.invoke(app.typer, ["convert"])
assert result.exit_code == 1
assert result.stderr == "Please provide a model via the --import option.\n"
assert not result.stdout
92 changes: 63 additions & 29 deletions tests/test_toolbox.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from pathlib import Path

import pytest
Expand All @@ -7,71 +8,77 @@
from cfmtoolbox.toolbox import CFMToolbox


def test_import_model_without_import_path():
def test_import_model_does_nothing_without_import_path():
app = CFMToolbox()
assert app.import_path is None
assert app.import_model() is None

app.import_model()
assert app.model is None

def test_import_model_reports_unsupported_formats(tmp_path):
import_path = tmp_path / "test.txt"
import_path.touch()

def test_import_model_without_matching_importer(tmp_path: Path):
app = CFMToolbox()
app.import_path = tmp_path / "test.txt"
app.import_path.touch()
app.import_path = import_path

with pytest.raises(typer.Abort, match="Unsupported import format"):
app.import_model()

assert app.model is None


def test_import_model_with_matching_importer(tmp_path: Path):
app = CFMToolbox()
app.import_path = tmp_path / "test.uvl"
app.import_path.touch()
def test_import_model_returns_cfm_of_supported_format(tmp_path: Path):
import_path = tmp_path / "test.uvl"
import_path.touch()

cfm = CFM([], [], [])
assert app.model is not cfm

app = CFMToolbox()
app.import_path = import_path

@app.importer(".uvl")
def import_uvl(data: bytes):
return cfm

app.import_model()
assert app.model is cfm
assert app.import_model() is cfm


def test_export_model_does_nothing_without_export_path():
cfm = CFM([], [], [])

def test_export_model_without_export_path():
app = CFMToolbox()
assert app.export_path is None

app.export_model()
app.export_model(cfm)


def test_export_model_without_matching_exporter(tmp_path: Path):
def test_export_model_reports_unsupported_formats(tmp_path):
cfm = CFM([], [], [])
export_path = tmp_path / "test.txt"

app = CFMToolbox()
app.model = CFM([], [], [])
app.export_path = tmp_path / "test.txt"
app.export_path = export_path

with pytest.raises(typer.Abort, match="Unsupported export format"):
app.export_model()
app.export_model(cfm)

assert not export_path.exists()


def test_export_model_with_matching_exporter(tmp_path: Path):
def test_export_model_stores_exported_model_in_supported_format(tmp_path):
cfm = CFM([], [], [])
export_path = tmp_path / "test.uvl"

app = CFMToolbox()
app.model = CFM([], [], [])
app.export_path = tmp_path / "test.uvl"
app.export_path = export_path

@app.exporter(".uvl")
def export_uvl(cfm: CFM):
return "hello".encode()

app.export_model()
assert app.export_path.read_text() == "hello"
app.export_model(cfm)
assert export_path.read_text() == "hello"


def test_importer_registration():
def test_importer_registers_the_decorated_importer():
app = CFMToolbox()

@app.importer(".uvl")
Expand All @@ -82,7 +89,7 @@ def import_uvl(data: bytes):
assert app.registered_importers[".uvl"] == import_uvl


def test_exporter_registration():
def test_exporter_registers_the_decorated_exporter():
app = CFMToolbox()

@app.exporter(".uvl")
Expand All @@ -93,7 +100,34 @@ def export_uvl(cfm: CFM):
assert app.registered_exporters[".uvl"] == export_uvl


def test_load_plugins():
def test_command_registers_the_decorated_command():
app = CFMToolbox()
assert len(app.typer.registered_commands) == 0

@app.command()
def make_sandwich(cfm: CFM) -> CFM:
return cfm

assert len(app.typer.registered_commands) == 1

command = app.typer.registered_commands[0]
assert getattr(command.callback, "__name__") == "make_sandwich"


def test_command_prevent_typer_from_including_the_cfm_argument_in_the_cli():
app = CFMToolbox()

@app.command()
def make_sandwich(cfm: CFM) -> CFM:
return cfm

command = app.typer.registered_commands[0]
callback = command.callback
assert callback is not None
assert "cfm" not in inspect.signature(callback).parameters


def test_load_plugins_loads_all_core_plugins():
app = CFMToolbox()
plugins = app.load_plugins()
assert len(plugins) == 9

0 comments on commit b8e27b1

Please sign in to comment.