Skip to content

Commit

Permalink
Support tf.Variables in dataset_len method (#851)
Browse files Browse the repository at this point in the history
  • Loading branch information
uri-granta authored Jun 4, 2024
1 parent 9da6a8a commit 7c18824
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
31 changes: 31 additions & 0 deletions tests/unit/test_ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,3 +925,34 @@ def test_ask_tell_optimizer_calls_initialize_subspaces(
assert local_acquisition_rule._initialize_subspaces_calls == 0
optimizer(search_space, init_dataset, model, local_acquisition_rule, track_data=False)
assert local_acquisition_rule._initialize_subspaces_calls == 1


@pytest.mark.parametrize("variable", [False, True])
def test_ask_tell_optimizer_dataset_len_variables(
init_dataset: Dataset,
variable: bool,
) -> None:
if variable:
dataset = Dataset(
tf.Variable(
init_dataset.query_points, shape=[None, *init_dataset.query_points.shape[1:]]
),
tf.Variable(
init_dataset.observations, shape=[None, *init_dataset.observations.shape[1:]]
),
)
else:
dataset = init_dataset

assert AskTellOptimizer.dataset_len({"tag": dataset}) == 2


def test_ask_tell_optimizer_dataset_len_raises_on_inconsistently_sized_datasets(
init_dataset: Dataset,
) -> None:
with pytest.raises(ValueError):
AskTellOptimizer.dataset_len(
{"tag": init_dataset, "empty": Dataset(tf.zeros([0, 2]), tf.zeros([0, 2]))}
)
with pytest.raises(ValueError):
AskTellOptimizer.dataset_len({})
14 changes: 8 additions & 6 deletions trieste/ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,14 +437,16 @@ def acquisition_state(self) -> StateType | None:
@classmethod
def dataset_len(cls, datasets: Mapping[Tag, Dataset]) -> int:
"""Helper method for inferring the global dataset size."""
dataset_lens = {
len(dataset.query_points)
dataset_lens = [
tf.shape(dataset.query_points)[0]
for tag, dataset in datasets.items()
if not LocalizedTag.from_tag(tag).is_local
}
if len(dataset_lens) != 1:
raise ValueError(f"Expected unique global dataset size, got {dataset_lens}")
return next(iter(dataset_lens))
]
unique_lens, unique_idxs = tf.unique(dataset_lens)
if len(unique_idxs) == 1:
return int(unique_lens[0])
else:
raise ValueError(f"Expected unique global dataset size, got {unique_lens}")

@classmethod
def from_record(
Expand Down

0 comments on commit 7c18824

Please sign in to comment.