Skip to content

Commit

Permalink
fixing xhatters
Browse files Browse the repository at this point in the history
  • Loading branch information
bknueven committed Jan 23, 2025
1 parent e97735a commit bc88698
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 44 deletions.
51 changes: 17 additions & 34 deletions mpisppy/cylinders/spoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,43 +328,26 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, options=
self.best_inner_bound = math.inf if self.is_minimizing else -math.inf
self.solver_options = None # can be overwritten by derived classes

# set up best solution cache
for k,s in self.opt.local_scenarios.items():
s._mpisppy_data.best_solution_cache = None

def update_if_improving(self, candidate_inner_bound):
if candidate_inner_bound is None:
return False
update = (candidate_inner_bound < self.best_inner_bound) \
if self.is_minimizing else \
def update_if_improving(self, candidate_inner_bound, update_cache=True):
if update_cache:
update = self.opt.update_best_solution_if_improving(candidate_inner_bound)
else:
update = ( (candidate_inner_bound < self.best_inner_bound)
if self.is_minimizing else
(self.best_inner_bound < candidate_inner_bound)
if not update:
return False

self.best_inner_bound = candidate_inner_bound
# send to hub
self.bound = candidate_inner_bound
self._cache_best_solution()
return True
)
if update:
self.best_inner_bound = candidate_inner_bound
# send to hub
self.bound = candidate_inner_bound
return True
return False

def finalize(self):
for k,s in self.opt.local_scenarios.items():
if s._mpisppy_data.best_solution_cache is None:
return None
for var, value in s._mpisppy_data.best_solution_cache.items():
var.set_value(value, skip_validation=True)

self.opt.first_stage_solution_available = True
self.opt.tree_solution_available = True
self.final_bound = self.bound
return self.final_bound

def _cache_best_solution(self):
for k,s in self.opt.local_scenarios.items():
scenario_cache = ComponentMap()
for var in s.component_data_objects(Var):
scenario_cache[var] = var.value
s._mpisppy_data.best_solution_cache = scenario_cache
if self.opt.load_best_solution():
self.final_bound = self.bound
return self.final_bound
return None


class OuterBoundNonantSpoke(_BoundNonantSpoke):
Expand Down
5 changes: 3 additions & 2 deletions mpisppy/cylinders/xhatlooper_bounder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ def main(self):
# just for sending the values to other scenarios
# so we don't need to tell persistent solvers
self.opt._restore_nonants(update_persistent=False)
upperbound, srcsname = xhatter.xhat_looper(scen_limit=scen_limit, restore_nonants=False)
upperbound, srcsname = xhatter.xhat_looper(scen_limit=scen_limit, restore_nonants=True)

# send a bound to the opt companion
self.update_if_improving(upperbound)
# the xhatter updates the cache in the opt object for us
self.update_if_improving(upperbound, update_cache=False)
xh_iter += 1
6 changes: 4 additions & 2 deletions mpisppy/cylinders/xhatshufflelooper_bounder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def try_scenario_dict(self, xhat_scenario_dict):
obj = self.xhatter._try_one(snamedict,
solver_options = self.solver_options,
verbose=False,
restore_nonants=False,
restore_nonants=True,
stage2EFsolvern=stage2EFsolvern,
branching_factors=branching_factors)
def _vb(msg):
Expand All @@ -83,7 +83,8 @@ def _vb(msg):
return False
_vb(f" Feasible {snamedict}, obj: {obj}")

update = self.update_if_improving(obj)
# the xhatter updates the cache in the opt object for us
update = self.update_if_improving(obj, update_cache=False)
logger.debug(f' bottom of try_scenario_dict on rank {self.global_rank}')
return update

Expand Down Expand Up @@ -143,6 +144,7 @@ def _vb(msg):
# so we don't need to tell persistent solvers
self.opt._restore_nonants(update_persistent=False)

_vb(f" Begin epoch")
scenario_cycler.begin_epoch()

