diff --git a/mpisppy/cylinders/xhatbase.py b/mpisppy/cylinders/xhatbase.py index b412d656f..8012cfc1e 100644 --- a/mpisppy/cylinders/xhatbase.py +++ b/mpisppy/cylinders/xhatbase.py @@ -33,7 +33,7 @@ def xhat_prep(self): ### begin iter0 stuff xhatter.pre_iter0() - if self.opt.extobject is not None: + if self.opt.extensions is not None: self.opt.extobject.pre_iter0() # for an extension self.opt._save_original_nonants() @@ -46,7 +46,7 @@ def xhat_prep(self): ### end iter0 stuff (but note: no need for iter 0 solves in an xhatter) xhatter.post_iter0() - if self.opt.extobject is not None: + if self.opt.extensions is not None: self.opt.extobject.post_iter0() # for an extension self.opt._save_nonants() # make the cache diff --git a/mpisppy/utils/xhat_eval.py b/mpisppy/utils/xhat_eval.py index fd93dd723..41217d303 100644 --- a/mpisppy/utils/xhat_eval.py +++ b/mpisppy/utils/xhat_eval.py @@ -43,7 +43,8 @@ def __init__( mpicomm=None, scenario_creator_kwargs=None, variable_probability=None, - ph_extensions=None, + extensions=None, + extension_kwargs=None, ): super().__init__( @@ -52,7 +53,8 @@ def __init__( scenario_creator, scenario_denouement=scenario_denouement, all_nodenames=all_nodenames, - extensions=ph_extensions, + extensions=extensions, + extension_kwargs=extension_kwargs, mpicomm=mpicomm, scenario_creator_kwargs=scenario_creator_kwargs, variable_probability=variable_probability,