diff --git a/trieste/acquisition/rule.py b/trieste/acquisition/rule.py index 3c4376d727..dc172ab2a6 100644 --- a/trieste/acquisition/rule.py +++ b/trieste/acquisition/rule.py @@ -1010,8 +1010,8 @@ def update( def _get_tags(self, tags: Set[Tag]) -> Tuple[Set[Tag], Set[Tag]]: # Separate tags into local (matching index) and global tags (without matching # local tag). - local_gtags = set() - global_tags = set() + local_gtags = set() # Set of global part of all local tags. + global_tags = set() # Set of all global tags. for tag in tags: ltag = LocalizedTag.from_tag(tag) if not ltag.is_local: diff --git a/trieste/objectives/utils.py b/trieste/objectives/utils.py index d765feb13f..b074738ed1 100644 --- a/trieste/objectives/utils.py +++ b/trieste/objectives/utils.py @@ -22,13 +22,12 @@ from collections.abc import Callable from typing import Mapping, Optional, Union, overload -import tensorflow as tf from check_shapes import check_shapes from ..data import Dataset from ..observer import OBJECTIVE, MultiObserver, Observer, SingleObserver from ..types import Tag, TensorType -from ..utils.misc import LocalizedTag +from ..utils.misc import LocalizedTag, flatten_leading_dims @overload @@ -83,7 +82,7 @@ def _observer(qps: TensorType) -> Mapping[Tag, Dataset]: # Call objective with rank 2 query points by flattening batch dimension. # Some objectives might only expect rank 2 query points, so this is safer. batch_size = qps.shape[1] - flat_qps = tf.reshape(qps, [-1, qps.shape[-1]]) + flat_qps, unflatten = flatten_leading_dims(qps) obs_or_dataset = objective_or_observer(flat_qps) if not isinstance(obs_or_dataset, (Mapping, Dataset)): @@ -98,8 +97,8 @@ def _observer(qps: TensorType) -> Mapping[Tag, Dataset]: for key, dataset in obs_or_dataset.items(): # Include overall dataset and per batch dataset. flat_obs = dataset.observations - qps = tf.reshape(flat_qps, [-1, batch_size, flat_qps.shape[-1]]) - obs = tf.reshape(flat_obs, [-1, batch_size, flat_obs.shape[-1]]) + qps = unflatten(flat_qps) + obs = unflatten(flat_obs) datasets[key] = dataset for i in range(batch_size): datasets[LocalizedTag(key, i)] = Dataset(qps[:, i], obs[:, i])