Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing xhatters #481

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 17 additions & 35 deletions mpisppy/cylinders/spoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import os
import math

from pyomo.environ import ComponentMap, Var
from mpisppy import MPI
from mpisppy.cylinders.spcommunicator import SPCommunicator, communicator_array

Expand Down Expand Up @@ -328,43 +327,26 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, options=
self.best_inner_bound = math.inf if self.is_minimizing else -math.inf
self.solver_options = None # can be overwritten by derived classes

# set up best solution cache
for k,s in self.opt.local_scenarios.items():
s._mpisppy_data.best_solution_cache = None

def update_if_improving(self, candidate_inner_bound):
if candidate_inner_bound is None:
return False
update = (candidate_inner_bound < self.best_inner_bound) \
if self.is_minimizing else \
def update_if_improving(self, candidate_inner_bound, update_cache=True):
if update_cache:
update = self.opt.update_best_solution_if_improving(candidate_inner_bound)
else:
update = ( (candidate_inner_bound < self.best_inner_bound)
if self.is_minimizing else
(self.best_inner_bound < candidate_inner_bound)
if not update:
return False

self.best_inner_bound = candidate_inner_bound
# send to hub
self.bound = candidate_inner_bound
self._cache_best_solution()
return True
)
if update:
self.best_inner_bound = candidate_inner_bound
# send to hub
self.bound = candidate_inner_bound
return True
return False

def finalize(self):
for k,s in self.opt.local_scenarios.items():
if s._mpisppy_data.best_solution_cache is None:
return None
for var, value in s._mpisppy_data.best_solution_cache.items():
var.set_value(value, skip_validation=True)

self.opt.first_stage_solution_available = True
self.opt.tree_solution_available = True
self.final_bound = self.bound
return self.final_bound

def _cache_best_solution(self):
for k,s in self.opt.local_scenarios.items():
scenario_cache = ComponentMap()
for var in s.component_data_objects(Var):
scenario_cache[var] = var.value
s._mpisppy_data.best_solution_cache = scenario_cache
if self.opt.load_best_solution():
self.final_bound = self.bound
return self.final_bound
return None


class OuterBoundNonantSpoke(_BoundNonantSpoke):
Expand Down
5 changes: 3 additions & 2 deletions mpisppy/cylinders/xhatlooper_bounder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ def main(self):
# just for sending the values to other scenarios
# so we don't need to tell persistent solvers
self.opt._restore_nonants(update_persistent=False)
upperbound, srcsname = xhatter.xhat_looper(scen_limit=scen_limit, restore_nonants=False)
upperbound, srcsname = xhatter.xhat_looper(scen_limit=scen_limit, restore_nonants=True)

# send a bound to the opt companion
self.update_if_improving(upperbound)
# XhatBase._try_one updates the solution cache on the opt object for us
self.update_if_improving(upperbound, update_cache=False)
xh_iter += 1
6 changes: 4 additions & 2 deletions mpisppy/cylinders/xhatshufflelooper_bounder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def try_scenario_dict(self, xhat_scenario_dict):
obj = self.xhatter._try_one(snamedict,
solver_options = self.solver_options,
verbose=False,
restore_nonants=False,
restore_nonants=True,
stage2EFsolvern=stage2EFsolvern,
branching_factors=branching_factors)
def _vb(msg):
Expand All @@ -83,7 +83,8 @@ def _vb(msg):
return False
_vb(f" Feasible {snamedict}, obj: {obj}")

update = self.update_if_improving(obj)
# XhatBase._try_one updates the solution cache in the opt object for us
update = self.update_if_improving(obj, update_cache=False)
logger.debug(f' bottom of try_scenario_dict on rank {self.global_rank}')
return update

Expand Down Expand Up @@ -143,6 +144,7 @@ def _vb(msg):
# so we don't need to tell persistent solvers
self.opt._restore_nonants(update_persistent=False)

_vb(" Begin epoch")
scenario_cycler.begin_epoch()

