Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make CFMs required in the command interface #35

Merged
merged 1 commit into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions cfmtoolbox/plugins/big_m.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@


@app.command()
def apply_big_m(model: CFM | None) -> CFM | None:
if model is None:
print("No model loaded.")
return None

def apply_big_m(model: CFM) -> CFM:
global_upper_bound = get_global_upper_bound(model.features[0])

replace_infinite_upper_bound_with_global_upper_bound(
Expand Down
2 changes: 1 addition & 1 deletion cfmtoolbox/plugins/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@


@app.command()
def convert(cfm: CFM | None) -> CFM | None:
def convert(cfm: CFM) -> CFM:
print("Converting CFM...")
return cfm
7 changes: 2 additions & 5 deletions cfmtoolbox/plugins/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@ def stringify_list(name: str, input: list) -> str:
return f"- {name}: {stringified_input}\n"


def stringify_cfm(cfm: CFM | None) -> str:
def stringify_cfm(cfm: CFM) -> str:
formatted_cfm = "CFM:\n"

if cfm is None:
return formatted_cfm + "None"

for feature in cfm.features:
formatted_cfm += f"{feature}: instance [{feature.instance_cardinality}], group type [{feature.group_type_cardinality}], group instance [{feature.group_instance_cardinality}]\n"
formatted_cfm += f"- parent: {feature.parent}\n"
Expand All @@ -24,7 +21,7 @@ def stringify_cfm(cfm: CFM | None) -> str:


@app.command()
def debug(model: CFM | None) -> CFM | None:
def debug(model: CFM) -> CFM:
print(stringify_cfm(model), end="")

return model
7 changes: 1 addition & 6 deletions cfmtoolbox/plugins/one_wise_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@


@app.command()
def one_wise_sampling(
model: CFM | None,
) -> CFM | None:
if model is None:
raise typer.Abort("No model loaded.")

def one_wise_sampling(model: CFM) -> CFM:
if model.is_unbound():
raise typer.Abort("Model is unbound. Please apply big-m global bound first.")

Expand Down
6 changes: 1 addition & 5 deletions cfmtoolbox/plugins/random_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@


@app.command()
def random_sampling(model: CFM | None, amount: int = 1) -> CFM | None:
if model is None:
print("No model loaded.")
return None

def random_sampling(model: CFM, amount: int = 1) -> CFM:
if model.is_unbound():
print("Model is unbound. Please apply big-m global bound first.")
return model
Expand Down
84 changes: 48 additions & 36 deletions cfmtoolbox/toolbox.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
from functools import partial
from importlib.metadata import entry_points
from pathlib import Path
from types import ModuleType
Expand All @@ -12,22 +11,23 @@
)

import typer
from rich.console import Console

from cfmtoolbox.models import CFM

Importer: TypeAlias = Callable[[bytes], CFM]
Exporter: TypeAlias = Callable[[CFM], bytes]
CommandF = TypeVar("CommandF", bound=Callable[[CFM | None], CFM | None])
CommandF = TypeVar("CommandF", bound=Callable[[CFM], CFM])


class CFMToolbox:
def __init__(self) -> None:
self.registered_importers: dict[str, Importer] = {}
self.registered_exporters: dict[str, Exporter] = {}
self.model: CFM | None = None
self.import_path: Path | None = None
self.export_path: Path | None = None
self.typer = typer.Typer(callback=self.prepare, result_callback=self.cleanup)
self.typer = typer.Typer(callback=self.prepare)
self.err_console = Console(stderr=True)

def __call__(self) -> None:
return self.typer()
Expand All @@ -39,30 +39,30 @@ def prepare(
) -> None:
self.import_path = import_path
self.export_path = export_path
self.import_model()

def cleanup(self, *args: tuple[object], **kwargs: dict[str, object]) -> None:
self.export_model()
def import_model(self) -> CFM | None:
if self.import_path is None:
return None

def import_model(self) -> None:
if self.import_path:
importer = self.registered_importers.get(self.import_path.suffix)
importer = self.registered_importers.get(self.import_path.suffix)

if importer is None:
message = f"Unsupported import format: {self.import_path.suffix}"
raise typer.Abort(message)
if importer is None:
message = f"Unsupported import format: {self.import_path.suffix}"
raise typer.Abort(message)

self.model = importer(self.import_path.read_bytes())
return importer(self.import_path.read_bytes())

