From 452b7e21cd985de3dbb410affa8b0a4bfc949b3f Mon Sep 17 00:00:00 2001 From: uri-granta <50578464+uri-granta@users.noreply.github.com> Date: Tue, 28 May 2024 09:43:06 +0100 Subject: [PATCH] Add state handling to filter_datasets (#841) --- .../integration/test_ask_tell_optimization.py | 21 +- .../integration/test_bayesian_optimization.py | 7 +- tests/unit/acquisition/test_rule.py | 262 ++++++++-------- tests/unit/test_ask_tell_optimization.py | 79 +++-- tests/unit/test_bayesian_optimizer.py | 31 +- tests/util/misc.py | 54 +++- trieste/acquisition/rule.py | 283 ++++++++++-------- trieste/ask_tell_optimization.py | 64 +++- trieste/bayesian_optimizer.py | 36 ++- 9 files changed, 542 insertions(+), 295 deletions(-) diff --git a/tests/integration/test_ask_tell_optimization.py b/tests/integration/test_ask_tell_optimization.py index fd3f565095..a97bbae8d2 100644 --- a/tests/integration/test_ask_tell_optimization.py +++ b/tests/integration/test_ask_tell_optimization.py @@ -30,9 +30,11 @@ AsynchronousGreedy, AsynchronousRuleState, BatchTrustRegionBox, + BatchTrustRegionState, EfficientGlobalOptimization, SingleObjectiveTrustRegionBox, TREGOBox, + UpdatableTrustRegionBox, ) from trieste.acquisition.utils import copy_to_local_models from trieste.ask_tell_optimization import AskTellOptimizer, AskTellOptimizerState @@ -73,11 +75,6 @@ True, lambda: BatchTrustRegionBox(TREGOBox(ScaledBranin.search_space)), id="TREGO/reload_state", - # TODO: trust regions maintain internal state and do not fully support the functional - # API for reloading from acquisition state. So this test is skipped for now. - marks=pytest.mark.skip( - reason="Trust regions do not support reloading from acquisition state" - ), ), pytest.param( 10, @@ -136,7 +133,10 @@ Callable[ [], AcquisitionRule[ - State[TensorType, Union[AsynchronousRuleState, BatchTrustRegionBox.State]], + State[ + TensorType, + Union[AsynchronousRuleState, BatchTrustRegionState[UpdatableTrustRegionBox]], + ], Box, TrainableProbabilisticModel, ], @@ -220,7 +220,11 @@ def _test_ask_tell_optimization_finds_minima( if reload_state: state: AskTellOptimizerState[ - None | State[TensorType, AsynchronousRuleState | BatchTrustRegionBox.State], + None + | State[ + TensorType, + AsynchronousRuleState | BatchTrustRegionState[UpdatableTrustRegionBox], + ], GaussianProcessRegression, ] = ask_tell.to_state() written_state = pickle.dumps(state) @@ -257,7 +261,8 @@ def _test_ask_tell_optimization_finds_minima( ask_tell.tell(initial_dataset) result: OptimizationResult[ - None | State[TensorType, AsynchronousRuleState | BatchTrustRegionBox.State], + None + | State[TensorType, AsynchronousRuleState | BatchTrustRegionState[UpdatableTrustRegionBox]], GaussianProcessRegression, ] = ask_tell.to_result() dataset = result.try_get_final_dataset() diff --git a/tests/integration/test_bayesian_optimization.py b/tests/integration/test_bayesian_optimization.py index a039691927..6965253e6c 100644 --- a/tests/integration/test_bayesian_optimization.py +++ b/tests/integration/test_bayesian_optimization.py @@ -50,13 +50,14 @@ AsynchronousOptimization, AsynchronousRuleState, BatchHypervolumeSharpeRatioIndicator, - BatchTrustRegion, BatchTrustRegionBox, + BatchTrustRegionState, DiscreteThompsonSampling, EfficientGlobalOptimization, SingleObjectiveTrustRegionBox, TREGOBox, TURBOBox, + UpdatableTrustRegionBox, ) from trieste.acquisition.sampler import ThompsonSamplerFromTrajectory from trieste.acquisition.utils import copy_to_local_models @@ -287,7 +288,9 @@ def GPR_OPTIMIZER_PARAMS() -> Tuple[str, List[ParameterSet]]: AcquisitionRuleType = Union[ AcquisitionRule[TensorType, SearchSpace, TrainableProbabilisticModelType], AcquisitionRule[ - State[TensorType, Union[AsynchronousRuleState, BatchTrustRegion.State]], + State[ + TensorType, Union[AsynchronousRuleState, BatchTrustRegionState[UpdatableTrustRegionBox]] + ], Box, TrainableProbabilisticModelType, ], diff --git a/tests/unit/acquisition/test_rule.py b/tests/unit/acquisition/test_rule.py index 5415c04a84..b44a56780b 100644 --- a/tests/unit/acquisition/test_rule.py +++ b/tests/unit/acquisition/test_rule.py @@ -15,7 +15,7 @@ import copy from collections.abc import Mapping -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Sequence, Union, cast from unittest.mock import ANY, MagicMock import gpflow @@ -50,6 +50,7 @@ BatchHypervolumeSharpeRatioIndicator, BatchTrustRegionBox, BatchTrustRegionProduct, + BatchTrustRegionState, DiscreteThompsonSampling, EfficientGlobalOptimization, FixedPointTrustRegionDiscrete, @@ -59,6 +60,7 @@ TREGOBox, TURBOBox, UpdatableTrustRegion, + UpdatableTrustRegionBox, UpdatableTrustRegionProduct, ) from trieste.acquisition.sampler import ( @@ -610,19 +612,16 @@ def test_trego_for_default_state( model = QuadraticMeanAndRBFKernel() state, query_point = tr.acquire_single(search_space, model, dataset=dataset)(None) - tr.filter_datasets({OBJECTIVE: model}, {OBJECTIVE: dataset}) + state, _ = tr.filter_datasets({OBJECTIVE: model}, {OBJECTIVE: dataset})(state) assert state is not None ret_subspace = state.acquisition_space.get_subspace("0") assert isinstance(ret_subspace, TREGOBox) npt.assert_array_almost_equal(ret_subspace.lower, lower_bound) npt.assert_array_almost_equal(ret_subspace.upper, upper_bound) - npt.assert_array_almost_equal(query_point, [expected_query_point], 5) - npt.assert_array_almost_equal(subspace.lower, lower_bound) - npt.assert_array_almost_equal(subspace.upper, upper_bound) - npt.assert_array_almost_equal(subspace._y_min, [0.012]) - assert subspace._is_global + npt.assert_array_almost_equal(ret_subspace._y_min, [0.012]) + assert ret_subspace._is_global def trego_create_subspace( @@ -667,27 +666,24 @@ def test_trego_successful_global_to_global_trust_region_unchanged( ) tr = BatchTrustRegionBox(subspace, rule) - previous_state = BatchTrustRegionBox.State(TaggedMultiSearchSpace([subspace])) + previous_state = BatchTrustRegionState[UpdatableTrustRegionBox]([subspace], ["0"]) model = {OBJECTIVE: QuadraticMeanAndRBFKernel()} current_state, query_point = tr.acquire( search_space, model, datasets={OBJECTIVE: dataset}, )(previous_state) - tr.filter_datasets(model, {OBJECTIVE: dataset}) + current_state, _ = tr.filter_datasets(model, {OBJECTIVE: dataset})(current_state) assert current_state is not None - current_subspace = current_state.acquisition_space.get_subspace("0") + current_subspace = current_state.subspaces[0] assert isinstance(current_subspace, TREGOBox) + npt.assert_array_almost_equal(current_subspace._eps, eps) + assert current_subspace._is_global + npt.assert_array_almost_equal(query_point, [expected_query_point], 5) npt.assert_array_almost_equal(current_subspace.lower, lower_bound) npt.assert_array_almost_equal(current_subspace.upper, upper_bound) - npt.assert_array_almost_equal(subspace._eps, eps) - assert subspace._is_global - npt.assert_array_almost_equal(query_point, [expected_query_point], 5) - npt.assert_array_almost_equal(subspace.lower, lower_bound) - npt.assert_array_almost_equal(subspace.upper, upper_bound) - @pytest.mark.parametrize( "rule", @@ -714,25 +710,22 @@ def test_trego_for_unsuccessful_global_to_local_trust_region_unchanged( previous_subspace_copy = copy.deepcopy(subspace) - previous_state = BatchTrustRegionBox.State(TaggedMultiSearchSpace([subspace])) + previous_state = BatchTrustRegionState[UpdatableTrustRegionBox]([subspace], ["0"]) model = {OBJECTIVE: QuadraticMeanAndRBFKernel()} current_state, query_point = tr.acquire( search_space, model, datasets={OBJECTIVE: dataset}, )(previous_state) - tr.filter_datasets(model, {OBJECTIVE: dataset}) + current_state, _ = tr.filter_datasets(model, {OBJECTIVE: dataset})(current_state) assert current_state is not None current_subspace = current_state.acquisition_space.get_subspace("0") assert isinstance(current_subspace, TREGOBox) - npt.assert_array_almost_equal(current_subspace.lower, acquisition_space.lower) - npt.assert_array_almost_equal(current_subspace.upper, acquisition_space.upper) - - npt.assert_array_almost_equal(subspace._eps, eps) - assert not subspace._is_global - npt.assert_array_less(lower_bound, subspace.lower) - npt.assert_array_less(subspace.upper, upper_bound) + npt.assert_array_almost_equal(current_subspace._eps, eps) + assert not current_subspace._is_global + npt.assert_array_less(lower_bound, current_subspace.lower) + npt.assert_array_less(current_subspace.upper, upper_bound) assert query_point[0][0] in previous_subspace_copy @@ -759,25 +752,22 @@ def test_trego_for_successful_local_to_global_trust_region_increased( ) tr = BatchTrustRegionBox(subspace, rule) - previous_state = BatchTrustRegionBox.State(TaggedMultiSearchSpace([subspace])) + previous_state = BatchTrustRegionState[UpdatableTrustRegionBox]([subspace], ["0"]) model = {OBJECTIVE: QuadraticMeanAndRBFKernel()} current_state, _ = tr.acquire( search_space, model, datasets={OBJECTIVE: dataset}, )(previous_state) - tr.filter_datasets(model, {OBJECTIVE: dataset}) + current_state, _ = tr.filter_datasets(model, {OBJECTIVE: dataset})(current_state) assert current_state is not None current_subspace = current_state.acquisition_space.get_subspace("0") assert isinstance(current_subspace, TREGOBox) - npt.assert_array_almost_equal(current_subspace.lower, acquisition_space.lower) - npt.assert_array_almost_equal(current_subspace.upper, acquisition_space.upper) - - npt.assert_array_less(eps, subspace._eps) # current TR larger than previous - assert subspace._is_global - npt.assert_array_almost_equal(subspace.lower, lower_bound) - npt.assert_array_almost_equal(subspace.upper, upper_bound) + npt.assert_array_less(eps, current_subspace._eps) # current TR larger than previous + assert current_subspace._is_global + npt.assert_array_almost_equal(current_subspace.lower, lower_bound) + npt.assert_array_almost_equal(current_subspace.upper, upper_bound) @pytest.mark.parametrize( @@ -803,41 +793,42 @@ def test_trego_for_unsuccessful_local_to_global_trust_region_reduced( ) tr = BatchTrustRegionBox(subspace, rule) - previous_state = BatchTrustRegionBox.State(TaggedMultiSearchSpace([subspace])) + previous_state = BatchTrustRegionState[UpdatableTrustRegionBox]([subspace], ["0"]) model = {OBJECTIVE: QuadraticMeanAndRBFKernel()} current_state, _ = tr.acquire( search_space, model, datasets={OBJECTIVE: dataset}, )(previous_state) - tr.filter_datasets(model, {OBJECTIVE: dataset}) + current_state, _ = tr.filter_datasets(model, {OBJECTIVE: dataset})(current_state) assert current_state is not None current_subspace = current_state.acquisition_space.get_subspace("0") assert isinstance(current_subspace, TREGOBox) - npt.assert_array_almost_equal(current_subspace.lower, acquisition_space.lower) - npt.assert_array_almost_equal(current_subspace.upper, acquisition_space.upper) - - npt.assert_array_less(subspace._eps, eps) # current TR smaller than previous - assert subspace._is_global - npt.assert_array_almost_equal(subspace.lower, lower_bound) - npt.assert_array_almost_equal(subspace.upper, upper_bound) + npt.assert_array_less(current_subspace._eps, eps) # current TR smaller than previous + assert current_subspace._is_global + npt.assert_array_almost_equal(current_subspace.lower, lower_bound) + npt.assert_array_almost_equal(current_subspace.upper, upper_bound) def test_trego_always_uses_global_dataset() -> None: search_space = Box([0.0, 0.0], [1.0, 1.0]) dataset = Dataset( - tf.constant([[0.1, 0.2], [-0.1, -0.2], [1.1, 2.3]]), tf.constant([[0.4], [0.5], [0.6]]) + tf.constant([[0.1, 0.2], [-0.1, -0.2], [1.1, 2.3]], dtype=tf.float64), + tf.constant([[0.4], [0.5], [0.6]], dtype=tf.float64), ) tr = BatchTrustRegionBox(TREGOBox(search_space)) # type: ignore[var-annotated] + state, _ = tr.acquire( + search_space, {OBJECTIVE: QuadraticMeanAndRBFKernel()}, {OBJECTIVE: dataset} + )(None) new_data = Dataset( - tf.constant([[0.5, -0.2], [0.7, 0.2], [1.1, 0.3], [0.5, 0.5]]), - tf.constant([[0.7], [0.8], [0.9], [1.0]]), + tf.constant([[0.5, -0.2], [0.7, 0.2], [1.1, 0.3], [0.5, 0.5]], dtype=tf.float64), + tf.constant([[0.7], [0.8], [0.9], [1.0]], dtype=tf.float64), ) - updated_datasets = tr.filter_datasets( + _, updated_datasets = tr.filter_datasets( {LocalizedTag(OBJECTIVE, 0): QuadraticMeanAndRBFKernel()}, {OBJECTIVE: dataset + new_data, LocalizedTag(OBJECTIVE, 0): dataset + new_data}, - ) + )(state) # Both the local and global datasets should match. assert updated_datasets.keys() == {OBJECTIVE, LocalizedTag(OBJECTIVE, 0)} @@ -859,7 +850,7 @@ def test_trego_state_deepcopy() -> None: tf.constant(7.8), False, ) - tr_state = BatchTrustRegionBox.State(TaggedMultiSearchSpace([subspace])) + tr_state = BatchTrustRegionState[UpdatableTrustRegionBox]([subspace], ["0"]) tr_state_copy = copy.deepcopy(tr_state) tr_subspace = tr_state.acquisition_space.get_subspace("0") tr_subspace_copy = tr_state_copy.acquisition_space.get_subspace("0") @@ -927,8 +918,8 @@ def test_turbo_heuristics_for_param_init_work() -> None: search_space = Box(lower_bound, upper_bound) rule = BatchTrustRegionBox(TURBOBox(search_space)) # type: ignore[var-annotated] rule.acquire(search_space, {OBJECTIVE: QuadraticMeanAndRBFKernel()}) - assert rule._subspaces is not None - region = rule._subspaces[0] + assert rule._init_subspaces is not None + region = rule._init_subspaces[0] assert isinstance(region, TURBOBox) assert region.L_init == 0.8 * 3.0 @@ -995,19 +986,17 @@ def test_turbo_for_default_state( lengthscales=tf.constant(lengthscales, dtype=tf.float64), variance=1e-5 ) # need a gpflow kernel for TURBOBox state, query_point = tr.acquire_single(search_space, model, dataset=dataset)(None) - tr.filter_datasets({OBJECTIVE: model}, {OBJECTIVE: dataset}) + state, _ = tr.filter_datasets({OBJECTIVE: model}, {OBJECTIVE: dataset})(state) assert state is not None state_region = state.acquisition_space.get_subspace("0") assert isinstance(state_region, TURBOBox) - npt.assert_array_almost_equal(state_region.lower, orig_region.lower) - npt.assert_array_almost_equal(state_region.upper, orig_region.upper) - npt.assert_array_almost_equal(region.lower, lower_bound) - npt.assert_array_almost_equal(region.upper, tf.constant(exp_upper, dtype=tf.float64)) - npt.assert_array_almost_equal(region.y_min, [0.012]) - npt.assert_array_almost_equal(region.L, tf.cast(0.8, dtype=tf.float64)) - assert region.success_counter == 0 - assert region.failure_counter == 0 + npt.assert_array_almost_equal(state_region.lower, lower_bound) + npt.assert_array_almost_equal(state_region.upper, tf.constant(exp_upper, dtype=tf.float64)) + npt.assert_array_almost_equal(state_region.y_min, [0.012]) + npt.assert_array_almost_equal(state_region.L, tf.cast(0.8, dtype=tf.float64)) + assert state_region.success_counter == 0 + assert state_region.failure_counter == 0 def turbo_create_region( @@ -1059,23 +1048,25 @@ def test_turbo_doesnt_change_size_unless_needed() -> None: success_counter, previous_y_min, ) - previous_state = BatchTrustRegionBox.State(TaggedMultiSearchSpace([region])) - tr._subspaces = (region,) + previous_state = BatchTrustRegionState[UpdatableTrustRegionBox]([region], ["0"]) + tr._init_subspaces = (region,) current_state, _ = tr.acquire( search_space, models, datasets={OBJECTIVE: dataset}, )(previous_state) - tr.filter_datasets(models, {OBJECTIVE: dataset}) + current_state, _ = tr.filter_datasets(models, {OBJECTIVE: dataset})(current_state) assert current_state is not None state_region = current_state.acquisition_space.get_subspace("0") assert isinstance(state_region, TURBOBox) - npt.assert_array_almost_equal(region.L, tf.cast(0.8, dtype=tf.float64)) - npt.assert_array_almost_equal(region.lower, lower_bound) - npt.assert_array_almost_equal(region.upper, tf.constant([0.8, 0.2], dtype=tf.float64)) - assert region.success_counter == success_counter + 1 - assert region.failure_counter == 0 + npt.assert_array_almost_equal(state_region.L, tf.cast(0.8, dtype=tf.float64)) + npt.assert_array_almost_equal(state_region.lower, lower_bound) + npt.assert_array_almost_equal( + state_region.upper, tf.constant([0.8, 0.2], dtype=tf.float64) + ) + assert state_region.success_counter == success_counter + 1 + assert state_region.failure_counter == 0 # failure but not enough to trigger size change previous_y_min = dataset.observations[0] # force failure @@ -1088,23 +1079,23 @@ def test_turbo_doesnt_change_size_unless_needed() -> None: success_counter, previous_y_min, ) - previous_state = BatchTrustRegionBox.State(TaggedMultiSearchSpace([region])) - tr._subspaces = (region,) + previous_state = BatchTrustRegionState[UpdatableTrustRegionBox]([region], ["0"]) + tr._init_subspaces = (region,) current_state, _ = tr.acquire( search_space, models, datasets={OBJECTIVE: dataset}, )(previous_state) - tr.filter_datasets(models, {OBJECTIVE: dataset}) + current_state, _ = tr.filter_datasets(models, {OBJECTIVE: dataset})(current_state) assert current_state is not None state_region = current_state.acquisition_space.get_subspace("0") assert isinstance(state_region, TURBOBox) - npt.assert_array_almost_equal(region.L, tf.cast(0.8, dtype=tf.float64)) - npt.assert_array_almost_equal(region.lower, lower_bound) - npt.assert_array_almost_equal(region.upper, tf.constant([0.8, 0.2], dtype=tf.float64)) - assert region.success_counter == 0 - assert region.failure_counter == 1 + npt.assert_array_almost_equal(state_region.L, tf.cast(0.8, dtype=tf.float64)) + npt.assert_array_almost_equal(state_region.lower, lower_bound) + npt.assert_array_almost_equal(state_region.upper, tf.constant([0.8, 0.2], dtype=tf.float64)) + assert state_region.success_counter == 0 + assert state_region.failure_counter == 1 def test_turbo_does_change_size_correctly_when_needed() -> None: @@ -1136,23 +1127,23 @@ def test_turbo_does_change_size_correctly_when_needed() -> None: 2, previous_y_min, ) - previous_state = BatchTrustRegionBox.State(TaggedMultiSearchSpace([region])) - tr._subspaces = (region,) + previous_state = BatchTrustRegionState[UpdatableTrustRegionBox]([region], ["0"]) + tr._init_subspaces = (region,) current_state, _ = tr.acquire( search_space, models, datasets={OBJECTIVE: dataset}, )(previous_state) - tr.filter_datasets(models, {OBJECTIVE: dataset}) + current_state, _ = tr.filter_datasets(models, {OBJECTIVE: dataset})(current_state) assert current_state is not None state_region = current_state.acquisition_space.get_subspace("0") assert isinstance(state_region, TURBOBox) - npt.assert_array_almost_equal(region.L, tf.cast(1.6, dtype=tf.float64)) - npt.assert_array_almost_equal(region.lower, lower_bound) - npt.assert_array_almost_equal(region.upper, tf.constant([1.0, 0.4], dtype=tf.float64)) - assert region.success_counter == 0 - assert region.failure_counter == 0 + npt.assert_array_almost_equal(state_region.L, tf.cast(1.6, dtype=tf.float64)) + npt.assert_array_almost_equal(state_region.lower, lower_bound) + npt.assert_array_almost_equal(state_region.upper, tf.constant([1.0, 0.4], dtype=tf.float64)) + assert state_region.success_counter == 0 + assert state_region.failure_counter == 0 # hits failure limit previous_y_min = dataset.observations[0] # force failure for success_counter in [0, 1, 2]: @@ -1164,23 +1155,23 @@ def test_turbo_does_change_size_correctly_when_needed() -> None: success_counter, previous_y_min, ) - previous_state = BatchTrustRegionBox.State(TaggedMultiSearchSpace([region])) - tr._subspaces = (region,) + previous_state = BatchTrustRegionState[UpdatableTrustRegionBox]([region], ["0"]) + tr._init_subspaces = (region,) current_state, _ = tr.acquire( search_space, models, datasets={OBJECTIVE: dataset}, )(previous_state) - tr.filter_datasets(models, {OBJECTIVE: dataset}) + current_state, _ = tr.filter_datasets(models, {OBJECTIVE: dataset})(current_state) assert current_state is not None state_region = current_state.acquisition_space.get_subspace("0") assert isinstance(state_region, TURBOBox) - npt.assert_array_almost_equal(region.L, tf.cast(0.4, dtype=tf.float64)) - npt.assert_array_almost_equal(region.lower, lower_bound) - npt.assert_array_almost_equal(region.upper, tf.constant([0.4, 0.1], dtype=tf.float64)) - assert region.success_counter == 0 - assert region.failure_counter == 0 + npt.assert_array_almost_equal(state_region.L, tf.cast(0.4, dtype=tf.float64)) + npt.assert_array_almost_equal(state_region.lower, lower_bound) + npt.assert_array_almost_equal(state_region.upper, tf.constant([0.4, 0.1], dtype=tf.float64)) + assert state_region.success_counter == 0 + assert state_region.failure_counter == 0 def test_turbo_restarts_tr_when_too_small() -> None: @@ -1210,51 +1201,51 @@ def test_turbo_restarts_tr_when_too_small() -> None: region = turbo_create_region( search_space, previous_search_space, L, failure_counter, success_counter, previous_y_min ) - previous_state = BatchTrustRegionBox.State(TaggedMultiSearchSpace([region])) - tr._subspaces = (region,) + previous_state = BatchTrustRegionState[UpdatableTrustRegionBox]([region], ["0"]) + tr._init_subspaces = (region,) current_state, _ = tr.acquire( search_space, models, datasets={OBJECTIVE: dataset}, )(previous_state) - tr.filter_datasets(models, {OBJECTIVE: dataset}) + current_state, _ = tr.filter_datasets(models, {OBJECTIVE: dataset})(current_state) assert current_state is not None state_region = current_state.acquisition_space.get_subspace("0") assert isinstance(state_region, TURBOBox) - npt.assert_array_almost_equal(region.L, tf.cast(0.8, dtype=tf.float64)) - npt.assert_array_almost_equal(region.lower, lower_bound) - npt.assert_array_almost_equal(region.upper, tf.constant([0.8, 0.2], dtype=tf.float64)) - assert region.success_counter == 0 - assert region.failure_counter == 0 + npt.assert_array_almost_equal(state_region.L, tf.cast(0.8, dtype=tf.float64)) + npt.assert_array_almost_equal(state_region.lower, lower_bound) + npt.assert_array_almost_equal(state_region.upper, tf.constant([0.8, 0.2], dtype=tf.float64)) + assert state_region.success_counter == 0 + assert state_region.failure_counter == 0 # secondly check what happens if L is too small after triggering decreasing the region region = turbo_create_region( search_space, previous_search_space, 0.5**6 - 0.1, 1, success_counter, previous_y_min ) - previous_state = BatchTrustRegionBox.State(TaggedMultiSearchSpace([region])) - tr._subspaces = (region,) + previous_state = BatchTrustRegionState[UpdatableTrustRegionBox]([region], ["0"]) + tr._init_subspaces = (region,) current_state, _ = tr.acquire( search_space, models, datasets={OBJECTIVE: dataset}, )(previous_state) - tr.filter_datasets(models, {OBJECTIVE: dataset}) + current_state, _ = tr.filter_datasets(models, {OBJECTIVE: dataset})(current_state) assert current_state is not None state_region = current_state.acquisition_space.get_subspace("0") assert isinstance(state_region, TURBOBox) - npt.assert_array_almost_equal(region.L, tf.cast(0.8, dtype=tf.float64)) - npt.assert_array_almost_equal(region.lower, lower_bound) - npt.assert_array_almost_equal(region.upper, tf.constant([0.8, 0.2], dtype=tf.float64)) - assert region.success_counter == 0 - assert region.failure_counter == 0 + npt.assert_array_almost_equal(state_region.L, tf.cast(0.8, dtype=tf.float64)) + npt.assert_array_almost_equal(state_region.lower, lower_bound) + npt.assert_array_almost_equal(state_region.upper, tf.constant([0.8, 0.2], dtype=tf.float64)) + assert state_region.success_counter == 0 + assert state_region.failure_counter == 0 def test_turbo_state_deepcopy() -> None: search_space = Box(tf.constant([1.2]), tf.constant([3.4])) subspace = turbo_create_region(search_space, search_space, 0.8, 0, 0, tf.constant(7.8)) - tr_state = BatchTrustRegionBox.State(TaggedMultiSearchSpace([subspace])) + tr_state = BatchTrustRegionState[UpdatableTrustRegionBox]([subspace], ["0"]) tr_state_copy = copy.deepcopy(tr_state) tr_subspace = tr_state.acquisition_space.get_subspace("0") tr_subspace_copy = tr_state_copy.acquisition_space.get_subspace("0") @@ -1535,19 +1526,24 @@ def test_trust_region_box_update_size(success: bool) -> None: (RandomSampling(num_query_points=2), 1), ], ) +@pytest.mark.parametrize("acquire", [True, False]) def test_multi_trust_region_box_no_subspace( rule: AcquisitionRule[TensorType, SearchSpace, ProbabilisticModel], exp_num_subspaces: int, + acquire: bool, ) -> None: """Check multi trust region works when no subspace is provided.""" search_space = Box([0.0, 0.0], [1.0, 1.0]) mtb = BatchTrustRegionBox(rule=rule) - mtb.acquire(search_space, {}) + if acquire: + mtb.acquire(search_space, {}) + else: + mtb.initialize_subspaces(search_space) assert mtb._tags is not None - assert mtb._subspaces is not None - assert len(mtb._subspaces) == exp_num_subspaces - for i, (subspace, tag) in enumerate(zip(mtb._subspaces, mtb._tags)): + assert mtb._init_subspaces is not None + assert len(mtb._init_subspaces) == exp_num_subspaces + for i, (subspace, tag) in enumerate(zip(mtb._init_subspaces, mtb._tags)): assert isinstance(subspace, SingleObjectiveTrustRegionBox) assert subspace.global_search_space == search_space assert tag == f"{i}" @@ -1558,7 +1554,7 @@ def test_multi_trust_region_box_single_subspace() -> None: search_space = Box([0.0, 0.0], [1.0, 1.0]) subspace = SingleObjectiveTrustRegionBox(search_space) mtb = BatchTrustRegionBox(subspace) # type: ignore[var-annotated] - assert mtb._subspaces == (subspace,) + assert mtb._init_subspaces == (subspace,) assert mtb._tags == ("0",) @@ -1628,9 +1624,7 @@ def test_multi_trust_region_box_raises_on_mismatched_tags() -> None: subspaces = [SingleObjectiveTrustRegionBox(search_space) for _ in range(2)] mtb = BatchTrustRegionBox(subspaces, base_rule) - state = BatchTrustRegionBox.State( - acquisition_space=TaggedMultiSearchSpace(subspaces, tags=["a", "b"]) - ) + state = BatchTrustRegionState[UpdatableTrustRegionBox](subspaces, ["a", "b"]) models = {OBJECTIVE: QuadraticMeanAndRBFKernelWithSamplers(dataset)} state_func = mtb.acquire( search_space, @@ -1700,12 +1694,13 @@ def test_multi_trust_region_box_inits_regions_that_need_it() -> None: assert bool(subspaces[2].requires_initialization) is False mtb = BatchTrustRegionBox(subspaces) # type: ignore[var-annotated] - mtb.filter_datasets({OBJECTIVE: model}, {OBJECTIVE: dataset}) + state, _ = mtb.filter_datasets({OBJECTIVE: model}, {OBJECTIVE: dataset})(None) # Check that the second region was re-initialized. - assert subspaces[0].eps < 0.35 # Expect reduction. - assert subspaces[1].eps == 0.4 # Expect re-initialized value. - assert subspaces[2].eps < 0.32 # Expect reduction. + assert state is not None + assert cast(TestTrustRegionBox, state.subspaces[0]).eps < 0.35 # Expect reduction. + assert cast(TestTrustRegionBox, state.subspaces[1]).eps == 0.4 # Expect re-initialized value. + assert cast(TestTrustRegionBox, state.subspaces[2]).eps < 0.32 # Expect reduction. def test_multi_trust_region_box_acquire_with_state() -> None: @@ -1733,7 +1728,7 @@ def test_multi_trust_region_box_acquire_with_state() -> None: TestTrustRegionBox(tf.constant([0.3, 0.3], dtype=tf.float64) + 1e-7, search_space), ] mtb = BatchTrustRegionBox(subspaces, base_rule) - state = BatchTrustRegionBox.State(acquisition_space=TaggedMultiSearchSpace(subspaces)) + state = BatchTrustRegionState[UpdatableTrustRegionBox](subspaces, ["0", "1", "2"]) for subspace in subspaces: subspace.initialize(datasets={OBJECTIVE: init_dataset}) @@ -1743,7 +1738,7 @@ def test_multi_trust_region_box_acquire_with_state() -> None: ) state_func = mtb.acquire(search_space, models, {OBJECTIVE: dataset}) next_state, points = state_func(state) - mtb.filter_datasets(models, {OBJECTIVE: dataset}) + next_state, _ = mtb.filter_datasets(models, {OBJECTIVE: dataset})(next_state) assert next_state is not None assert points.shape == [1, 3, 2] @@ -1753,7 +1748,7 @@ def test_multi_trust_region_box_acquire_with_state() -> None: # subspace. for point, subspace, exp_obs, exp_eps in zip( points[0], - subspaces, + cast(Sequence[TestTrustRegionBox], next_state.subspaces), [dataset.observations[0], dataset.observations[2], dataset.observations[0]], [0.1, 0.1, 0.07], # First two regions updated, third region initialized. ): @@ -1928,7 +1923,7 @@ def test_multi_trust_region_box_updated_datasets_are_in_regions( num_query_points=num_query_points_per_region, ) rule = BatchTrustRegionBox(subspaces, base_rule) - _, points = rule.acquire(search_space, models, datasets)(None) + state, points = rule.acquire(search_space, models, datasets)(None) observer = mk_batch_observer(quadratic) new_data = observer(points) assert not isinstance(new_data, Dataset) @@ -1938,7 +1933,7 @@ def test_multi_trust_region_box_updated_datasets_are_in_regions( _, dataset = get_value_for_tag(datasets, *[tag, LocalizedTag.from_tag(tag).global_tag]) assert dataset is not None updated_datasets[tag] = dataset + new_data[tag] - filtered_datasets = rule.filter_datasets(models, updated_datasets) + _, filtered_datasets = rule.filter_datasets(models, updated_datasets)(state) # Check local datasets. for i, subspace in enumerate(subspaces): @@ -1997,7 +1992,7 @@ def test_multi_trust_region_box_state_deepcopy() -> None: ] for _subspace in subspaces: _subspace.initialize(datasets={OBJECTIVE: dataset}) - state = BatchTrustRegionBox.State(acquisition_space=TaggedMultiSearchSpace(subspaces)) + state = BatchTrustRegionState[UpdatableTrustRegionBox](subspaces, ["0", "1", "2"]) state_copy = copy.deepcopy(state) assert state_copy is not state @@ -2477,23 +2472,28 @@ def test_updatable_tr_product_datasets_filter_mask_value() -> None: (RandomSampling(num_query_points=2), 1), ], ) +@pytest.mark.parametrize("acquire", [True, False]) def test_batch_trust_region_product_no_subspace( discrete_search_space: DiscreteSearchSpace, continuous_search_space: Box, rule: AcquisitionRule[TensorType, SearchSpace, ProbabilisticModel], exp_num_subspaces: int, + acquire: bool, ) -> None: """Check batch trust region creates default subspaces when none are provided at init.""" search_space = TaggedProductSearchSpace( [discrete_search_space, continuous_search_space, discrete_search_space] ) tr_rule = BatchTrustRegionProduct(rule=rule) - tr_rule.acquire(search_space, {}) + if acquire: + tr_rule.acquire(search_space, {}) + else: + tr_rule.initialize_subspaces(search_space) assert tr_rule._tags is not None - assert tr_rule._subspaces is not None - assert len(tr_rule._subspaces) == exp_num_subspaces - for i, (subspace, tag) in enumerate(zip(tr_rule._subspaces, tr_rule._tags)): + assert tr_rule._init_subspaces is not None + assert len(tr_rule._init_subspaces) == exp_num_subspaces + for i, (subspace, tag) in enumerate(zip(tr_rule._init_subspaces, tr_rule._tags)): assert isinstance(subspace, UpdatableTrustRegionProduct) assert subspace.global_search_space == search_space assert tag == f"{i}" diff --git a/tests/unit/test_ask_tell_optimization.py b/tests/unit/test_ask_tell_optimization.py index 62c4d446fc..e6a168109a 100644 --- a/tests/unit/test_ask_tell_optimization.py +++ b/tests/unit/test_ask_tell_optimization.py @@ -44,7 +44,7 @@ from trieste.models.interfaces import ProbabilisticModel, TrainableProbabilisticModel from trieste.objectives.utils import mk_batch_observer from trieste.observer import OBJECTIVE -from trieste.space import Box, SearchSpace +from trieste.space import Box from trieste.types import State, Tag, TensorType from trieste.utils.misc import LocalizedTag @@ -83,7 +83,7 @@ def acquisition_rule() -> AcquisitionRule[TensorType, Box, ProbabilisticModel]: @pytest.fixture def local_acquisition_rule() -> LocalDatasetsAcquisitionRule[TensorType, Box, ProbabilisticModel]: - return FixedLocalAcquisitionRule([[0.0]]) + return FixedLocalAcquisitionRule([[0.0]], 3) @pytest.fixture @@ -217,12 +217,11 @@ def test_ask_tell_optimizer_loads_from_state( ], ) - ask_tell = optimizer.from_record( - old_state.record, + ask_tell = optimizer.from_state( + old_state, search_space, local_acquisition_rule, track_data=False, - local_data_ixs=old_state.local_data_ixs, ) new_state: AskTellOptimizerState[None, TrainableProbabilisticModel] = ask_tell.to_state() @@ -670,19 +669,6 @@ def update(self, dataset: Dataset) -> None: self.update_count += 1 -class LocalDatasetsFixedAcquisitionRule( - FixedAcquisitionRule, - LocalDatasetsAcquisitionRule[TensorType, SearchSpace, ProbabilisticModel], -): - def __init__(self, query_points: TensorType, num_local_datasets: int) -> None: - super().__init__(query_points) - self._num_local_datasets = num_local_datasets - - @property - def num_local_datasets(self) -> int: - return self._num_local_datasets - - # Check that the correct dataset is routed to the model. # Note: this test is almost identical to the one in test_bayesian_optimizer.py. @pytest.mark.parametrize("use_global_model", [True, False]) @@ -717,7 +703,7 @@ def test_ask_tell_optimizer_creates_correct_datasets_for_rank3_points( model._tag = tag observer = mk_batch_observer(lambda x: Dataset(x, x)) - rule = LocalDatasetsFixedAcquisitionRule(query_points, batch_size) + rule = FixedLocalAcquisitionRule(query_points, batch_size) ask_tell = AskTellOptimizer(search_space, init_data, models, rule) points = ask_tell.ask() @@ -814,3 +800,58 @@ def test_ask_tell_optimizer_raises_with_badly_shaped_new_data_idxs( ) with pytest.raises(ValueError, match="new_data_ixs has 1"): ask_tell.tell(init_dataset + new_data, new_data_ixs=[tf.constant([[4]])]) + + +@pytest.mark.parametrize("optimizer", OPTIMIZERS) +def test_ask_tell_optimizer_uses_pre_filter_state_in_to_record( + search_space: Box, + init_dataset: Dataset, + model: TrainableProbabilisticModel, + local_acquisition_rule: LocalDatasetsAcquisitionRule[ + TensorType, Box, TrainableProbabilisticModel + ], + optimizer: OptimizerType, +) -> None: + ask_tell = optimizer( + search_space, init_dataset, model, local_acquisition_rule, track_data=False + ) + new_data = mk_dataset( + [[x / 100] for x in range(75, 75 + 6)], [[x / 100] for x in range(75, 75 + 6)] + ) + + # the internal acquisition state is incremented every time we call either ask or tell + # and once at initialisation; however, the state reported in to_record() is only updated + # after calling ask + assert ask_tell.to_record().acquisition_state is None + ask_tell.ask() + assert ask_tell.to_record().acquisition_state == 2 + ask_tell.tell(init_dataset + new_data) + assert ask_tell.to_record().acquisition_state == 2 + ask_tell.ask() + assert ask_tell.to_record().acquisition_state == 4 + ask_tell.tell(init_dataset + new_data + new_data) + assert ask_tell.to_record().acquisition_state == 4 + + # the pattern continues for a copy made using the reported state + ask_tell_copy = optimizer.from_record( + ask_tell.to_record(), search_space, local_acquisition_rule, track_data=False + ) + assert ask_tell_copy.to_record().acquisition_state == 4 + ask_tell_copy.ask() + assert ask_tell_copy.to_record().acquisition_state == 6 + + +@pytest.mark.parametrize("optimizer", OPTIMIZERS) +def test_ask_tell_optimizer_calls_initialize_subspaces( + search_space: Box, + init_dataset: Dataset, + model: TrainableProbabilisticModel, + local_acquisition_rule: LocalDatasetsAcquisitionRule[ + TensorType, Box, TrainableProbabilisticModel + ], + optimizer: OptimizerType, +) -> None: + assert isinstance(local_acquisition_rule, FixedLocalAcquisitionRule) + assert local_acquisition_rule._initialize_subspaces_calls == 0 + optimizer(search_space, init_dataset, model, local_acquisition_rule, track_data=False) + assert local_acquisition_rule._initialize_subspaces_calls == 1 diff --git a/tests/unit/test_bayesian_optimizer.py b/tests/unit/test_bayesian_optimizer.py index 54362625ce..0701dd311b 100644 --- a/tests/unit/test_bayesian_optimizer.py +++ b/tests/unit/test_bayesian_optimizer.py @@ -23,9 +23,10 @@ import tensorflow as tf from check_shapes import inherit_check_shapes -from tests.unit.test_ask_tell_optimization import DatasetChecker, LocalDatasetsFixedAcquisitionRule +from tests.unit.test_ask_tell_optimization import DatasetChecker from tests.util.misc import ( FixedAcquisitionRule, + FixedLocalAcquisitionRule, assert_datasets_allclose, empty_dataset, mk_dataset, @@ -281,7 +282,7 @@ def test_bayesian_optimizer_creates_correct_datasets_for_rank3_points( model._tag = tag optimizer = BayesianOptimizer(lambda x: Dataset(x, x), search_space) - rule = LocalDatasetsFixedAcquisitionRule(query_points, batch_size) + rule = FixedLocalAcquisitionRule(query_points, batch_size) optimizer.optimize(1, init_data, models, rule).final_result.unwrap() @@ -715,3 +716,29 @@ def go(state: int | None) -> tuple[int | None, TensorType]: history[step].models[NA].predict(tf.constant([[0.0]], tf.float64)) ) npt.assert_allclose(variance_from_saved_model, 1.0 / (step + 1)) + + +def test_bayesian_optimizer_uses_pre_filter_state_in_history() -> None: + rule = FixedLocalAcquisitionRule([[0.0]], 3) + result = BayesianOptimizer(_quadratic_observer, Box([0], [1])).optimize( + 5, + {NA: mk_dataset([[0.0]], [[0.0]])}, + {NA: _PseudoTrainableQuadratic()}, + rule, + ) + # the states gets updated by both filter_datasets and acquire, but it's the post-acquire + # state that's returned in the history + acquisition_states = [record.acquisition_state for record in result.history] + assert acquisition_states == [None, 2, 4, 6, 8] + + +def test_bayesian_optimizer_calls_initialize_subspaces() -> None: + rule = FixedLocalAcquisitionRule([[0.0]], 3) + assert rule._initialize_subspaces_calls == 0 + BayesianOptimizer(_quadratic_observer, Box([0], [1])).optimize( + 5, + {NA: mk_dataset([[0.0]], [[0.0]])}, + {NA: _PseudoTrainableQuadratic()}, + rule, + ) + assert rule._initialize_subspaces_calls == 1 diff --git a/tests/util/misc.py b/tests/util/misc.py index a87230b672..59b2f2327e 100644 --- a/tests/util/misc.py +++ b/tests/util/misc.py @@ -17,20 +17,31 @@ import os import random from collections.abc import Container, Mapping -from typing import Any, Callable, NoReturn, Optional, Sequence, TypeVar, Union, cast, overload +from typing import ( + Any, + Callable, + NoReturn, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + cast, + overload, +) import numpy as np import numpy.testing as npt import tensorflow as tf from typing_extensions import Final -from trieste.acquisition.rule import AcquisitionRule, LocalDatasetsAcquisitionRule +from trieste.acquisition.rule import AcquisitionRule, LocalDatasetsAcquisitionRule, SearchSpaceType from trieste.data import Dataset from trieste.models import ProbabilisticModel from trieste.objectives import Branin, Hartmann6 from trieste.objectives.utils import mk_observer from trieste.space import SearchSpace -from trieste.types import Tag, TensorType +from trieste.types import State, Tag, TensorType from trieste.utils import shapes_equal TF_DEBUGGING_ERROR_TYPES: Final[tuple[type[Exception], ...]] = ( @@ -195,13 +206,44 @@ def acquire( class FixedLocalAcquisitionRule( - LocalDatasetsAcquisitionRule[TensorType, SearchSpace, ProbabilisticModel], FixedAcquisitionRule + FixedAcquisitionRule, + LocalDatasetsAcquisitionRule[State[Optional[int], TensorType], SearchSpace, ProbabilisticModel], ): - """A local dataset acquisition rule that returns the same fixed value on every step.""" + """A local dataset acquisition rule that returns the same fixed value on every step and + keeps track of a counter internal State.""" + + def __init__(self, query_points: TensorType, num_local_datasets: int) -> None: + super().__init__(query_points) + self._num_local_datasets = num_local_datasets + self._initialize_subspaces_calls = 0 @property def num_local_datasets(self) -> int: - return 3 + return self._num_local_datasets + + def initialize_subspaces(self, search_space: SearchSpaceType) -> None: + self._initialize_subspaces_calls += 1 + + def acquire( + self, + search_space: SearchSpace, + models: Mapping[Tag, ProbabilisticModel], + datasets: Optional[Mapping[Tag, Dataset]] = None, + ) -> TensorType: + def state_func(state: int | None) -> tuple[int | None, TensorType]: + new_state = 1 if state is None else state + 1 + return new_state, self._qp + + return state_func + + def filter_datasets( + self, models: Mapping[Tag, ProbabilisticModel], datasets: Mapping[Tag, Dataset] + ) -> Mapping[Tag, Dataset] | State[int | None, Mapping[Tag, Dataset]]: + def state_func(state: int | None) -> Tuple[int | None, Mapping[Tag, Dataset]]: + new_state = 1 if state is None else state + 1 + return new_state, datasets + + return state_func ShapeLike = Union[tf.TensorShape, Sequence[int]] diff --git a/trieste/acquisition/rule.py b/trieste/acquisition/rule.py index 0fa03b4a24..1335a8013b 100644 --- a/trieste/acquisition/rule.py +++ b/trieste/acquisition/rule.py @@ -170,14 +170,15 @@ def acquire_single( datasets=None if dataset is None else {OBJECTIVE: dataset}, ) + # AcquisitionRule should really have been generic in StateType, but that's too big a change now def filter_datasets( self, models: Mapping[Tag, ProbabilisticModelType], datasets: Mapping[Tag, Dataset] - ) -> Mapping[Tag, Dataset]: + ) -> Mapping[Tag, Dataset] | State[Any | None, Mapping[Tag, Dataset]]: """ Filter the post-acquisition datasets before they are used for model training. For example, this can be used to remove points from the post-acquisition datasets that are no longer in the search space. - Some rules may also update their internal state. + Rules that need to update their internal state should return a State callable. :param models: The model for each tag. :param datasets: The updated datasets after previous acquisition step. @@ -198,6 +199,10 @@ class LocalDatasetsAcquisitionRule( def num_local_datasets(self) -> int: """The number of local datasets required by this rule.""" + @abstractmethod + def initialize_subspaces(self, search_space: SearchSpaceType) -> None: + """Create local subspaces for when no initial subspaces are provided.""" + class EfficientGlobalOptimization( AcquisitionRule[TensorType, SearchSpaceType, ProbabilisticModelType] @@ -1240,9 +1245,31 @@ def get_datasets_filter_mask( """ A type variable bound to :class:`UpdatableTrustRegion`. """ +@dataclass(frozen=True) +class BatchTrustRegionState(Generic[UpdatableTrustRegionType]): + """The acquisition state for the :class:`BatchTrustRegion` acquisition rule.""" + + subspaces: Sequence[UpdatableTrustRegionType] + """ The acquisition space's subspaces. """ + + subspace_tags: Sequence[str] + """ The subspaces' tags. """ + + def __deepcopy__( + self, memo: dict[int, object] + ) -> BatchTrustRegionState[UpdatableTrustRegionType]: + subspaces_copy = copy.deepcopy(self.subspaces) + return BatchTrustRegionState(subspaces_copy, self.subspace_tags) + + @property + def acquisition_space(self) -> TaggedMultiSearchSpace: + """The acquisition search space.""" + return TaggedMultiSearchSpace(self.subspaces, self.subspace_tags) + + class BatchTrustRegion( LocalDatasetsAcquisitionRule[ - types.State[Optional["BatchTrustRegion.State"], TensorType], + types.State[Optional[BatchTrustRegionState[UpdatableTrustRegionType]], TensorType], SearchSpace, ProbabilisticModelType, ], @@ -1257,17 +1284,6 @@ class BatchTrustRegion( object. """ - @dataclass(frozen=True) - class State: - """The acquisition state for the :class:`BatchTrustRegion` acquisition rule.""" - - acquisition_space: TaggedMultiSearchSpace - """ The search space. """ - - def __deepcopy__(self, memo: dict[int, object]) -> BatchTrustRegion.State: - acquisition_space_copy = copy.deepcopy(self.acquisition_space, memo) - return BatchTrustRegion.State(acquisition_space_copy) - def __init__( self: "BatchTrustRegion[ProbabilisticModelType, UpdatableTrustRegionType]", init_subspaces: Union[ @@ -1287,15 +1303,15 @@ def __init__( :class:`~trieste.acquisition.EfficientGlobalOptimization` otherwise. """ # If init_subspaces are not provided, leave it to the subclasses to create them. - self._subspaces = None + self._init_subspaces = None self._tags = None if init_subspaces is not None: if not isinstance(init_subspaces, Sequence): init_subspaces = [init_subspaces] - self._subspaces = tuple(init_subspaces) - for index, subspace in enumerate(self._subspaces): + self._init_subspaces = tuple(init_subspaces) + for index, subspace in enumerate(self._init_subspaces): subspace.region_index = index # Override the index. - self._tags = tuple(str(index) for index in range(len(init_subspaces))) + self._tags = tuple(str(index) for index, _ in enumerate(self._init_subspaces)) self._rule = rule # The rules for each subspace. These are only used when we want to run the base rule @@ -1306,19 +1322,19 @@ def __init__( def __repr__(self) -> str: """""" - return f"""{self.__class__.__name__}({self._subspaces!r}, {self._rule!r})""" + return f"""{self.__class__.__name__}({self._init_subspaces!r}, {self._rule!r})""" @property def num_local_datasets(self) -> int: - assert self._subspaces is not None, "the subspaces have not been initialized" - return len(self._subspaces) + assert self._init_subspaces is not None, "the subspaces have not been initialized" + return len(self._init_subspaces) def acquire( self, search_space: SearchSpace, models: Mapping[Tag, ProbabilisticModelType], datasets: Optional[Mapping[Tag, Dataset]] = None, - ) -> types.State[State | None, TensorType]: + ) -> types.State[BatchTrustRegionState[UpdatableTrustRegionType] | None, TensorType]: """ Use the ``rule`` specified at :meth:`~BatchTrustRegion.__init__` to find new query points. Return a function that constructs these points given a previous trust region @@ -1336,14 +1352,17 @@ def acquire( points from the previous acquisition state. """ - # Subspaces should be set by the time we call `acquire`. + # initialize subspaces + self.initialize_subspaces(search_space) + + # Subspaces should be initialised by the time we call `acquire`. assert self._tags is not None - assert self._subspaces is not None + assert self._init_subspaces is not None # Implement heuristic defaults for the rule if not specified by the user. if self._rule is None: # Use first subspace to determine the type of the base rule. - if isinstance(self._subspaces[0], TURBOBox): + if isinstance(self._init_subspaces[0], TURBOBox): # Default to Thompson sampling with batches of size 1. self._rule = DiscreteThompsonSampling( tf.minimum(100 * search_space.dimension, 5_000), 1 @@ -1380,31 +1399,33 @@ def acquire( self._rules = [copy.deepcopy(self._rule) for _ in range(num_subspaces)] def state_func( - state: BatchTrustRegion.State | None, - ) -> Tuple[BatchTrustRegion.State | None, TensorType]: + state: BatchTrustRegionState[UpdatableTrustRegionType] | None, + ) -> Tuple[BatchTrustRegionState[UpdatableTrustRegionType] | None, TensorType]: # Check again to keep mypy happy. assert self._tags is not None - assert self._subspaces is not None + assert self._init_subspaces is not None assert self._rule is not None - # If state is set, the tags should be the same as the tags of the acquisition space - # in the state. + # If state is set, the tags should be the same as the tags of the initial space. if state is not None: - assert ( - self._tags == state.acquisition_space.subspace_tags - ), f"""The tags of the state acquisition space - {state.acquisition_space.subspace_tags} should be the same as the tags of the - BatchTrustRegion acquisition rule {self._tags}""" - - # Never use the subspaces from the passed in state, as we may have modified the - # subspaces in filter_datasets. - acquisition_space = TaggedMultiSearchSpace(self._subspaces, self._tags) + assert self._tags == tuple(state.subspace_tags), ( + f"The tags of the state acquisition space " + f"{state.subspace_tags} should be the same as the tags of the " + f"BatchTrustRegion acquisition rule {self._tags}" + ) + assert len(state.subspaces) == len(state.subspace_tags), ( + f"Inconsistent number of subspaces: {len(state.subspaces)} subspaces" + f"and {len(state.subspace_tags)} tags" + ) + subspaces = state.subspaces + else: + subspaces = self._init_subspaces # If the base rule is a sequence, run it sequentially for each subspace. # See earlier comments. if self._rules is not None: _points = [] - for subspace, rule in zip(self._subspaces, self._rules): + for tag, subspace, rule in zip(self._tags, subspaces, self._rules): _models = subspace.select_in_region(models) _datasets = subspace.select_in_region(datasets) assert _models is not None @@ -1434,11 +1455,11 @@ def state_func( } else: _datasets = None + acquisition_space = TaggedMultiSearchSpace(subspaces, self._tags) points = self._rule.acquire(acquisition_space, models, _datasets) - # We may modify the regions in filter_datasets later, so return a copy. - state_ = BatchTrustRegion.State(copy.deepcopy(acquisition_space)) - return state_, tf.reshape(points, [-1, len(self._subspaces), points.shape[-1]]) + state_ = BatchTrustRegionState(subspaces, self._tags) + return state_, tf.reshape(points, [-1, len(subspaces), points.shape[-1]]) return state_func @@ -1488,46 +1509,70 @@ def get_initialize_subspaces_mask( def filter_datasets( self, models: Mapping[Tag, ProbabilisticModelType], datasets: Mapping[Tag, Dataset] - ) -> Mapping[Tag, Dataset]: - # Update subspaces with the latest datasets. - assert self._subspaces is not None - for subspace in self._subspaces: - # Re-initialize or update the subspace, depending on the property. - if subspace.requires_initialization: - subspace.initialize(models, datasets) + ) -> types.State[BatchTrustRegionState[UpdatableTrustRegionType] | None, Mapping[Tag, Dataset]]: + def state_func( + state: BatchTrustRegionState[UpdatableTrustRegionType] | None, + ) -> Tuple[BatchTrustRegionState[UpdatableTrustRegionType] | None, Mapping[Tag, Dataset]]: + if state is not None: + assert self._tags == state.subspace_tags, ( + f"The tags of the state acquisition space " + f"{state.subspace_tags} should be the same as the tags of the " + f"BatchTrustRegion acquisition rule {self._tags}" + ) + assert len(state.subspaces) == len(state.subspace_tags), ( + f"Inconsistent number of subspaces: {len(state.subspaces)} subspaces" + f"and {len(state.subspace_tags)} tags" + ) + subspaces = tuple(state.subspaces) else: - subspace.update(models, datasets) - self.maybe_initialize_subspaces(self._subspaces, models, datasets) - - # Filter out points that are not in any of the subspaces. This is done by creating a mask - # for each local dataset that is True for points that are in any subspace. - used_masks = { - tag: tf.zeros(dataset.query_points.shape[:-1], dtype=tf.bool) - for tag, dataset in datasets.items() - if LocalizedTag.from_tag(tag).is_local - } - - for subspace in self._subspaces: - in_region_masks = subspace.get_datasets_filter_mask(datasets) - if in_region_masks is not None: - for tag, in_region in in_region_masks.items(): - ltag = LocalizedTag.from_tag(tag) - assert ltag.is_local, f"can only filter local tags, got {tag}" - used_masks[tag] = tf.logical_or(used_masks[tag], in_region) - - filtered_datasets = {} - for tag, used_mask in used_masks.items(): - filtered_datasets[tag] = Dataset( - tf.boolean_mask(datasets[tag].query_points, used_mask), - tf.boolean_mask(datasets[tag].observations, used_mask), - ) + assert self._init_subspaces is not None, "the subspaces have not been initialized" + assert self._tags is not None + subspaces = self._init_subspaces + + # make a deepcopy to avoid modifying any user copies + subspaces = copy.deepcopy(subspaces) + + # Update subspaces with the latest datasets. + for subspace in subspaces: + # Re-initialize or update the subspace, depending on the property. + if subspace.requires_initialization: + subspace.initialize(models, datasets) + else: + subspace.update(models, datasets) + self.maybe_initialize_subspaces(subspaces, models, datasets) + + # Filter out points that are not in any of the subspaces. This is done by creating a + # mask for each local dataset that is True for points that are in any subspace. + used_masks = { + tag: tf.zeros(dataset.query_points.shape[:-1], dtype=tf.bool) + for tag, dataset in datasets.items() + if LocalizedTag.from_tag(tag).is_local + } + + for subspace in subspaces: + in_region_masks = subspace.get_datasets_filter_mask(datasets) + if in_region_masks is not None: + for tag, in_region in in_region_masks.items(): + ltag = LocalizedTag.from_tag(tag) + assert ltag.is_local, f"can only filter local tags, got {tag}" + used_masks[tag] = tf.logical_or(used_masks[tag], in_region) + + filtered_datasets = {} + for tag, used_mask in used_masks.items(): + filtered_datasets[tag] = Dataset( + tf.boolean_mask(datasets[tag].query_points, used_mask), + tf.boolean_mask(datasets[tag].observations, used_mask), + ) - # Include global datasets unmodified. - for tag, dataset in datasets.items(): - if not LocalizedTag.from_tag(tag).is_local: - filtered_datasets[tag] = dataset + # Include global datasets unmodified. + for tag, dataset in datasets.items(): + if not LocalizedTag.from_tag(tag).is_local: + filtered_datasets[tag] = dataset - return filtered_datasets + state_ = BatchTrustRegionState(subspaces, self._tags) + return state_, filtered_datasets + + return state_func class HypercubeTrustRegion(UpdatableTrustRegion): @@ -1819,18 +1864,13 @@ class BatchTrustRegionBox(BatchTrustRegion[ProbabilisticModelType, UpdatableTrus This is intended to be used for single-objective optimization with batching. """ - def acquire( - self, - search_space: SearchSpace, - models: Mapping[Tag, ProbabilisticModelType], - datasets: Optional[Mapping[Tag, Dataset]] = None, - ) -> types.State[BatchTrustRegion.State | None, TensorType]: - if self._subspaces is None: - # If no initial subspaces were provided, create N default subspaces, where N is the - # number of query points in the base-rule. - # Currently the detection for N is only implemented for EGO. - # Note: the reason we don't create the default subspaces in `__init__` is because we - # don't have the global search space at that point. + def initialize_subspaces(self, search_space: SearchSpace) -> None: + # If no initial subspaces were provided, create N default subspaces, where N is the + # number of query points in the base-rule. + # Currently the detection for N is only implemented for EGO. + # Note: the reason we don't create the default subspaces in `__init__` is because we + # don't have the global search space at that point. + if self._init_subspaces is None: if isinstance(self._rule, EfficientGlobalOptimization): num_query_points = self._rule._num_query_points else: @@ -1842,20 +1882,27 @@ def acquire( init_subspaces: Tuple[UpdatableTrustRegionBox, ...] = tuple( SingleObjectiveTrustRegionBox(search_space) for _ in range(num_query_points) ) - self._subspaces = init_subspaces - for index, subspace in enumerate(self._subspaces): + self._init_subspaces = init_subspaces + for index, subspace in enumerate(self._init_subspaces): subspace.region_index = index # Override the index. self._tags = tuple(str(index) for index in range(self.num_local_datasets)) + def acquire( + self, + search_space: SearchSpace, + models: Mapping[Tag, ProbabilisticModelType], + datasets: Optional[Mapping[Tag, Dataset]] = None, + ) -> types.State[BatchTrustRegionState[UpdatableTrustRegionBox] | None, TensorType]: # Ensure passed in global search space is always the same as the search space passed to # the subspaces. - for subspace in self._subspaces: - assert subspace.global_search_space == search_space, ( - "The global search space of the subspaces should be the same as the " - "search space passed to the BatchTrustRegionBox acquisition rule. " - "If you want to change the global search space, you should recreate the rule. " - "Note: all subspaces should be initialized with the same global search space." - ) + if self._init_subspaces is not None: + for subspace in self._init_subspaces: + assert subspace.global_search_space == search_space, ( + "The global search space of the subspaces should be the same as the " + "search space passed to the BatchTrustRegionBox acquisition rule. " + "If you want to change the global search space, you should recreate the rule. " + "Note: all subspaces should be initialized with the same global search space." + ) return super().acquire(search_space, models, datasets) @@ -2497,13 +2544,8 @@ class BatchTrustRegionProduct( spaces. This is intended to be used for single-objective optimization with batching. """ - def acquire( - self, - search_space: SearchSpace, - models: Mapping[Tag, ProbabilisticModelType], - datasets: Optional[Mapping[Tag, Dataset]] = None, - ) -> types.State[BatchTrustRegion.State | None, TensorType]: - if self._subspaces is None: + def initialize_subspaces(self, search_space: SearchSpaceType) -> None: + if self._init_subspaces is None: # If no initial subspaces were provided, create N default subspaces, where N is the # number of query points in the base-rule. # Currently the detection for N is only implemented for EGO. @@ -2536,20 +2578,27 @@ def create_subregions() -> Sequence[UpdatableTrustRegion]: init_subspaces: Tuple[UpdatableTrustRegionProduct, ...] = tuple( UpdatableTrustRegionProduct(create_subregions()) for _ in range(num_query_points) ) - self._subspaces = init_subspaces - for index, subspace in enumerate(self._subspaces): + self._init_subspaces = init_subspaces + for index, subspace in enumerate(self._init_subspaces): subspace.region_index = index # Override the index. self._tags = tuple(str(index) for index in range(self.num_local_datasets)) + def acquire( + self, + search_space: SearchSpace, + models: Mapping[Tag, ProbabilisticModelType], + datasets: Optional[Mapping[Tag, Dataset]] = None, + ) -> types.State[BatchTrustRegionState[UpdatableTrustRegionProduct] | None, TensorType]: # Ensure passed in global search space is always the same as the search space passed to # the subspaces. - for subspace in self._subspaces: - assert subspace.global_search_space == search_space, ( - "The global search space of the subspaces should be the same as the " - "search space passed to the BatchTrustRegionProduct acquisition rule. " - "If you want to change the global search space, you should recreate the rule. " - "Note: all subspaces should be initialized with the same global search space." - ) + if self._init_subspaces is not None: + for subspace in self._init_subspaces: + assert subspace.global_search_space == search_space, ( + "The global search space of the subspaces should be the same as the " + "search space passed to the BatchTrustRegionProduct acquisition rule. " + "If you want to change the global search space, you should recreate the rule. " + "Note: all subspaces should be initialized with the same global search space." + ) return super().acquire(search_space, models, datasets) diff --git a/trieste/ask_tell_optimization.py b/trieste/ask_tell_optimization.py index 9353622a13..3c709d93fa 100644 --- a/trieste/ask_tell_optimization.py +++ b/trieste/ask_tell_optimization.py @@ -225,7 +225,7 @@ def __init__( - default acquisition is used but incompatible with other inputs """ self._search_space = search_space - self._acquisition_state = acquisition_state + self._acquisition_record = self._acquisition_state = acquisition_state if not datasets or not models: raise ValueError("dicts of datasets and models must be populated.") @@ -288,7 +288,17 @@ def __init__( datasets = with_local_datasets( self._datasets, num_local_datasets, self._dataset_ixs ) - self._filtered_datasets = self._acquisition_rule.filter_datasets(self._models, datasets) + self._acquisition_rule.initialize_subspaces(search_space) + + filtered_datasets: Mapping[Tag, Dataset] | State[ + StateType | None, Mapping[Tag, Dataset] + ] = self._acquisition_rule.filter_datasets(self._models, datasets) + if callable(filtered_datasets): + self._acquisition_state, self._filtered_datasets = filtered_datasets( + self._acquisition_state + ) + else: + self._filtered_datasets = filtered_datasets if fit_model: with Timer() as initial_model_fitting_timer: @@ -420,6 +430,8 @@ def from_record( on each optimization step. Defaults to :class:`~trieste.acquisition.rule.EfficientGlobalOptimization` with default arguments. + :param track_data: Whether the optimizer tracks the changing datasets via a local copy. + :param local_data_ixs: Indices to local data for local rules with `track_data` False. :return: New instance of :class:`~AskTellOptimizer`. """ # we are recovering previously saved optimization state @@ -451,7 +463,10 @@ def to_record(self, copy: bool = True) -> Record[StateType, ProbabilisticModelTy try: datasets_copy = deepcopy(self._datasets) if copy else self._datasets models_copy = deepcopy(self._models) if copy else self._models - state_copy = deepcopy(self._acquisition_state) if copy else self._acquisition_state + # use the state as it was at acquisition time, not the one modified in + # filter_datasets in preparation for the next acquisition, so we can reinitialise + # the AskTellOptimizer using the record + state_copy = deepcopy(self._acquisition_record) if copy else self._acquisition_record except Exception as e: raise NotImplementedError( "Failed to copy the optimization state. Some models do not support " @@ -475,6 +490,39 @@ def to_result(self, copy: bool = True) -> OptimizationResult[StateType, Probabil record: Record[StateType, ProbabilisticModelType] = self.to_record(copy=copy) return OptimizationResult(Ok(record), []) + @classmethod + def from_state( + cls: Type[AskTellOptimizerType], + state: AskTellOptimizerState[StateType, ProbabilisticModelType], + search_space: SearchSpaceType, + acquisition_rule: AcquisitionRule[ + TensorType | State[StateType | None, TensorType], + SearchSpaceType, + ProbabilisticModelType, + ] + | None = None, + track_data: bool = True, + ) -> AskTellOptimizerType: + """Creates new :class:`~AskTellOptimizer` instance from provided AskTellOptimizer state. + Model training isn't triggered upon creation of the instance. + + :param state: AskTellOptimizer state. + :param search_space: The space over which to search for the next query point. + :param acquisition_rule: The acquisition rule, which defines how to search for a new point + on each optimization step. Defaults to + :class:`~trieste.acquisition.rule.EfficientGlobalOptimization` with default + arguments. + :param track_data: Whether the optimizer tracks the changing datasets via a local copy. + :return: New instance of :class:`~AskTellOptimizer`. + """ + return cls.from_record( # type: ignore + state.record, + search_space, + acquisition_rule, + track_data=track_data, + local_data_ixs=state.local_data_ixs, + ) + def to_state( self, copy: bool = False ) -> AskTellOptimizerState[StateType, ProbabilisticModelType]: @@ -511,6 +559,8 @@ def ask(self) -> TensorType: if callable(points_or_stateful): self._acquisition_state, query_points = points_or_stateful(self._acquisition_state) + # also keep a copy of the state to return in to_record + self._acquisition_record = self._acquisition_state else: query_points = points_or_stateful @@ -597,7 +647,13 @@ def tell( datasets = with_local_datasets(new_data, num_local_datasets, self._dataset_ixs) self._dataset_len = self.dataset_len(datasets) - self._filtered_datasets = self._acquisition_rule.filter_datasets(self._models, datasets) + filtered_datasets = self._acquisition_rule.filter_datasets(self._models, datasets) + if callable(filtered_datasets): + self._acquisition_state, self._filtered_datasets = filtered_datasets( + self._acquisition_state + ) + else: + self._filtered_datasets = filtered_datasets with Timer() as model_fitting_timer: for tag, model in self._models.items(): diff --git a/trieste/bayesian_optimizer.py b/trieste/bayesian_optimizer.py index 72fb5efac4..1f06d9c458 100644 --- a/trieste/bayesian_optimizer.py +++ b/trieste/bayesian_optimizer.py @@ -709,6 +709,7 @@ def optimize( FrozenRecord[StateType, TrainableProbabilisticModelType] | Record[StateType, TrainableProbabilisticModelType] ] = [] + history_state = acquisition_state query_plot_dfs: dict[int, pd.DataFrame] = {} observation_plot_dfs = observation_plot_init(datasets) @@ -734,15 +735,18 @@ def optimize( try: if track_state: try: + # note that we use the state as it was at acquisition time, not the one + # modified in filter_datasets in preparation for the next acquisition, + # so we can restart the optimization correctly (and also for plotting) if track_path is None: datasets_copy = copy.deepcopy(datasets) models_copy = copy.deepcopy(models) - acquisition_state_copy = copy.deepcopy(acquisition_state) - record = Record(datasets_copy, models_copy, acquisition_state_copy) + history_state_copy = copy.deepcopy(history_state) + record = Record(datasets_copy, models_copy, history_state_copy) history.append(record) else: track_path = Path(track_path) - record = Record(datasets, models, acquisition_state) + record = Record(datasets, models, history_state) file_name = OptimizationResult.step_filename(step, num_steps) history.append(record.save(track_path / file_name)) except Exception as e: @@ -763,7 +767,17 @@ def optimize( datasets = with_local_datasets( datasets, acquisition_rule.num_local_datasets ) - filtered_datasets = acquisition_rule.filter_datasets(models, datasets) + acquisition_rule.initialize_subspaces(self._search_space) + + filtered_datasets_or_callable: Mapping[Tag, Dataset] | State[ + StateType | None, Mapping[Tag, Dataset] + ] = acquisition_rule.filter_datasets(models, datasets) + if callable(filtered_datasets_or_callable): + acquisition_state, filtered_datasets = filtered_datasets_or_callable( + acquisition_state + ) + else: + filtered_datasets = filtered_datasets_or_callable if fit_model and fit_initial_model: with Timer() as initial_model_fitting_timer: @@ -789,6 +803,7 @@ def optimize( ) if callable(points_or_stateful): acquisition_state, query_points = points_or_stateful(acquisition_state) + history_state = acquisition_state else: query_points = points_or_stateful @@ -806,7 +821,16 @@ def optimize( for tag, new_dataset in tagged_output.items(): datasets[tag] += new_dataset - filtered_datasets = acquisition_rule.filter_datasets(models, datasets) + + filtered_datasets_or_callable = acquisition_rule.filter_datasets( + models, datasets + ) + if callable(filtered_datasets_or_callable): + acquisition_state, filtered_datasets = filtered_datasets_or_callable( + acquisition_state + ) + else: + filtered_datasets = filtered_datasets_or_callable with Timer() as model_fitting_timer: if fit_model: @@ -860,7 +884,7 @@ def optimize( tf.print("Optimization completed without errors", output_stream=absl.logging.INFO) - record = Record(datasets, models, acquisition_state) + record = Record(datasets, models, history_state) result = OptimizationResult(Ok(record), history) if track_state and track_path is not None: result.save_result(Path(track_path) / OptimizationResult.RESULTS_FILENAME)