From a9c522173ca37459d2b311de5caf2ed08c5b38ad Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Thu, 23 Jan 2025 08:58:31 -0700 Subject: [PATCH] fixing xhatters --- mpisppy/cylinders/spoke.py | 52 ++++++---------- mpisppy/cylinders/xhatlooper_bounder.py | 5 +- .../cylinders/xhatshufflelooper_bounder.py | 6 +- mpisppy/extensions/xhatbase.py | 2 + mpisppy/spbase.py | 60 +++++++++++++++++-- 5 files changed, 80 insertions(+), 45 deletions(-) diff --git a/mpisppy/cylinders/spoke.py b/mpisppy/cylinders/spoke.py index 7be2d5a62..92c673029 100644 --- a/mpisppy/cylinders/spoke.py +++ b/mpisppy/cylinders/spoke.py @@ -13,7 +13,6 @@ import os import math -from pyomo.environ import ComponentMap, Var from mpisppy import MPI from mpisppy.cylinders.spcommunicator import SPCommunicator, communicator_array @@ -328,43 +327,26 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, options= self.best_inner_bound = math.inf if self.is_minimizing else -math.inf self.solver_options = None # can be overwritten by derived classes - # set up best solution cache - for k,s in self.opt.local_scenarios.items(): - s._mpisppy_data.best_solution_cache = None - - def update_if_improving(self, candidate_inner_bound): - if candidate_inner_bound is None: - return False - update = (candidate_inner_bound < self.best_inner_bound) \ - if self.is_minimizing else \ + def update_if_improving(self, candidate_inner_bound, update_cache=True): + if update_cache: + update = self.opt.update_best_solution_if_improving(candidate_inner_bound) + else: + update = ( (candidate_inner_bound < self.best_inner_bound) + if self.is_minimizing else (self.best_inner_bound < candidate_inner_bound) - if not update: - return False - - self.best_inner_bound = candidate_inner_bound - # send to hub - self.bound = candidate_inner_bound - self._cache_best_solution() - return True + ) + if update: + self.best_inner_bound = candidate_inner_bound + # send to hub + self.bound = candidate_inner_bound + return True + return False def finalize(self): - for k,s in self.opt.local_scenarios.items(): - if s._mpisppy_data.best_solution_cache is None: - return None - for var, value in s._mpisppy_data.best_solution_cache.items(): - var.set_value(value, skip_validation=True) - - self.opt.first_stage_solution_available = True - self.opt.tree_solution_available = True - self.final_bound = self.bound - return self.final_bound - - def _cache_best_solution(self): - for k,s in self.opt.local_scenarios.items(): - scenario_cache = ComponentMap() - for var in s.component_data_objects(Var): - scenario_cache[var] = var.value - s._mpisppy_data.best_solution_cache = scenario_cache + if self.opt.load_best_solution(): + self.final_bound = self.bound + return self.final_bound + return None class OuterBoundNonantSpoke(_BoundNonantSpoke): diff --git a/mpisppy/cylinders/xhatlooper_bounder.py b/mpisppy/cylinders/xhatlooper_bounder.py index 6cd386959..c2e70ff47 100644 --- a/mpisppy/cylinders/xhatlooper_bounder.py +++ b/mpisppy/cylinders/xhatlooper_bounder.py @@ -76,8 +76,9 @@ def main(self): # just for sending the values to other scenarios # so we don't need to tell persistent solvers self.opt._restore_nonants(update_persistent=False) - upperbound, srcsname = xhatter.xhat_looper(scen_limit=scen_limit, restore_nonants=False) + upperbound, srcsname = xhatter.xhat_looper(scen_limit=scen_limit, restore_nonants=True) # send a bound to the opt companion - self.update_if_improving(upperbound) + # the xhatter updates the cache in the opt object for us + self.update_if_improving(upperbound, update_cache=False) xh_iter += 1 diff --git a/mpisppy/cylinders/xhatshufflelooper_bounder.py b/mpisppy/cylinders/xhatshufflelooper_bounder.py index 44a8d00ba..cbd34c565 100644 --- a/mpisppy/cylinders/xhatshufflelooper_bounder.py +++ b/mpisppy/cylinders/xhatshufflelooper_bounder.py @@ -71,7 +71,7 @@ def try_scenario_dict(self, xhat_scenario_dict): obj = self.xhatter._try_one(snamedict, solver_options = self.solver_options, verbose=False, - restore_nonants=False, + restore_nonants=True, stage2EFsolvern=stage2EFsolvern, branching_factors=branching_factors) def _vb(msg): @@ -83,7 +83,8 @@ def _vb(msg): return False _vb(f" Feasible {snamedict}, obj: {obj}") - update = self.update_if_improving(obj) + # the xhatter updates the cache in the opt object for us + update = self.update_if_improving(obj, update_cache=False) logger.debug(f' bottom of try_scenario_dict on rank {self.global_rank}') return update @@ -143,6 +144,7 @@ def _vb(msg): # so we don't need to tell persistent solvers self.opt._restore_nonants(update_persistent=False) + _vb(" Begin epoch") scenario_cycler.begin_epoch() next_scendict = scenario_cycler.get_next() diff --git a/mpisppy/extensions/xhatbase.py b/mpisppy/extensions/xhatbase.py index 592e62dc6..542642a59 100644 --- a/mpisppy/extensions/xhatbase.py +++ b/mpisppy/extensions/xhatbase.py @@ -202,6 +202,7 @@ def _try_one(self, snamedict, solver_options=None, verbose=False, self.opt.local_scenarios[sname].pprint() # get the global obj obj = self.opt.Eobjective(verbose=verbose) + self.opt.update_best_solution_if_improving(obj) if restore_nonants: self.opt._restore_nonants() return obj @@ -228,6 +229,7 @@ def _try_one(self, snamedict, solver_options=None, verbose=False, print(" Feasible xhat found:") self.opt.local_scenarios[sname].pprint() obj = self.opt.Eobjective(verbose=verbose) + self.opt.update_best_solution_if_improving(obj) if restore_nonants: self.opt._restore_nonants() return obj diff --git a/mpisppy/spbase.py b/mpisppy/spbase.py index c801addba..43a605e46 100644 --- a/mpisppy/spbase.py +++ b/mpisppy/spbase.py @@ -91,6 +91,11 @@ def __init__( self.n_proc = self.mpicomm.Get_size() self.global_rank = MPI.COMM_WORLD.Get_rank() + # for writers, if the appropriate + # solution is loaded into the subproblems + self.tree_solution_available = False + self.first_stage_solution_available = False + self.best_solution_obj_val = None if options.get("toc", True): global_toc("Initializing SPBase") @@ -120,14 +125,11 @@ def __init__( self._verify_nonant_lengths() self._set_sense() self._use_variable_probability_setter() + self._set_best_solution_cache() ## SPCommunicator object self._spcomm = None - # for writers, if the appropriate - # solution is loaded into the subproblems - self.tree_solution_available = False - self.first_stage_solution_available = False def _set_sense(self, comm=None): """ Check to confirm that all the models constructed by scenario_crator @@ -540,6 +542,48 @@ def _options_check(self, required_options, given_options): if missing: raise ValueError(f"Missing the following required options: {', '.join(missing)}") + def _set_best_solution_cache(self): + # set up best solution cache + for k,s in self.local_scenarios.items(): + s._mpisppy_data.best_solution_cache = None + + def update_best_solution_if_improving(self, obj_val): + """ Call if the variable values have a nonanticipative solution + with associated obj_val. Will update the best_solution_cache + if the solution is better than the existing cached solution + """ + if obj_val is None: + return False + if self.best_solution_obj_val is None: + update = True + elif self.is_minimizing: + update = (obj_val < self.best_solution_obj_val) + else: + update = (self.best_solution_obj_val < obj_val) + if update: + self.best_solution_obj_val = obj_val + self._cache_best_solution() + return True + return False + + def _cache_best_solution(self): + for k,s in self.local_scenarios.items(): + scenario_cache = pyo.ComponentMap() + for var in s.component_data_objects(pyo.Var): + scenario_cache[var] = var.value + s._mpisppy_data.best_solution_cache = scenario_cache + + def load_best_solution(self): + for k,s in self.local_scenarios.items(): + if s._mpisppy_data.best_solution_cache is None: + return False + for var, value in s._mpisppy_data.best_solution_cache.items(): + var.set_value(value, skip_validation=True) + + self.first_stage_solution_available = True + self.tree_solution_available = True + return True + @property def spcomm(self): if self._spcomm is None: @@ -652,7 +696,9 @@ def write_first_stage_solution(self, file_name, """ if not self.first_stage_solution_available: - raise RuntimeError("No first stage solution available") + # try loading a solution + if not self.load_best_solution(): + raise RuntimeError("No first stage solution available") if self.cylinder_rank == 0: dirname = os.path.dirname(file_name) if dirname != '': @@ -670,7 +716,9 @@ def write_tree_solution(self, directory_name, scenario_tree_solution_writer (optional): custom scenario solution writer function """ if not self.tree_solution_available: - raise RuntimeError("No tree solution available") + # try loading a solution + if not self.load_best_solution(): + raise RuntimeError("No tree solution available") if self.cylinder_rank == 0: os.makedirs(directory_name, exist_ok=True) self.mpicomm.Barrier()