From 0f1cb27467db97ada9707389906e2ca92d630ca2 Mon Sep 17 00:00:00 2001 From: Uri Granta Date: Tue, 28 May 2024 16:31:54 +0100 Subject: [PATCH] Use correct dtype for empty box samples --- tests/unit/test_space.py | 14 ++++++++++++++ trieste/space.py | 4 ++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_space.py b/tests/unit/test_space.py index 638093382f..3800e6436e 100644 --- a/tests/unit/test_space.py +++ b/tests/unit/test_space.py @@ -1577,3 +1577,17 @@ def test_nonlinear_constraints_multioutput_raises() -> None: ) with pytest.raises(TF_DEBUGGING_ERROR_TYPES): nlc.residual(points) + + +@pytest.mark.parametrize("dtype", [tf.float32, tf.float64]) +def test_box_empty_sobol_sampling_returns_correct_dtype(dtype: tf.DType) -> None: + box = Box(tf.zeros((3,), dtype=dtype), tf.ones((3,), dtype=dtype)) + sobol_samples = box.sample_sobol(0) + assert sobol_samples.dtype == dtype + + +@pytest.mark.parametrize("dtype", [tf.float32, tf.float64]) +def test_box_empty_halton_sampling_returns_correct_dtype(dtype: tf.DType) -> None: + box = Box(tf.zeros((3,), dtype=dtype), tf.ones((3,), dtype=dtype)) + sobol_samples = box.sample_halton(0) + assert sobol_samples.dtype == dtype diff --git a/trieste/space.py b/trieste/space.py index 3c750eb583..1c361dc6ef 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -627,7 +627,7 @@ def _sample_halton( # Internal common method to sample from the space using a Halton sequence. tf.debugging.assert_non_negative(num_samples) if num_samples == 0: - return tf.constant([]) + return tf.constant([], dtype=self._lower.dtype) if seed is not None: # ensure reproducibility tf.random.set_seed(seed) dim = tf.shape(self._lower)[-1] @@ -660,7 +660,7 @@ def sample_sobol(self, num_samples: int, skip: Optional[int] = None) -> TensorTy """ tf.debugging.assert_non_negative(num_samples) if num_samples == 0: - return tf.constant([]) + return tf.constant([], dtype=self._lower.dtype) if skip is None: # generate random skip skip = tf.random.uniform([1], maxval=2**16, dtype=tf.int32)[0] dim = tf.shape(self._lower)[-1]