next_scendict = scenario_cycler.get_next()
Expand Down
2 changes: 2 additions & 0 deletions mpisppy/extensions/xhatbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def _try_one(self, snamedict, solver_options=None, verbose=False,
self.opt.local_scenarios[sname].pprint()
# get the global obj
obj = self.opt.Eobjective(verbose=verbose)
self.opt.update_best_solution_if_improving(obj)
if restore_nonants:
self.opt._restore_nonants()
return obj
Expand All @@ -228,6 +229,7 @@ def _try_one(self, snamedict, solver_options=None, verbose=False,
print(" Feasible xhat found:")
self.opt.local_scenarios[sname].pprint()
obj = self.opt.Eobjective(verbose=verbose)
self.opt.update_best_solution_if_improving(obj)
if restore_nonants:
self.opt._restore_nonants()
return obj
Expand Down
60 changes: 54 additions & 6 deletions mpisppy/spbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def __init__(
self.n_proc = self.mpicomm.Get_size()
self.global_rank = MPI.COMM_WORLD.Get_rank()

# for writers, if the appropriate
# solution is loaded into the subproblems
self.tree_solution_available = False
self.first_stage_solution_available = False
self.best_solution_obj_val = None

if options.get("toc", True):
global_toc("Initializing SPBase")
Expand Down Expand Up @@ -120,14 +125,11 @@ def __init__(
self._verify_nonant_lengths()
self._set_sense()
self._use_variable_probability_setter()
self._set_best_solution_cache()

## SPCommunicator object
self._spcomm = None

# for writers, if the appropriate
# solution is loaded into the subproblems
self.tree_solution_available = False
self.first_stage_solution_available = False

def _set_sense(self, comm=None):
""" Check to confirm that all the models constructed by scenario_crator
Expand Down Expand Up @@ -540,6 +542,48 @@ def _options_check(self, required_options, given_options):
if missing:
raise ValueError(f"Missing the following required options: {', '.join(missing)}")

def _set_best_solution_cache(self):
# set up best solution cache
for k,s in self.local_scenarios.items():
s._mpisppy_data.best_solution_cache = None

def update_best_solution_if_improving(self, obj_val):
""" Call if the variable values have a nonanticipative solution
with associated obj_val. Will update the best_solution_cache
if the solution is better than the existing cached solution
"""
if obj_val is None:
return False
if self.best_solution_obj_val is None:
update = True
elif self.is_minimizing:
update = (obj_val < self.best_solution_obj_val)
else:
update = (self.best_solution_obj_val < obj_val)
if update:
self.best_solution_obj_val = obj_val
self._cache_best_solution()
return True
return False

def _cache_best_solution(self):
for k,s in self.local_scenarios.items():
scenario_cache = pyo.ComponentMap()
for var in s.component_data_objects(pyo.Var):
scenario_cache[var] = var.value
s._mpisppy_data.best_solution_cache = scenario_cache

def load_best_solution(self):
for k,s in self.local_scenarios.items():
if s._mpisppy_data.best_solution_cache is None:
return False
for var, value in s._mpisppy_data.best_solution_cache.items():
var.set_value(value, skip_validation=True)

self.first_stage_solution_available = True
self.tree_solution_available = True
return True

@property
def spcomm(self):
if self._spcomm is None:
Expand Down Expand Up @@ -652,7 +696,9 @@ def write_first_stage_solution(self, file_name,
"""

if not self.first_stage_solution_available:
raise RuntimeError("No first stage solution available")
# try loading a solution
if not self.load_best_solution():
raise RuntimeError("No first stage solution available")
if self.cylinder_rank == 0:
dirname = os.path.dirname(file_name)
if dirname != '':
Expand All @@ -670,7 +716,9 @@ def write_tree_solution(self, directory_name,
scenario_tree_solution_writer (optional): custom scenario solution writer function
"""
if not self.tree_solution_available:
raise RuntimeError("No tree solution available")
# try loading a solution
if not self.load_best_solution():
raise RuntimeError("No tree solution available")
if self.cylinder_rank == 0:
os.makedirs(directory_name, exist_ok=True)
self.mpicomm.Barrier()
Expand Down

0 comments on commit bc88698

Please sign in to comment.