next_scendict = scenario_cycler.get_next()
Expand Down
2 changes: 2 additions & 0 deletions mpisppy/extensions/xhatbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def _try_one(self, snamedict, solver_options=None, verbose=False,
self.opt.local_scenarios[sname].pprint()
# get the global obj
obj = self.opt.Eobjective(verbose=verbose)
self.opt.update_best_solution_if_improving(obj)
if restore_nonants:
self.opt._restore_nonants()
return obj
Expand All @@ -228,6 +229,7 @@ def _try_one(self, snamedict, solver_options=None, verbose=False,
print(" Feasible xhat found:")
self.opt.local_scenarios[sname].pprint()
obj = self.opt.Eobjective(verbose=verbose)
self.opt.update_best_solution_if_improving(obj)
if restore_nonants:
self.opt._restore_nonants()
return obj
Expand Down
60 changes: 54 additions & 6 deletions mpisppy/spbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def __init__(
self.n_proc = self.mpicomm.Get_size()
self.global_rank = MPI.COMM_WORLD.Get_rank()

# for writers, if the appropriate
# solution is loaded into the subproblems
self.tree_solution_available = False
self.first_stage_solution_available = False
self.best_solution_obj_val = None

if options.get("toc", True):
global_toc("Initializing SPBase")
Expand Down Expand Up @@ -120,14 +125,11 @@ def __init__(
self._verify_nonant_lengths()
self._set_sense()
self._use_variable_probability_setter()
self._set_best_solution_cache()

## SPCommunicator object
self._spcomm = None

# for writers, if the appropriate
# solution is loaded into the subproblems
self.tree_solution_available = False
self.first_stage_solution_available = False

def _set_sense(self, comm=None):
""" Check to confirm that all the models constructed by scenario_crator
Expand Down Expand Up @@ -540,6 +542,48 @@ def _options_check(self, required_options, given_options):
if missing:
raise ValueError(f"Missing the following required options: {', '.join(missing)}")

def _set_best_solution_cache(self):
# set up best solution cache
for k,s in self.local_scenarios.items():
s._mpisppy_data.best_solution_cache = None

def update_best_solution_if_improving(self, obj_val):
""" Call if the variable values have a nonanticipative solution
with associated obj_val. Will update the best_solution_cache
if the solution is better than the existing cached solution
"""
if obj_val is None:
return False
if self.best_solution_obj_val is None:
update = True
elif self.is_minimizing:
update = (obj_val < self.best_solution_obj_val)
else:
update = (self.best_solution_obj_val < obj_val)
if update:
self.best_solution_obj_val = obj_val
self._cache_best_solution()
return True
return False

def _cache_best_solution(self):
for k,s in self.local_scenarios.items():
scenario_cache = pyo.ComponentMap()
for var in s.component_data_objects(pyo.Var):
scenario_cache[var] = var.value
s._mpisppy_data.best_solution_cache = scenario_cache

def load_best_solution(self):
for k,s in self.local_scenarios.items():
if s._mpisppy_data.best_solution_cache is None:
return False
for var, value in s._mpisppy_data.best_solution_cache.items():
var.set_value(value, skip_validation=True)

self.first_stage_solution_available = True
self.tree_solution_available = True
return True

@property
def spcomm(self):
if self._spcomm is None:
Expand Down Expand Up @@ -652,7 +696,9 @@ def write_first_stage_solution(self, file_name,
"""

if not self.first_stage_solution_available:
raise RuntimeError("No first stage solution available")
# try loading a solution
if not self.load_best_solution():
raise RuntimeError("No first stage solution available")
if self.cylinder_rank == 0:
dirname = os.path.dirname(file_name)
if dirname != '':
Expand All @@ -670,7 +716,9 @@ def write_tree_solution(self, directory_name,
scenario_tree_solution_writer (optional): custom scenario solution writer function
"""
if not self.tree_solution_available:
raise RuntimeError("No tree solution available")
# try loading a solution
if not self.load_best_solution():
raise RuntimeError("No tree solution available")
if self.cylinder_rank == 0:
os.makedirs(directory_name, exist_ok=True)
self.mpicomm.Barrier()
Expand Down
Loading