Skip to content

Commit

Permalink
Fix type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
staadecker committed May 23, 2024
1 parent 82a17ce commit 9882786
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/pyoframe/_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_indices(expr):

strat = (left.unmatched_strategy, right.unmatched_strategy)

propogate_strat = propogatation_strategies[strat]
propogate_strat = propogatation_strategies[strat] # type: ignore

if strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.DROP):
left_data = left.data.join(get_indices(right), how="inner", on=dims)
Expand Down
4 changes: 2 additions & 2 deletions src/pyoframe/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def process(cls, status: str) -> "SolverStatus":
def from_termination_condition(
cls, termination_condition: "TerminationCondition"
) -> "SolverStatus":
for status in STATUS_TO_TERMINATION_CONDITION_MAP:
if termination_condition in STATUS_TO_TERMINATION_CONDITION_MAP[status]:
for status, termination_conditions in STATUS_TO_TERMINATION_CONDITION_MAP.items():
if termination_condition in termination_conditions:
return status
return cls("unknown")

Expand Down
5 changes: 4 additions & 1 deletion src/pyoframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ def value(self) -> pl.DataFrame:
>>> m.objective.value
63.0
"""
assert self._model is not None, "Expression must be added to the model to use .value"
if self._model.result is None or self._model.result.solution is None:
raise ValueError(
"Can't obtain value of expression since the model has not been solved."
Expand Down Expand Up @@ -924,6 +925,7 @@ def slack(self):
The first call to this property will load the slack values from the solver (lazy loading).
"""
if SLACK_COL not in self.data.columns:
assert self._model is not None, "Constraint must be added to a model to get the slack."
if self._model.solver is None:
raise ValueError("The model has not been solved yet.")
self._model.solver.load_slack()
Expand Down Expand Up @@ -1160,7 +1162,7 @@ def __init__(
lb: float | int | SupportsToExpr | None = None,
ub: float | int | SupportsToExpr | None = None,
vtype: VType | VTypeValue = VType.CONTINUOUS,
equals: SupportsToExpr = None,
equals: Optional[SupportsMath] = None,
):
if lb is None:
lb = float("-inf")
Expand Down Expand Up @@ -1222,6 +1224,7 @@ def RC(self):
The first call to this property will load the reduced costs from the solver (lazy loading).
"""
if RC_COL not in self.data.columns:
assert self._model is not None, "Variable must be added to a model to get the reduced cost."
if self._model.solver is None:
raise ValueError("The model has not been solved yet.")
self._model.solver.load_rc()
Expand Down
6 changes: 3 additions & 3 deletions src/pyoframe/io_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
if TYPE_CHECKING: # pragma: no cover
from pyoframe.model import Variable
from pyoframe.core import Constraint
from pyoframe.util import CountableModelElement
from pyoframe.model_element import ModelElementWithId


@dataclass
Expand All @@ -29,7 +29,7 @@ class Mapper(ABC):

NAME_COL = "__name"

def __init__(self, cls: Type["CountableModelElement"]) -> None:
def __init__(self, cls: Type["ModelElementWithId"]) -> None:
self._ID_COL = cls.get_id_column_name()
self.mapping_registry = pl.DataFrame(
{self._ID_COL: [], Mapper.NAME_COL: []},
Expand All @@ -43,7 +43,7 @@ def _extend_registry(self, df: pl.DataFrame) -> None:
self.mapping_registry = pl.concat([self.mapping_registry, df])

@abstractmethod
def _element_to_map(self, element: "CountableModelElement") -> pl.DataFrame: ...
def _element_to_map(self, element: "ModelElementWithId") -> pl.DataFrame: ...

def apply(
self,
Expand Down
13 changes: 7 additions & 6 deletions src/pyoframe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Model(AttrContainerMixin):
"result",
"attr",
"sense",
"objective"
]

def __init__(self, min_or_max: Union[ObjSense, ObjSenseValue], name=None, **kwargs):
Expand Down Expand Up @@ -74,6 +75,12 @@ def constraints(self):
@property
def objective(self):
return self._objective

@objective.setter
def objective(self, value):
value = Objective(value)
self._objective = value
value.on_add_to_model(self, "objective")

def __setattr__(self, __name: str, __value: Any) -> None:
if __name not in Model._reserved_attributes and not isinstance(
Expand All @@ -87,9 +94,6 @@ def __setattr__(self, __name: str, __value: Any) -> None:
isinstance(__value, ModelElement)
and __name not in Model._reserved_attributes
):
if __name == "objective":
__value = Objective(__value)

if isinstance(__value, ModelElementWithId):
assert not hasattr(
self, __name
Expand All @@ -103,9 +107,6 @@ def __setattr__(self, __name: str, __value: Any) -> None:
self.var_map.add(__value)
elif isinstance(__value, Constraint):
self._constraints.append(__value)
elif isinstance(__value, Objective):
self._objective = __value
return
return super().__setattr__(__name, __value)

def __repr__(self) -> str:
Expand Down
8 changes: 6 additions & 2 deletions src/pyoframe/model_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _support_polars_method(method_name: str):
Wrapper to add a method to ModelElement that simply calls the underlying Polars method on the data attribute.
"""

