From beedf930d95e9a7230f28553837ea0863a702de7 Mon Sep 17 00:00:00 2001 From: uri-granta <50578464+uri-granta@users.noreply.github.com> Date: Tue, 5 Nov 2024 13:06:56 +0000 Subject: [PATCH] Use different seeds for subspace sampling (#885) Co-authored-by: Uri Granta --- tests/unit/test_space.py | 12 ++++++++++++ trieste/space.py | 6 +++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_space.py b/tests/unit/test_space.py index 47e4d8759..55679d0cf 100644 --- a/tests/unit/test_space.py +++ b/tests/unit/test_space.py @@ -973,6 +973,18 @@ def test_collection_space_sampling_raises_for_invalid_sample_size( collection_space.sample(num_samples) +@pytest.mark.parametrize("search_space_type", [TaggedMultiSearchSpace, TaggedProductSearchSpace]) +def test_collection_space_sampling_uses_different_seeds_for_subspaces( + search_space_type: Type[CollectionSearchSpace], +) -> None: + box = Box([0], [1]) + collection_space = search_space_type(spaces=[box, box]) + samples = collection_space.sample(5, seed=42) + # check that all the points are unique despite the seed + flattened = samples.numpy().flatten() + assert len(flattened) == len(set(flattened)) + + @pytest.mark.parametrize( "search_space_type, space_A", [ diff --git a/trieste/space.py b/trieste/space.py index 0de9214b5..1ff8a09e2 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -1229,7 +1229,11 @@ def subspace_sample(self, num_samples: int, seed: Optional[int] = None) -> Seque tf.debugging.assert_non_negative(num_samples) if seed is not None: # ensure reproducibility tf.random.set_seed(seed) - return [self._spaces[tag].sample(num_samples, seed=seed) for tag in self._tags] + return [ + # ensure subspaces (which may be identical) don't all use the same seed + self._spaces[tag].sample(num_samples, seed=None if seed is None else seed + i) + for i, tag in enumerate(self._tags) + ] def __eq__(self, other: object) -> bool: """