def export_model(self) -> None:
if self.export_path and self.model:
exporter = self.registered_exporters.get(self.export_path.suffix)
def export_model(self, model: CFM) -> None:
if self.export_path is None:
return

if exporter is None:
message = f"Unsupported export format: {self.export_path.suffix}"
raise typer.Abort(message)
exporter = self.registered_exporters.get(self.export_path.suffix)

self.export_path.write_bytes(exporter(self.model))
if exporter is None:
message = f"Unsupported export format: {self.export_path.suffix}"
raise typer.Abort(message)

self.export_path.write_bytes(exporter(model))

def importer(self, extension: str) -> Callable[[Importer], Importer]:
def decorator(func: Importer) -> Importer:
Expand All @@ -79,21 +79,33 @@ def decorator(func: Exporter) -> Exporter:
return decorator

def command(self, *args, **kwargs) -> Callable[[CommandF], CommandF]:
def decorator(func: CommandF) -> CommandF:
partial_func = partial(func, self.model)

def lazy_partial_func(*args, **kwargs):
self.model = func(self.model, *args, **kwargs)

lazy_partial_func.__name__ = func.__name__
lazy_partial_func.__module__ = func.__module__
lazy_partial_func.__qualname__ = func.__qualname__
lazy_partial_func.__doc__ = func.__doc__
lazy_partial_func.__annotations__ = func.__annotations__
setattr(lazy_partial_func, "__signature__", inspect.signature(partial_func))

self.typer.command(*args, **kwargs)(lazy_partial_func)
return func
def decorator(internal_function: CommandF) -> CommandF:
internal_signature = inspect.signature(internal_function)
internal_params = list(internal_signature.parameters.values())
external_params = internal_params[1:] # Omit the first/CFM parameter
external_signature = internal_signature.replace(parameters=external_params)

def external_function(*args, **kwargs):
original_model = self.import_model()

if original_model is None:
self.err_console.print(
"Please provide a model via the --import option."
)
raise typer.Exit(code=1)

modified_model = internal_function(original_model, *args, **kwargs)
self.export_model(modified_model)

external_function.__name__ = internal_function.__name__
external_function.__module__ = internal_function.__module__
external_function.__qualname__ = internal_function.__qualname__
external_function.__doc__ = internal_function.__doc__
external_function.__annotations__ = internal_function.__annotations__
setattr(external_function, "__signature__", external_signature)

self.typer.command(*args, **kwargs)(external_function)
return internal_function

return decorator

Expand Down
4 changes: 0 additions & 4 deletions tests/plugins/test_big_m.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ def test_plugin_can_be_loaded():
assert big_m in app.load_plugins()


def test_apply_big_m_without_loaded_model():
assert apply_big_m(None) is None


def test_apply_big_m_with_loaded_model(model: CFM):
assert model.is_unbound()
new_model = apply_big_m(model)
Expand Down
4 changes: 0 additions & 4 deletions tests/plugins/test_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,3 @@ def test_stringify_cfm():
)

assert cfm_str == stringify_cfm(cfm)


def test_stringify_cfm_can_stringify_none_cfm():
assert stringify_cfm(None) == "CFM:\nNone"
5 changes: 0 additions & 5 deletions tests/plugins/test_one_wise_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ def test_plugin_can_be_loaded():
assert one_wise_sampling_plugin in app.load_plugins()


def test_one_wise_sampling_without_loaded_model():
with pytest.raises(typer.Abort, match="No model loaded."):
one_wise_sampling(None)


def test_one_wise_sampling_with_unbound_model(unbound_model: CFM, capsys):
with pytest.raises(
typer.Abort, match="Model is unbound. Please apply big-m global bound first."
Expand Down
5 changes: 0 additions & 5 deletions tests/plugins/test_random_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ def test_plugin_can_be_loaded():
assert random_sampling_plugin in app.load_plugins()


def test_random_sampling_without_loaded_model():
assert random_sampling(None) is None
assert random_sampling(None, 3) is None


def test_random_sampling_with_unbound_model(unbound_model: CFM, capsys):
random_sampling(unbound_model) is unbound_model
captured = capsys.readouterr()
Expand Down
12 changes: 12 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typer.testing import CliRunner

from cfmtoolbox import app

runner = CliRunner(mix_stderr=False)


def test_cli_aborts_when_commands_are_invoked_without_import_option():
result = runner.invoke(app.typer, ["convert"])
assert result.exit_code == 1
assert result.stderr == "Please provide a model via the --import option.\n"
assert not result.stdout
Loading
Loading