From 9882786ff04a7d661a5b197f9b43bf91f106fb6d Mon Sep 17 00:00:00 2001 From: Martin Staadecker Date: Thu, 23 May 2024 23:03:35 +0200 Subject: [PATCH] Fix type errors --- src/pyoframe/_arithmetic.py | 2 +- src/pyoframe/constants.py | 4 ++-- src/pyoframe/core.py | 5 ++++- src/pyoframe/io_mappers.py | 6 +++--- src/pyoframe/model.py | 13 +++++++------ src/pyoframe/model_element.py | 8 ++++++-- src/pyoframe/solvers.py | 9 +++++++-- 7 files changed, 30 insertions(+), 17 deletions(-) diff --git a/src/pyoframe/_arithmetic.py b/src/pyoframe/_arithmetic.py index 4cb5b25..00b10c9 100644 --- a/src/pyoframe/_arithmetic.py +++ b/src/pyoframe/_arithmetic.py @@ -106,7 +106,7 @@ def get_indices(expr): strat = (left.unmatched_strategy, right.unmatched_strategy) - propogate_strat = propogatation_strategies[strat] + propogate_strat = propogatation_strategies[strat] # type: ignore if strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.DROP): left_data = left.data.join(get_indices(right), how="inner", on=dims) diff --git a/src/pyoframe/constants.py b/src/pyoframe/constants.py index 806573a..cbff7e1 100644 --- a/src/pyoframe/constants.py +++ b/src/pyoframe/constants.py @@ -132,8 +132,8 @@ def process(cls, status: str) -> "SolverStatus": def from_termination_condition( cls, termination_condition: "TerminationCondition" ) -> "SolverStatus": - for status in STATUS_TO_TERMINATION_CONDITION_MAP: - if termination_condition in STATUS_TO_TERMINATION_CONDITION_MAP[status]: + for status, termination_conditions in STATUS_TO_TERMINATION_CONDITION_MAP.items(): + if termination_condition in termination_conditions: return status return cls("unknown") diff --git a/src/pyoframe/core.py b/src/pyoframe/core.py index cc2cf62..edd3ca7 100644 --- a/src/pyoframe/core.py +++ b/src/pyoframe/core.py @@ -709,6 +709,7 @@ def value(self) -> pl.DataFrame: >>> m.objective.value 63.0 """ + assert self._model is not None, "Expression must be added to the model to use .value" if self._model.result is None or self._model.result.solution is None: raise ValueError( "Can't obtain value of expression since the model has not been solved." @@ -924,6 +925,7 @@ def slack(self): The first call to this property will load the slack values from the solver (lazy loading). """ if SLACK_COL not in self.data.columns: + assert self._model is not None, "Constraint must be added to a model to get the slack." if self._model.solver is None: raise ValueError("The model has not been solved yet.") self._model.solver.load_slack() @@ -1160,7 +1162,7 @@ def __init__( lb: float | int | SupportsToExpr | None = None, ub: float | int | SupportsToExpr | None = None, vtype: VType | VTypeValue = VType.CONTINUOUS, - equals: SupportsToExpr = None, + equals: Optional[SupportsMath] = None, ): if lb is None: lb = float("-inf") @@ -1222,6 +1224,7 @@ def RC(self): The first call to this property will load the reduced costs from the solver (lazy loading). """ if RC_COL not in self.data.columns: + assert self._model is not None, "Variable must be added to a model to get the reduced cost." if self._model.solver is None: raise ValueError("The model has not been solved yet.") self._model.solver.load_rc() diff --git a/src/pyoframe/io_mappers.py b/src/pyoframe/io_mappers.py index d030f24..0c3ad83 100644 --- a/src/pyoframe/io_mappers.py +++ b/src/pyoframe/io_mappers.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: # pragma: no cover from pyoframe.model import Variable from pyoframe.core import Constraint - from pyoframe.util import CountableModelElement + from pyoframe.model_element import ModelElementWithId @dataclass @@ -29,7 +29,7 @@ class Mapper(ABC): NAME_COL = "__name" - def __init__(self, cls: Type["CountableModelElement"]) -> None: + def __init__(self, cls: Type["ModelElementWithId"]) -> None: self._ID_COL = cls.get_id_column_name() self.mapping_registry = pl.DataFrame( {self._ID_COL: [], Mapper.NAME_COL: []}, @@ -43,7 +43,7 @@ def _extend_registry(self, df: pl.DataFrame) -> None: self.mapping_registry = pl.concat([self.mapping_registry, df]) @abstractmethod - def _element_to_map(self, element: "CountableModelElement") -> pl.DataFrame: ... + def _element_to_map(self, element: "ModelElementWithId") -> pl.DataFrame: ... def apply( self, diff --git a/src/pyoframe/model.py b/src/pyoframe/model.py index 5c864c7..cd8e72b 100644 --- a/src/pyoframe/model.py +++ b/src/pyoframe/model.py @@ -37,6 +37,7 @@ class Model(AttrContainerMixin): "result", "attr", "sense", + "objective" ] def __init__(self, min_or_max: Union[ObjSense, ObjSenseValue], name=None, **kwargs): @@ -74,6 +75,12 @@ def constraints(self): @property def objective(self): return self._objective + + @objective.setter + def objective(self, value): + value = Objective(value) + self._objective = value + value.on_add_to_model(self, "objective") def __setattr__(self, __name: str, __value: Any) -> None: if __name not in Model._reserved_attributes and not isinstance( @@ -87,9 +94,6 @@ def __setattr__(self, __name: str, __value: Any) -> None: isinstance(__value, ModelElement) and __name not in Model._reserved_attributes ): - if __name == "objective": - __value = Objective(__value) - if isinstance(__value, ModelElementWithId): assert not hasattr( self, __name @@ -103,9 +107,6 @@ def __setattr__(self, __name: str, __value: Any) -> None: self.var_map.add(__value) elif isinstance(__value, Constraint): self._constraints.append(__value) - elif isinstance(__value, Objective): - self._objective = __value - return return super().__setattr__(__name, __value) def __repr__(self) -> str: diff --git a/src/pyoframe/model_element.py b/src/pyoframe/model_element.py index c67affc..fb4f0e4 100644 --- a/src/pyoframe/model_element.py +++ b/src/pyoframe/model_element.py @@ -109,7 +109,7 @@ def _support_polars_method(method_name: str): Wrapper to add a method to ModelElement that simply calls the underlying Polars method on the data attribute. """ - def method(self: "SupportPolarsMethodMixin", *args, **kwargs): + def method(self: "SupportPolarsMethodMixin", *args, **kwargs) -> Any: result_from_polars = getattr(self.data, method_name)(*args, **kwargs) if isinstance(result_from_polars, pl.DataFrame): return self._new(result_from_polars) @@ -119,7 +119,7 @@ def method(self: "SupportPolarsMethodMixin", *args, **kwargs): return method -class SupportPolarsMethodMixin: +class SupportPolarsMethodMixin(ABC): rename = _support_polars_method("rename") with_columns = _support_polars_method("with_columns") filter = _support_polars_method("filter") @@ -131,6 +131,10 @@ def _new(self, data: pl.DataFrame): Used to create a new instance of the same class with the given data (for e.g. on .rename(), .with_columns(), etc.). """ + @property + @abstractmethod + def data(self): ... + class ModelElementWithId(ModelElement, AttrContainerMixin): """ diff --git a/src/pyoframe/solvers.py b/src/pyoframe/solvers.py index 24f93b1..5cc504b 100644 --- a/src/pyoframe/solvers.py +++ b/src/pyoframe/solvers.py @@ -105,7 +105,7 @@ def __init__(self, model, log_to_console): self.log_to_console = log_to_console @abstractmethod - def create_solver_model(self) -> Any: ... + def create_solver_model(self, directory, use_var_names, env) -> Any: ... @abstractmethod def set_attr(self, element, param_name, param_value): ... @@ -216,7 +216,7 @@ class GurobiSolver(FileBasedSolver): 17: "internal_solver_error", } - def create_solver_model_from_lp(self, problem_fn, env) -> Result: + def create_solver_model_from_lp(self, problem_fn, env) -> Any: """ Solve a linear problem using the gurobi solver. @@ -236,10 +236,12 @@ def create_solver_model_from_lp(self, problem_fn, env) -> Result: return m def set_param(self, param_name, param_value): + assert self.solver_model is not None self.solver_model.setParam(param_name, param_value) @lru_cache def _get_var_mapping(self): + assert self.solver_model is not None vars = self.solver_model.getVars() return vars, pl.DataFrame( {VAR_KEY: self.solver_model.getAttr("VarName", vars)} @@ -247,12 +249,14 @@ def _get_var_mapping(self): @lru_cache def _get_constraint_mapping(self): + assert self.solver_model is not None constraints = self.solver_model.getConstrs() return constraints, pl.DataFrame( {CONSTRAINT_KEY: self.solver_model.getAttr("ConstrName", constraints)} ).with_columns(i=pl.int_range(pl.len())) def set_attr_unmapped(self, element, param_name, param_value): + assert self.solver_model is not None if isinstance(element, pf.Model): self.solver_model.setAttr(param_name, param_value) elif isinstance(element, pf.Variable): @@ -277,6 +281,7 @@ def set_attr_unmapped(self, element, param_name, param_value): raise ValueError(f"Element type {type(element)} not recognized.") def solve(self, log_fn, warmstart_fn, basis_fn, solution_file) -> Result: + assert self.solver_model is not None m = self.solver_model if log_fn is not None: m.setParam("logfile", _path_to_str(log_fn))