diff --git a/tests/unit/acquisition/test_rule.py b/tests/unit/acquisition/test_rule.py index 0e2b575538..7c8095334c 100644 --- a/tests/unit/acquisition/test_rule.py +++ b/tests/unit/acquisition/test_rule.py @@ -1274,6 +1274,41 @@ def test_trust_region_box_update_size(success: bool) -> None: npt.assert_allclose(trb.upper, np.minimum(trb.location + trb.eps, search_space.upper)) +# Check multi trust region works when no subspace is provided. +@pytest.mark.parametrize( + "rule, exp_num_subspaces", + [ + (EfficientGlobalOptimization(), 1), + (EfficientGlobalOptimization(ParallelContinuousThompsonSampling(), num_query_points=2), 2), + (RandomSampling(num_query_points=2), 1), + ], +) +def test_multi_trust_region_box_no_subspace( + rule: AcquisitionRule[TensorType, SearchSpace, ProbabilisticModel], + exp_num_subspaces: int, +) -> None: + search_space = Box([0.0, 0.0], [1.0, 1.0]) + mtb = BatchTrustRegionBox(rule=rule) + mtb.acquire(search_space, {}) + + assert mtb._tags is not None + assert mtb._init_subspaces is not None + assert len(mtb._init_subspaces) == exp_num_subspaces + for i, (subspace, tag) in enumerate(zip(mtb._init_subspaces, mtb._tags)): + assert isinstance(subspace, SingleObjectiveTrustRegionBox) + assert subspace.global_search_space == search_space + assert tag == f"{i}" + + +# Check multi trust region works when a single subspace is provided. +def test_multi_trust_region_box_single_subspace() -> None: + search_space = Box([0.0, 0.0], [1.0, 1.0]) + subspace = SingleObjectiveTrustRegionBox(search_space) + mtb = BatchTrustRegionBox(subspace) # type: ignore[var-annotated] + assert mtb._init_subspaces == (subspace,) + assert mtb._tags == ("0",) + + # When state is None, acquire returns a multi search space of the correct type. def test_multi_trust_region_box_acquire_no_state() -> None: search_space = Box([0.0, 0.0], [1.0, 1.0]) diff --git a/trieste/acquisition/rule.py b/trieste/acquisition/rule.py index e11977667b..3d1303f156 100644 --- a/trieste/acquisition/rule.py +++ b/trieste/acquisition/rule.py @@ -1166,19 +1166,30 @@ def __deepcopy__(self, memo: dict[int, object]) -> BatchTrustRegion.State: def __init__( self: "BatchTrustRegion[ProbabilisticModelType, UpdatableTrustRegionType]", - init_subspaces: Sequence[UpdatableTrustRegionType], + init_subspaces: Union[ + None, UpdatableTrustRegionType, Sequence[UpdatableTrustRegionType] + ] = None, rule: AcquisitionRule[TensorType, SearchSpace, ProbabilisticModelType] | None = None, ): """ - :param init_subspaces: The initial search spaces for each trust region. + :param init_subspaces: The initial search spaces for each trust region. If `None`, default + subspaces of type :class:`UpdatableTrustRegionType` will be created, with length + equal to the number of query points in the base `rule`. :param rule: The acquisition rule that defines how to search for a new query point in each subspace. Defaults to :class:`EfficientGlobalOptimization` with default arguments. """ if rule is None: rule = EfficientGlobalOptimization() - self._init_subspaces = tuple(init_subspaces) - self._tags = tuple([str(index) for index in range(len(init_subspaces))]) + # If init_subspaces are not provided, leave it to the subclasses to create them. + self._init_subspaces = None + self._tags = None + if init_subspaces is not None: + if not isinstance(init_subspaces, Sequence): + init_subspaces = [init_subspaces] + self._init_subspaces = tuple(init_subspaces) + self._tags = tuple([str(index) for index in range(len(init_subspaces))]) + self._rule = rule def __repr__(self) -> str: @@ -1202,6 +1213,11 @@ def state_func( Use the rule to acquire points from the acquisition space. """ + + # Subspaces should be set by the time we call `acquire`. + assert self._tags is not None + assert self._init_subspaces is not None + # If state is set, the tags should be the same as the tags of the acquisition space # in the state. if state is not None: @@ -1399,6 +1415,28 @@ class BatchTrustRegionBox(BatchTrustRegion[ProbabilisticModelType, SingleObjecti This is intended to be used for single-objective optimization with batching. """ + def acquire( + self, + search_space: SearchSpace, + models: Mapping[Tag, ProbabilisticModelType], + datasets: Optional[Mapping[Tag, Dataset]] = None, + ) -> types.State[BatchTrustRegion.State | None, TensorType]: + if self._init_subspaces is None: + # If no initial subspaces were provided, create N default subspaces, where N is the + # number of query points in the base-rule. + # Currently the detection for N is only implemented for EGO. + if isinstance(self._rule, EfficientGlobalOptimization): + num_query_points = self._rule._num_query_points + else: + num_query_points = 1 + + self._init_subspaces = tuple( + [SingleObjectiveTrustRegionBox(search_space) for _ in range(num_query_points)] + ) + self._tags = tuple([str(index) for index in range(len(self._init_subspaces))]) + + return super().acquire(search_space, models, datasets) + @inherit_check_shapes def get_initialize_subspaces_mask( self,