Skip to content

Commit

Permalink
Add state handling to filter_datasets (#841)
Browse files Browse the repository at this point in the history
  • Loading branch information
uri-granta authored May 28, 2024
1 parent adcd419 commit 452b7e2
Show file tree
Hide file tree
Showing 9 changed files with 542 additions and 295 deletions.
21 changes: 13 additions & 8 deletions tests/integration/test_ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
AsynchronousGreedy,
AsynchronousRuleState,
BatchTrustRegionBox,
BatchTrustRegionState,
EfficientGlobalOptimization,
SingleObjectiveTrustRegionBox,
TREGOBox,
UpdatableTrustRegionBox,
)
from trieste.acquisition.utils import copy_to_local_models
from trieste.ask_tell_optimization import AskTellOptimizer, AskTellOptimizerState
Expand Down Expand Up @@ -73,11 +75,6 @@
True,
lambda: BatchTrustRegionBox(TREGOBox(ScaledBranin.search_space)),
id="TREGO/reload_state",
# TODO: trust regions maintain internal state and do not fully support the functional
# API for reloading from acquisition state. So this test is skipped for now.
marks=pytest.mark.skip(
reason="Trust regions do not support reloading from acquisition state"
),
),
pytest.param(
10,
Expand Down Expand Up @@ -136,7 +133,10 @@
Callable[
[],
AcquisitionRule[
State[TensorType, Union[AsynchronousRuleState, BatchTrustRegionBox.State]],
State[
TensorType,
Union[AsynchronousRuleState, BatchTrustRegionState[UpdatableTrustRegionBox]],
],
Box,
TrainableProbabilisticModel,
],
Expand Down Expand Up @@ -220,7 +220,11 @@ def _test_ask_tell_optimization_finds_minima(

if reload_state:
state: AskTellOptimizerState[
None | State[TensorType, AsynchronousRuleState | BatchTrustRegionBox.State],
None
| State[
TensorType,
AsynchronousRuleState | BatchTrustRegionState[UpdatableTrustRegionBox],
],
GaussianProcessRegression,
] = ask_tell.to_state()
written_state = pickle.dumps(state)
Expand Down Expand Up @@ -257,7 +261,8 @@ def _test_ask_tell_optimization_finds_minima(
ask_tell.tell(initial_dataset)

result: OptimizationResult[
None | State[TensorType, AsynchronousRuleState | BatchTrustRegionBox.State],
None
| State[TensorType, AsynchronousRuleState | BatchTrustRegionState[UpdatableTrustRegionBox]],
GaussianProcessRegression,
] = ask_tell.to_result()
dataset = result.try_get_final_dataset()
Expand Down
7 changes: 5 additions & 2 deletions tests/integration/test_bayesian_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,14 @@
AsynchronousOptimization,
AsynchronousRuleState,
BatchHypervolumeSharpeRatioIndicator,
BatchTrustRegion,
BatchTrustRegionBox,
BatchTrustRegionState,
DiscreteThompsonSampling,
EfficientGlobalOptimization,
SingleObjectiveTrustRegionBox,
TREGOBox,
TURBOBox,
UpdatableTrustRegionBox,
)
from trieste.acquisition.sampler import ThompsonSamplerFromTrajectory
from trieste.acquisition.utils import copy_to_local_models
Expand Down Expand Up @@ -287,7 +288,9 @@ def GPR_OPTIMIZER_PARAMS() -> Tuple[str, List[ParameterSet]]:
AcquisitionRuleType = Union[
AcquisitionRule[TensorType, SearchSpace, TrainableProbabilisticModelType],
AcquisitionRule[
State[TensorType, Union[AsynchronousRuleState, BatchTrustRegion.State]],
State[
TensorType, Union[AsynchronousRuleState, BatchTrustRegionState[UpdatableTrustRegionBox]]
],
Box,
TrainableProbabilisticModelType,
],
Expand Down
Loading

0 comments on commit 452b7e2

Please sign in to comment.