Skip to content

Commit

Permalink
Allow TRs to be created with empty/single region
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani committed Sep 12, 2023
1 parent faee54c commit 067e466
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 4 deletions.
35 changes: 35 additions & 0 deletions tests/unit/acquisition/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,41 @@ def test_trust_region_box_update_size(success: bool) -> None:
npt.assert_allclose(trb.upper, np.minimum(trb.location + trb.eps, search_space.upper))


# Check multi trust region works when no subspace is provided.
@pytest.mark.parametrize(
"rule, exp_num_subspaces",
[
(EfficientGlobalOptimization(), 1),
(EfficientGlobalOptimization(ParallelContinuousThompsonSampling(), num_query_points=2), 2),
(RandomSampling(num_query_points=2), 1),
],
)
def test_multi_trust_region_box_no_subspace(
rule: AcquisitionRule[TensorType, SearchSpace, ProbabilisticModel],
exp_num_subspaces: int,
) -> None:
search_space = Box([0.0, 0.0], [1.0, 1.0])
mtb = BatchTrustRegionBox(rule=rule)
mtb.acquire(search_space, {})

assert mtb._tags is not None
assert mtb._init_subspaces is not None
assert len(mtb._init_subspaces) == exp_num_subspaces
for i, (subspace, tag) in enumerate(zip(mtb._init_subspaces, mtb._tags)):
assert isinstance(subspace, SingleObjectiveTrustRegionBox)
assert subspace.global_search_space == search_space
assert tag == f"{i}"


# Check multi trust region works when a single subspace is provided.
def test_multi_trust_region_box_single_subspace() -> None:
search_space = Box([0.0, 0.0], [1.0, 1.0])
subspace = SingleObjectiveTrustRegionBox(search_space)
mtb = BatchTrustRegionBox(subspace) # type: ignore[var-annotated]
assert mtb._init_subspaces == (subspace,)
assert mtb._tags == ("0",)


# When state is None, acquire returns a multi search space of the correct type.
def test_multi_trust_region_box_acquire_no_state() -> None:
search_space = Box([0.0, 0.0], [1.0, 1.0])
Expand Down
46 changes: 42 additions & 4 deletions trieste/acquisition/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,19 +1166,30 @@ def __deepcopy__(self, memo: dict[int, object]) -> BatchTrustRegion.State:

def __init__(
self: "BatchTrustRegion[ProbabilisticModelType, UpdatableTrustRegionType]",
init_subspaces: Sequence[UpdatableTrustRegionType],
init_subspaces: Union[
None, UpdatableTrustRegionType, Sequence[UpdatableTrustRegionType]
] = None,
rule: AcquisitionRule[TensorType, SearchSpace, ProbabilisticModelType] | None = None,
):
"""
:param init_subspaces: The initial search spaces for each trust region.
:param init_subspaces: The initial search spaces for each trust region. If `None`, default
subspaces of type :class:`UpdatableTrustRegionType` will be created, with length
equal to the number of query points in the base `rule`.
:param rule: The acquisition rule that defines how to search for a new query point in each
subspace. Defaults to :class:`EfficientGlobalOptimization` with default arguments.
"""
if rule is None:
rule = EfficientGlobalOptimization()

self._init_subspaces = tuple(init_subspaces)
self._tags = tuple([str(index) for index in range(len(init_subspaces))])
# If init_subspaces are not provided, leave it to the subclasses to create them.
self._init_subspaces = None
self._tags = None
if init_subspaces is not None:
if not isinstance(init_subspaces, Sequence):
init_subspaces = [init_subspaces]
self._init_subspaces = tuple(init_subspaces)
self._tags = tuple([str(index) for index in range(len(init_subspaces))])

self._rule = rule

def __repr__(self) -> str:
Expand All @@ -1202,6 +1213,11 @@ def state_func(
Use the rule to acquire points from the acquisition space.
"""

# Subspaces should be set by the time we call `acquire`.
assert self._tags is not None
assert self._init_subspaces is not None

# If state is set, the tags should be the same as the tags of the acquisition space
# in the state.
if state is not None:
Expand Down Expand Up @@ -1399,6 +1415,28 @@ class BatchTrustRegionBox(BatchTrustRegion[ProbabilisticModelType, SingleObjecti
This is intended to be used for single-objective optimization with batching.
"""

def acquire(
self,
search_space: SearchSpace,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
) -> types.State[BatchTrustRegion.State | None, TensorType]:
if self._init_subspaces is None:
# If no initial subspaces were provided, create N default subspaces, where N is the
# number of query points in the base-rule.
# Currently the detection for N is only implemented for EGO.
if isinstance(self._rule, EfficientGlobalOptimization):
num_query_points = self._rule._num_query_points
else:
num_query_points = 1

self._init_subspaces = tuple(
[SingleObjectiveTrustRegionBox(search_space) for _ in range(num_query_points)]
)
self._tags = tuple([str(index) for index in range(len(self._init_subspaces))])

return super().acquire(search_space, models, datasets)

@inherit_check_shapes
def get_initialize_subspaces_mask(
self,
Expand Down

0 comments on commit 067e466

Please sign in to comment.