def method(self: "SupportPolarsMethodMixin", *args, **kwargs):
def method(self: "SupportPolarsMethodMixin", *args, **kwargs) -> Any:
result_from_polars = getattr(self.data, method_name)(*args, **kwargs)
if isinstance(result_from_polars, pl.DataFrame):
return self._new(result_from_polars)
Expand All @@ -119,7 +119,7 @@ def method(self: "SupportPolarsMethodMixin", *args, **kwargs):
return method


class SupportPolarsMethodMixin:
class SupportPolarsMethodMixin(ABC):
rename = _support_polars_method("rename")
with_columns = _support_polars_method("with_columns")
filter = _support_polars_method("filter")
Expand All @@ -131,6 +131,10 @@ def _new(self, data: pl.DataFrame):
Used to create a new instance of the same class with the given data (for e.g. on .rename(), .with_columns(), etc.).
"""

@property
@abstractmethod
def data(self): ...


class ModelElementWithId(ModelElement, AttrContainerMixin):
"""
Expand Down
9 changes: 7 additions & 2 deletions src/pyoframe/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(self, model, log_to_console):
self.log_to_console = log_to_console

@abstractmethod
def create_solver_model(self) -> Any: ...
def create_solver_model(self, directory, use_var_names, env) -> Any: ...

@abstractmethod
def set_attr(self, element, param_name, param_value): ...
Expand Down Expand Up @@ -216,7 +216,7 @@ class GurobiSolver(FileBasedSolver):
17: "internal_solver_error",
}

def create_solver_model_from_lp(self, problem_fn, env) -> Result:
def create_solver_model_from_lp(self, problem_fn, env) -> Any:
"""
Solve a linear problem using the gurobi solver.
Expand All @@ -236,23 +236,27 @@ def create_solver_model_from_lp(self, problem_fn, env) -> Result:
return m

def set_param(self, param_name, param_value):
assert self.solver_model is not None
self.solver_model.setParam(param_name, param_value)

@lru_cache
def _get_var_mapping(self):
assert self.solver_model is not None
vars = self.solver_model.getVars()
return vars, pl.DataFrame(
{VAR_KEY: self.solver_model.getAttr("VarName", vars)}
).with_columns(i=pl.int_range(pl.len()))

@lru_cache
def _get_constraint_mapping(self):
assert self.solver_model is not None
constraints = self.solver_model.getConstrs()
return constraints, pl.DataFrame(
{CONSTRAINT_KEY: self.solver_model.getAttr("ConstrName", constraints)}
).with_columns(i=pl.int_range(pl.len()))

def set_attr_unmapped(self, element, param_name, param_value):
assert self.solver_model is not None
if isinstance(element, pf.Model):
self.solver_model.setAttr(param_name, param_value)
elif isinstance(element, pf.Variable):
Expand All @@ -277,6 +281,7 @@ def set_attr_unmapped(self, element, param_name, param_value):
raise ValueError(f"Element type {type(element)} not recognized.")

def solve(self, log_fn, warmstart_fn, basis_fn, solution_file) -> Result:
assert self.solver_model is not None
m = self.solver_model
if log_fn is not None:
m.setParam("logfile", _path_to_str(log_fn))
Expand Down

0 comments on commit 9882786

Please sign in to comment.