diff --git a/tests/integration/test_mixed_space_bayesian_optimization.py b/tests/integration/test_mixed_space_bayesian_optimization.py index 616eb16508..84cbe3bf29 100644 --- a/tests/integration/test_mixed_space_bayesian_optimization.py +++ b/tests/integration/test_mixed_space_bayesian_optimization.py @@ -39,10 +39,17 @@ from trieste.bayesian_optimizer import BayesianOptimizer from trieste.models import TrainableProbabilisticModel from trieste.models.gpflow import GaussianProcessRegression, build_gpr -from trieste.objectives import ScaledBranin +from trieste.objectives import ScaledBranin, SingleObjectiveTestProblem +from trieste.objectives.single_objectives import scaled_branin from trieste.objectives.utils import mk_observer from trieste.observer import OBJECTIVE -from trieste.space import Box, DiscreteSearchSpace, TaggedProductSearchSpace +from trieste.space import ( + Box, + CategoricalSearchSpace, + DiscreteSearchSpace, + TaggedProductSearchSpace, + one_hot_encoder, +) from trieste.types import TensorType @@ -190,3 +197,85 @@ def test_optimizer_finds_minima_of_the_scaled_branin_function( acquisition_function = acquisition_rule._acquisition_function if isinstance(acquisition_function, AcquisitionFunctionClass): assert acquisition_function.__call__._get_tracing_count() <= 4 # type: ignore + + +def categorical_scaled_branin( + categories_to_points: TensorType, +) -> SingleObjectiveTestProblem[TaggedProductSearchSpace]: + """ + Generate a Scaled Branin test problem defined on the product of a categorical space and a + continuous space, with categories mapped to points using the given 1D tensor. + """ + categorical_space = CategoricalSearchSpace([str(float(v)) for v in categories_to_points]) + continuous_space = Box([0], [1]) + search_space = TaggedProductSearchSpace( + spaces=[categorical_space, continuous_space], + tags=["discrete", "continuous"], + ) + + def objective(x: TensorType) -> TensorType: + points = tf.gather(categories_to_points, tf.cast(x[..., 0], tf.int32)) + x_mapped = tf.concat([tf.expand_dims(points, -1), x[..., 1:]], axis=-1) + return scaled_branin(x_mapped) + + minimizer_indices = [] + for minimizer0 in ScaledBranin.minimizers[..., 0]: + indices = tf.where(tf.equal(categories_to_points, minimizer0)) + minimizer_indices.append(indices[0][0]) + category_indices = tf.expand_dims(tf.convert_to_tensor(minimizer_indices, dtype=tf.float64), -1) + minimizers = tf.concat([category_indices, ScaledBranin.minimizers[..., 1:]], axis=-1) + + return SingleObjectiveTestProblem( + name="Categorical scaled Branin", + objective=objective, + search_space=search_space, + minimizers=minimizers, + minimum=ScaledBranin.minimum, + ) + + +@random_seed +@pytest.mark.parametrize( + "num_steps, acquisition_rule", + [ + pytest.param(25, EfficientGlobalOptimization(), id="EfficientGlobalOptimization"), + ], +) +def test_optimizer_finds_minima_of_the_categorical_scaled_branin_function( + num_steps: int, + acquisition_rule: AcquisitionRule[ + TensorType, TaggedProductSearchSpace, TrainableProbabilisticModel + ], +) -> None: + # 6 categories mapping to 3 random points plus the 3 minimizer points + points = tf.concat( + [tf.random.uniform([3], dtype=tf.float64), ScaledBranin.minimizers[..., 0]], 0 + ) + problem = categorical_scaled_branin(tf.random.shuffle(points)) + initial_query_points = problem.search_space.sample(5) + observer = mk_observer(problem.objective) + initial_data = observer(initial_query_points) + + # model uses one-hot encoding for the categorical inputs + encoder = one_hot_encoder(problem.search_space) + model = GaussianProcessRegression( + build_gpr(initial_data, problem.search_space, likelihood_variance=1e-8), + encoder=encoder, + ) + + dataset = ( + BayesianOptimizer(observer, problem.search_space) + .optimize(num_steps, initial_data, model, acquisition_rule) + .try_get_final_dataset() + ) + + arg_min_idx = tf.squeeze(tf.argmin(dataset.observations, axis=0)) + + best_y = dataset.observations[arg_min_idx] + best_x = dataset.query_points[arg_min_idx] + + relative_minimizer_err = tf.abs((best_x - problem.minimizers) / problem.minimizers) + assert tf.reduce_any( + tf.reduce_all(relative_minimizer_err < 0.1, axis=-1), axis=0 + ), relative_minimizer_err + npt.assert_allclose(best_y, problem.minimum, rtol=0.005) diff --git a/tests/integration/test_multifidelity_bayesian_optimization.py b/tests/integration/test_multifidelity_bayesian_optimization.py index 39202a6b4c..ac25dca63a 100644 --- a/tests/integration/test_multifidelity_bayesian_optimization.py +++ b/tests/integration/test_multifidelity_bayesian_optimization.py @@ -38,11 +38,13 @@ ) from trieste.objectives.utils import mk_observer from trieste.observer import SingleObserver -from trieste.space import TaggedProductSearchSpace +from trieste.space import SearchSpaceType, TaggedProductSearchSpace from trieste.types import TensorType -def _build_observer(problem: SingleObjectiveMultifidelityTestProblem) -> SingleObserver: +def _build_observer( + problem: SingleObjectiveMultifidelityTestProblem[SearchSpaceType], +) -> SingleObserver: objective_function = problem.objective def noisy_objective(x: TensorType) -> TensorType: @@ -57,7 +59,7 @@ def noisy_objective(x: TensorType) -> TensorType: def _build_nested_multifidelity_dataset( - problem: SingleObjectiveMultifidelityTestProblem, observer: SingleObserver + problem: SingleObjectiveMultifidelityTestProblem[SearchSpaceType], observer: SingleObserver ) -> Dataset: num_fidelities = problem.num_fidelities initial_sample_sizes = [10 + 2 * (num_fidelities - i) for i in range(num_fidelities)] @@ -83,7 +85,7 @@ def _build_nested_multifidelity_dataset( @random_seed @pytest.mark.parametrize("problem", ((Linear2Fidelity), (Linear3Fidelity), (Linear5Fidelity))) def test_multifidelity_bo_finds_minima_of_linear_problem( - problem: SingleObjectiveMultifidelityTestProblem, + problem: SingleObjectiveMultifidelityTestProblem[SearchSpaceType], ) -> None: observer = _build_observer(problem) initial_data = _build_nested_multifidelity_dataset(problem, observer) diff --git a/tests/unit/models/conftest.py b/tests/unit/models/conftest.py index 9a41a18316..107a5e7193 100644 --- a/tests/unit/models/conftest.py +++ b/tests/unit/models/conftest.py @@ -43,6 +43,7 @@ VariationalGaussianProcess, ) from trieste.models.optimizer import DatasetTransformer, Optimizer +from trieste.space import EncoderFunction from trieste.types import TensorType @@ -58,12 +59,15 @@ ) def _gpflow_interface_factory(request: Any) -> ModelFactoryType: def model_interface_factory( - x: TensorType, y: TensorType, optimizer: Optimizer | None = None + x: TensorType, + y: TensorType, + optimizer: Optimizer | None = None, + encoder: EncoderFunction | None = None, ) -> tuple[GPflowPredictor, Callable[[TensorType, TensorType], GPModel]]: model_interface: Callable[..., GPflowPredictor] = request.param[0] base_model: GaussianProcessRegression = request.param[1](x, y) reference_model: Callable[[TensorType, TensorType], GPModel] = request.param[1] - return model_interface(base_model, optimizer=optimizer), reference_model + return model_interface(base_model, optimizer=optimizer, encoder=encoder), reference_model return model_interface_factory diff --git a/tests/unit/models/gpflow/test_interface.py b/tests/unit/models/gpflow/test_interface.py index d1f117d790..e55929bb17 100644 --- a/tests/unit/models/gpflow/test_interface.py +++ b/tests/unit/models/gpflow/test_interface.py @@ -24,6 +24,7 @@ from tests.util.misc import random_seed from trieste.data import Dataset from trieste.models.gpflow import BatchReparametrizationSampler, GPflowPredictor +from trieste.space import CategoricalSearchSpace, one_hot_encoder class _QuadraticPredictor(GPflowPredictor): @@ -31,10 +32,10 @@ class _QuadraticPredictor(GPflowPredictor): def model(self) -> GPModel: return _QuadraticGPModel() - def optimize(self, dataset: Dataset) -> None: + def optimize_encoded(self, dataset: Dataset) -> None: self.optimizer.optimize(self.model, dataset) - def update(self, dataset: Dataset) -> None: + def update_encoded(self, dataset: Dataset) -> None: return def log(self, dataset: Optional[Dataset] = None) -> None: @@ -112,3 +113,14 @@ def test_gpflow_reparam_sampler_returns_reparam_sampler_with_correct_samples() - linear_error = 1 / tf.sqrt(tf.cast(num_samples, tf.float32)) npt.assert_allclose(sample_mean, [[6.25]], rtol=linear_error) npt.assert_allclose(sample_variance, 1.0, rtol=2 * linear_error) + + +def test_gpflow_categorical_predict() -> None: + search_space = CategoricalSearchSpace(["Red", "Green", "Blue"]) + query_points = search_space.sample(10) + model = _QuadraticPredictor(encoder=one_hot_encoder(search_space)) + mean, variance = model.predict(query_points) + assert mean.shape == [10, 1] + assert variance.shape == [10, 1] + npt.assert_allclose(mean, [[1.0]] * 10, rtol=0.01) + npt.assert_allclose(variance, [[1.0]] * 10, rtol=0.01) diff --git a/tests/unit/models/test_interfaces.py b/tests/unit/models/test_interfaces.py index 6213f7e723..b650f62f16 100644 --- a/tests/unit/models/test_interfaces.py +++ b/tests/unit/models/test_interfaces.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections.abc import Callable, Sequence +from typing import Optional import gpflow import numpy as np @@ -35,12 +36,17 @@ from trieste.data import Dataset from trieste.models import TrainableModelStack, TrainableProbabilisticModel from trieste.models.interfaces import ( + EncodedProbabilisticModel, + EncodedSupportsPredictJoint, + EncodedSupportsPredictY, + EncodedTrainableProbabilisticModel, TrainablePredictJointReparamModelStack, TrainablePredictYModelStack, TrainableSupportsPredictJoint, TrainableSupportsPredictJointHasReparamSampler, ) from trieste.models.utils import get_last_optimization_result, optimize_model_and_save_result +from trieste.space import EncoderFunction from trieste.types import TensorType @@ -216,3 +222,93 @@ def test_model_stack_reparam_sampler() -> None: npt.assert_allclose(var[..., :2], var01, rtol=0.04) npt.assert_allclose(var[..., 2:3], var2, rtol=0.04) npt.assert_allclose(var[..., 3:], var3, rtol=0.04) + + +class _EncodedModel( + EncodedTrainableProbabilisticModel, + EncodedSupportsPredictJoint, + EncodedSupportsPredictY, + EncodedProbabilisticModel, +): + def __init__(self, encoder: EncoderFunction | None = None) -> None: + self.dataset: Dataset | None = None + self._encoder = (lambda x: x + 1) if encoder is None else encoder + + @property + def encoder(self) -> EncoderFunction | None: + return self._encoder + + def predict_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + return query_points, query_points + + def sample_encoded(self, query_points: TensorType, num_samples: int) -> TensorType: + return tf.tile(tf.expand_dims(query_points, 0), [num_samples, 1, 1]) + + def log(self, dataset: Optional[Dataset] = None) -> None: + pass + + def update_encoded(self, dataset: Dataset) -> None: + self.dataset = dataset + + def optimize_encoded(self, dataset: Dataset) -> None: + self.dataset = dataset + + def predict_joint_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + b, d = query_points.shape + return query_points, tf.zeros([d, b, b]) + + def predict_y_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + return self.predict_encoded(query_points) + + +def test_encoded_probabilistic_model() -> None: + model = _EncodedModel() + query_points = tf.random.uniform([3, 5]) + mean, var = model.predict(query_points) + npt.assert_allclose(mean, query_points + 1) + npt.assert_allclose(var, query_points + 1) + samples = model.sample(query_points, 10) + assert len(samples) == 10 + for i in range(10): + npt.assert_allclose(samples[i], query_points + 1) + + +def test_encoded_trainable_probabilistic_model() -> None: + model = _EncodedModel() + assert model.dataset is None + for method in model.update, model.optimize: + query_points = tf.random.uniform([3, 5]) + observations = tf.random.uniform([3, 1]) + dataset = Dataset(query_points, observations) + method(dataset) + assert model.dataset is not None + # no idea why mypy thinks model.dataset couldn't have changed here + npt.assert_allclose( # type: ignore[unreachable] + model.dataset.query_points, query_points + 1 + ) + npt.assert_allclose(model.dataset.observations, observations) + + +def test_encoded_supports_predict_joint() -> None: + model = _EncodedModel() + query_points = tf.random.uniform([3, 5]) + mean, var = model.predict_joint(query_points) + npt.assert_allclose(mean, query_points + 1) + npt.assert_allclose(var, tf.zeros([5, 3, 3])) + + +def test_encoded_supports_predict_y() -> None: + model = _EncodedModel() + query_points = tf.random.uniform([3, 5]) + mean, var = model.predict_y(query_points) + npt.assert_allclose(mean, query_points + 1) + npt.assert_allclose(var, query_points + 1) + + +def test_encoded_probabilistic_model_keras_embedding() -> None: + encoder = tf.keras.layers.Embedding(3, 2) + model = _EncodedModel(encoder=encoder) + query_points = tf.random.uniform([3, 5], minval=0, maxval=3, dtype=tf.int32) + mean, var = model.predict(query_points) + assert mean.shape == (3, 5, 2) + npt.assert_allclose(mean, encoder(query_points)) diff --git a/tests/unit/objectives/test_multi_objectives.py b/tests/unit/objectives/test_multi_objectives.py index c3063be5bb..67d1fe98df 100644 --- a/tests/unit/objectives/test_multi_objectives.py +++ b/tests/unit/objectives/test_multi_objectives.py @@ -19,6 +19,7 @@ from check_shapes.exceptions import ShapeMismatchError from trieste.objectives.multi_objectives import DTLZ1, DTLZ2, VLMOP2, MultiObjectiveTestProblem +from trieste.space import SearchSpaceType from trieste.types import TensorType @@ -117,7 +118,7 @@ def test_dtlz2_has_expected_output( ], ) def test_gen_pareto_front_is_equal_to_math_defined( - obj_type: Callable[[int, int], MultiObjectiveTestProblem], + obj_type: Callable[[int, int], MultiObjectiveTestProblem[SearchSpaceType]], input_dim: int, num_obj: int, gen_pf_num: int, @@ -140,7 +141,7 @@ def test_gen_pareto_front_is_equal_to_math_defined( ], ) def test_func_raises_specified_input_dim_not_align_with_actual_input_dim( - obj_inst: MultiObjectiveTestProblem, actual_x: TensorType + obj_inst: MultiObjectiveTestProblem[SearchSpaceType], actual_x: TensorType ) -> None: with pytest.raises(ShapeMismatchError): obj_inst.objective(actual_x) @@ -160,7 +161,7 @@ def test_func_raises_specified_input_dim_not_align_with_actual_input_dim( @pytest.mark.parametrize("num_obs", [1, 5, 10]) @pytest.mark.parametrize("dtype", [tf.float32, tf.float64]) def test_objective_has_correct_shape_and_dtype( - problem: MultiObjectiveTestProblem, + problem: MultiObjectiveTestProblem[SearchSpaceType], input_dim: int, num_obj: int, num_obs: int, diff --git a/tests/unit/objectives/test_single_objectives.py b/tests/unit/objectives/test_single_objectives.py index fbe0c1e9b1..882f56336d 100644 --- a/tests/unit/objectives/test_single_objectives.py +++ b/tests/unit/objectives/test_single_objectives.py @@ -36,6 +36,7 @@ SingleObjectiveTestProblem, Trid10, ) +from trieste.space import Box, SearchSpaceType @pytest.fixture( @@ -58,12 +59,12 @@ Levy8, ], ) -def _problem_fixture(request: Any) -> Tuple[SingleObjectiveTestProblem, int]: +def _problem_fixture(request: Any) -> Tuple[SingleObjectiveTestProblem[SearchSpaceType], int]: return request.param def test_objective_maps_minimizers_to_minimum( - problem: SingleObjectiveTestProblem, + problem: SingleObjectiveTestProblem[SearchSpaceType], ) -> None: objective = problem.objective minimizers = problem.minimizers @@ -74,7 +75,7 @@ def test_objective_maps_minimizers_to_minimum( def test_no_function_values_are_less_than_global_minimum( - problem: SingleObjectiveTestProblem, + problem: SingleObjectiveTestProblem[Box], ) -> None: objective = problem.objective space = problem.search_space @@ -86,7 +87,7 @@ def test_no_function_values_are_less_than_global_minimum( @pytest.mark.parametrize("num_obs", [5, 1]) @pytest.mark.parametrize("dtype", [tf.float32, tf.float64]) def test_objective_has_correct_shape_and_dtype( - problem: SingleObjectiveTestProblem, + problem: SingleObjectiveTestProblem[SearchSpaceType], num_obs: int, dtype: tf.DType, ) -> None: @@ -120,7 +121,7 @@ def test_objective_has_correct_shape_and_dtype( ) @pytest.mark.parametrize("num_obs", [5, 1]) def test_search_space_has_correct_shape_and_default_dtype( - problem: SingleObjectiveTestProblem, + problem: SingleObjectiveTestProblem[SearchSpaceType], input_dim: int, num_obs: int, ) -> None: diff --git a/tests/unit/test_space.py b/tests/unit/test_space.py index b2caae1612..6fb473d130 100644 --- a/tests/unit/test_space.py +++ b/tests/unit/test_space.py @@ -1756,51 +1756,51 @@ def test_categorical_search_space__to_tags_raises_for_non_integers() -> None: [ ( CategoricalSearchSpace(["V"]), - tf.constant([[0], [0]]), - tf.constant([[1], [1]], dtype=tf.float32), + tf.constant([[0], [0]], dtype=tf.float64), + tf.constant([[1], [1]], dtype=tf.float64), ), ( - CategoricalSearchSpace(["R", "G", "B"]), - tf.constant([[0], [2], [1]]), + CategoricalSearchSpace(["R", "G", "B"], dtype=tf.float32), + tf.constant([[0], [2], [1]], dtype=tf.float32), tf.constant([[1, 0, 0], [0, 0, 1], [0, 1, 0]], dtype=tf.float32), ), ( CategoricalSearchSpace(["R", "G", "B"]), - tf.constant([[[[[0]]]]]), - tf.constant([[[[[1, 0, 0]]]]], dtype=tf.float32), + tf.constant([[[[[0]]]]], dtype=tf.float64), + tf.constant([[[[[1, 0, 0]]]]], dtype=tf.float64), ), ( - CategoricalSearchSpace(["R", "G", "B", "A"]), - tf.constant([[0], [2], [2]]), + CategoricalSearchSpace(["R", "G", "B", "A"], dtype=tf.float32), + tf.constant([[0], [2], [2]], dtype=tf.float32), tf.constant([[1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]], dtype=tf.float32), ), ( CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]), - tf.constant([[0, 0], [2, 0], [1, 1]]), - tf.constant([[1, 0, 0, 1, 0], [0, 0, 1, 1, 0], [0, 1, 0, 0, 1]], dtype=tf.float32), + tf.constant([[0, 0], [2, 0], [1, 1]], dtype=tf.float64), + tf.constant([[1, 0, 0, 1, 0], [0, 0, 1, 1, 0], [0, 1, 0, 0, 1]], dtype=tf.float64), ), ( CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]), - tf.constant([[[0, 0], [0, 0]], [[2, 0], [1, 1]]]), + tf.constant([[[0, 0], [0, 0]], [[2, 0], [1, 1]]], dtype=tf.float64), tf.constant( [[[1, 0, 0, 1, 0], [1, 0, 0, 1, 0]], [[0, 0, 1, 1, 0], [0, 1, 0, 0, 1]]], - dtype=tf.float32, + dtype=tf.float64, ), ), ( TaggedProductSearchSpace([Box([0.0], [1.0]), CategoricalSearchSpace(["R", "G", "B"])]), - tf.constant([[0.5, 0], [0.3, 2]]), - tf.constant([[0.5, 1, 0, 0], [0.3, 0, 0, 1]], dtype=tf.float32), + tf.constant([[0.5, 0], [0.3, 2]], dtype=tf.float64), + tf.constant([[0.5, 1, 0, 0], [0.3, 0, 0, 1]], dtype=tf.float64), ), ( TaggedProductSearchSpace([Box([0.0], [1.0]), CategoricalSearchSpace(["R", "G", "B"])]), - tf.constant([[[0.5, 0]], [[0.3, 2]]]), - tf.constant([[[0.5, 1, 0, 0]], [[0.3, 0, 0, 1]]], dtype=tf.float32), + tf.constant([[[0.5, 0]], [[0.3, 2]]], dtype=tf.float64), + tf.constant([[[0.5, 1, 0, 0]], [[0.3, 0, 0, 1]]], dtype=tf.float64), ), ( Box([0.0], [1.0]), - tf.constant([[0.5], [0.3]]), - tf.constant([[0.5], [0.3]], dtype=tf.float32), + tf.constant([[0.5], [0.3]], dtype=tf.float64), + tf.constant([[0.5], [0.3]], dtype=tf.float64), ), ], ) diff --git a/trieste/acquisition/optimizer.py b/trieste/acquisition/optimizer.py index 01c33ece7e..d4ab1600d5 100644 --- a/trieste/acquisition/optimizer.py +++ b/trieste/acquisition/optimizer.py @@ -34,7 +34,7 @@ Box, CollectionSearchSpace, Constraint, - DiscreteSearchSpace, + GeneralDiscreteSearchSpace, SearchSpace, SearchSpaceType, TaggedMultiSearchSpace, @@ -101,7 +101,7 @@ def automatic_optimizer_selector( :return: The batch of points in ``space`` that maximises ``target_func``, with shape [1, D]. """ - if isinstance(space, DiscreteSearchSpace): + if isinstance(space, GeneralDiscreteSearchSpace): return optimize_discrete(space, target_func) elif isinstance(space, (Box, CollectionSearchSpace)): @@ -151,11 +151,11 @@ def _get_max_discrete_points( def optimize_discrete( - space: DiscreteSearchSpace, + space: GeneralDiscreteSearchSpace, target_func: Union[AcquisitionFunction, Tuple[AcquisitionFunction, int]], ) -> TensorType: """ - An :const:`AcquisitionOptimizer` for :class:'DiscreteSearchSpace' spaces. + An :const:`AcquisitionOptimizer` for :class:'GeneralDiscreteSearchSpace' spaces. When this functions receives an acquisition-integer tuple as its `target_func`,it evaluates all the points in the search space for each of the individual V functions making @@ -734,7 +734,7 @@ def get_bounds_of_box_relaxation_around_point( space_with_fixed_discrete = space for tag in space.subspace_tags: if isinstance( - space.get_subspace(tag), DiscreteSearchSpace + space.get_subspace(tag), GeneralDiscreteSearchSpace ): # convert discrete subspaces to box spaces. subspace_value = space.get_subspace_component(tag, current_point) space_with_fixed_discrete = space_with_fixed_discrete.fix_subspace(tag, subspace_value) diff --git a/trieste/models/gpflow/builders.py b/trieste/models/gpflow/builders.py index 638ae55b9d..813cef6bdf 100644 --- a/trieste/models/gpflow/builders.py +++ b/trieste/models/gpflow/builders.py @@ -21,7 +21,7 @@ from __future__ import annotations import math -from typing import Optional, Sequence, Type +from typing import Callable, Optional, Sequence, Type import gpflow import tensorflow as tf @@ -30,9 +30,10 @@ from gpflow.models import GPR, SGPR, SVGP, VGP, GPModel from ...data import Dataset, split_dataset_by_fidelity -from ...space import Box, SearchSpace +from ...space import Box, EncoderFunction, SearchSpace, one_hot_encoded_space, one_hot_encoder from ...types import TensorType from ..gpflow.models import GaussianProcessRegression +from ..interfaces import encode_dataset # NOTE: As a static non-Tensor, this should really be a tf.constant (like the other constants). # However, changing it breaks serialisation during the expected_improvement.pct.py notebook. @@ -88,6 +89,8 @@ def build_gpr( likelihood_variance: Optional[float] = None, trainable_likelihood: bool = False, kernel: Optional[gpflow.kernels.Kernel] = None, + encoder: EncoderFunction | None = None, + space_encoder: Callable[[SearchSpace], SearchSpace] | None = None, ) -> GPR: """ Build a :class:`~gpflow.models.GPR` model with sensible initial parameters and @@ -118,8 +121,20 @@ def build_gpr( non-trainable. By default set to `False`. :param kernel: The kernel to use in the model, defaults to letting the function set up a :class:`~gpflow.kernels.Matern52` kernel. + :param encoder: Encoder with which to transform the dataset before training. Defaults to + one_hot_encoder if the search_space is specified. + :param space_encoder: Encoder with which to transform search_space before generating a kernel. + Defaults to one_hot_encoded_space. :return: A :class:`~gpflow.models.GPR` model. """ + if search_space is not None: + encoder = one_hot_encoder(search_space) if encoder is None else encoder + space_encoder = one_hot_encoded_space if space_encoder is None else space_encoder + search_space = space_encoder(search_space) + + if encoder is not None: + data = encode_dataset(data, encoder) + empirical_mean, empirical_variance, _ = _get_data_stats(data) if kernel is None: diff --git a/trieste/models/gpflow/interface.py b/trieste/models/gpflow/interface.py index 5fad89f679..9448333365 100644 --- a/trieste/models/gpflow/interface.py +++ b/trieste/models/gpflow/interface.py @@ -19,22 +19,24 @@ import gpflow import tensorflow as tf -from check_shapes import inherit_check_shapes from gpflow.models import GPModel from gpflow.posteriors import BasePosterior, PrecomputeCacheType -from typing_extensions import Protocol +from typing_extensions import Protocol, final from ... import logging from ...data import Dataset +from ...space import EncoderFunction from ...types import TensorType from ..interfaces import ( + EncodedProbabilisticModel, + EncodedSupportsPredictJoint, + EncodedSupportsPredictY, + EncodedTrainableProbabilisticModel, HasReparamSampler, ReparametrizationSampler, SupportsGetKernel, SupportsGetObservationNoise, SupportsPredictJoint, - SupportsPredictY, - TrainableProbabilisticModel, ) from ..optimizer import Optimizer from ..utils import ( @@ -46,27 +48,39 @@ class GPflowPredictor( - SupportsPredictJoint, + EncodedSupportsPredictJoint, SupportsGetKernel, SupportsGetObservationNoise, - SupportsPredictY, + EncodedSupportsPredictY, HasReparamSampler, - TrainableProbabilisticModel, + EncodedTrainableProbabilisticModel, + EncodedProbabilisticModel, ABC, ): """A trainable wrapper for a GPflow Gaussian process model.""" - def __init__(self, optimizer: Optimizer | None = None): + def __init__(self, optimizer: Optimizer | None = None, encoder: EncoderFunction | None = None): """ :param optimizer: The optimizer with which to train the model. Defaults to :class:`~trieste.models.optimizer.Optimizer` with :class:`~gpflow.optimizers.Scipy`. + :param encoder: Optional encoder with which to transform query points before + generating predictions. """ if optimizer is None: optimizer = Optimizer(gpflow.optimizers.Scipy(), compile=True) self._optimizer = optimizer + self._encoder = encoder self._posterior: Optional[BasePosterior] = None + @property + def encoder(self) -> EncoderFunction | None: + return self._encoder + + @encoder.setter + def encoder(self, encoder: EncoderFunction | None) -> None: + self._encoder = encoder + @property def optimizer(self) -> Optimizer: """The optimizer with which to train the model.""" @@ -102,16 +116,14 @@ def update_posterior_cache(self) -> None: def model(self) -> GPModel: """The underlying GPflow model.""" - @inherit_check_shapes - def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + def predict_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]: mean, cov = (self._posterior or self.model).predict_f(query_points) # posterior predict can return negative variance values [cf GPFlow issue #1813] if self._posterior is not None: cov = tf.clip_by_value(cov, 1e-12, cov.dtype.max) return mean, cov - @inherit_check_shapes - def predict_joint(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + def predict_joint_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]: mean, cov = (self._posterior or self.model).predict_f(query_points, full_cov=True) # posterior predict can return negative variance values [cf GPFlow issue #1813] if self._posterior is not None: @@ -120,12 +132,10 @@ def predict_joint(self, query_points: TensorType) -> tuple[TensorType, TensorTyp ) return mean, cov - @inherit_check_shapes - def sample(self, query_points: TensorType, num_samples: int) -> TensorType: + def sample_encoded(self, query_points: TensorType, num_samples: int) -> TensorType: return self.model.predict_f_samples(query_points, num_samples) - @inherit_check_shapes - def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + def predict_y_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]: return self.model.predict_y(query_points) def get_kernel(self) -> gpflow.kernels.Kernel: @@ -206,3 +216,21 @@ def covariance_between_points( (L being the number of latent GPs = number of output dimensions) """ raise NotImplementedError + + +class EncodedSupportsCovarianceBetweenPoints( + EncodedProbabilisticModel, SupportsCovarianceBetweenPoints +): + @abstractmethod + def covariance_between_points_encoded( + self, query_points_1: TensorType, query_points_2: TensorType + ) -> TensorType: + """Implementation of covariance_between_points on encoded query points.""" + + @final + def covariance_between_points( + self, query_points_1: TensorType, query_points_2: TensorType + ) -> TensorType: + return self.covariance_between_points_encoded( + self.encode(query_points_1), self.encode(query_points_2) + ) diff --git a/trieste/models/gpflow/models.py b/trieste/models/gpflow/models.py index 235325e399..51a29e71ba 100644 --- a/trieste/models/gpflow/models.py +++ b/trieste/models/gpflow/models.py @@ -37,11 +37,12 @@ check_and_extract_fidelity_query_points, split_dataset_by_fidelity, ) +from ...space import EncoderFunction from ...types import TensorType from ...utils import DEFAULTS, jit from ...utils.misc import flatten_leading_dims from ..interfaces import ( - FastUpdateModel, + EncodedFastUpdateModel, HasTrajectorySampler, SupportsCovarianceWithTopFidelity, SupportsGetInducingVariables, @@ -52,7 +53,7 @@ ) from ..optimizer import BatchOptimizer, Optimizer, OptimizeResult from .inducing_point_selectors import InducingPointSelector -from .interface import GPflowPredictor, SupportsCovarianceBetweenPoints +from .interface import EncodedSupportsCovarianceBetweenPoints, GPflowPredictor from .sampler import DecoupledTrajectorySampler, RandomFourierFeatureTrajectorySampler from .utils import ( _covariance_between_points_for_variational_models, @@ -66,8 +67,8 @@ class GaussianProcessRegression( GPflowPredictor, - FastUpdateModel, - SupportsCovarianceBetweenPoints, + EncodedFastUpdateModel, + EncodedSupportsCovarianceBetweenPoints, SupportsGetInternalData, HasTrajectorySampler, ): @@ -90,6 +91,7 @@ def __init__( num_kernel_samples: int = 10, num_rff_features: int = 1000, use_decoupled_sampler: bool = True, + encoder: EncoderFunction | None = None, ): """ :param model: The GPflow model to wrap. @@ -105,8 +107,10 @@ def __init__( :param use_decoupled_sampler: If True use a decoupled random Fourier feature sampler, else just use a random Fourier feature sampler. The decoupled sampler suffers less from overestimating variance and can typically get away with a lower num_rff_features. + :param encoder: Optional encoder with which to transform query points before + generating predictions. """ - super().__init__(optimizer) + super().__init__(optimizer, encoder) self._model = model check_optimizer(self.optimizer) @@ -159,12 +163,11 @@ def _ensure_variable_model_data(self) -> None: ), ) - @inherit_check_shapes - def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]: - f_mean, f_var = self.predict(query_points) + def predict_y_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + f_mean, f_var = self.predict_encoded(query_points) return self.model.likelihood.predict_mean_and_var(query_points, f_mean, f_var) - def update(self, dataset: Dataset) -> None: + def update_encoded(self, dataset: Dataset) -> None: self._ensure_variable_model_data() x, y = self.model.data[0].value(), self.model.data[1].value() @@ -181,7 +184,7 @@ def update(self, dataset: Dataset) -> None: self.model.data[1].assign(dataset.observations) self.update_posterior_cache() - def covariance_between_points( + def covariance_between_points_encoded( self, query_points_1: TensorType, query_points_2: TensorType ) -> TensorType: r""" @@ -249,7 +252,7 @@ def covariance_between_points( return cov - def optimize(self, dataset: Dataset) -> OptimizeResult: + def optimize_encoded(self, dataset: Dataset) -> OptimizeResult: """ Optimize the model with the specified `dataset`. @@ -269,7 +272,6 @@ def optimize(self, dataset: Dataset) -> OptimizeResult: :param dataset: The data with which to optimize the `model`. """ - num_trainable_params_with_priors_or_constraints = tf.reduce_sum( [ tf.size(param) @@ -349,7 +351,7 @@ def get_internal_data(self) -> Dataset: """ return Dataset(self.model.data[0], self.model.data[1]) - def conditional_predict_f( + def conditional_predict_f_encoded( self, query_points: TensorType, additional_data: Dataset ) -> tuple[TensorType, TensorType]: """ @@ -374,10 +376,10 @@ def conditional_predict_f( "should have shape [M, D]", ) - mean_add, cov_add = self.predict_joint( + mean_add, cov_add = self.predict_joint_encoded( additional_data.query_points ) # [..., N, L], [..., L, N, N] - mean_qp, var_qp = self.predict(query_points) # [M, L], [M, L] + mean_qp, var_qp = self.predict_encoded(query_points) # [M, L], [M, L] cov_cross = self.covariance_between_points( additional_data.query_points, query_points @@ -414,7 +416,7 @@ def conditional_predict_f( return mean_qp_new, var_qp_new - def conditional_predict_joint( + def conditional_predict_joint_encoded( self, query_points: TensorType, additional_data: Dataset ) -> tuple[TensorType, TensorType]: """ @@ -445,7 +447,7 @@ def conditional_predict_joint( query_points_r = tf.broadcast_to(query_points, new_shape) # [..., M, D] points = tf.concat([additional_data.query_points, query_points_r], axis=-2) # [..., N+M, D] - mean, cov = self.predict_joint(points) # [..., N+M, L], [..., L, N+M, N+M] + mean, cov = self.predict_joint_encoded(points) # [..., N+M, L], [..., L, N+M, N+M] N = tf.shape(additional_data.query_points)[-2] @@ -484,7 +486,7 @@ def conditional_predict_joint( return mean_qp_new, cov_qp_new - def conditional_predict_f_sample( + def conditional_predict_f_sample_encoded( self, query_points: TensorType, additional_data: Dataset, num_samples: int ) -> TensorType: """ @@ -505,7 +507,7 @@ def conditional_predict_f_sample( ) # [..., (S), P, N] return tf.linalg.adjoint(samples) # [..., (S), N, L] - def conditional_predict_y( + def conditional_predict_y_encoded( self, query_points: TensorType, additional_data: Dataset ) -> tuple[TensorType, TensorType]: """ @@ -524,7 +526,7 @@ def conditional_predict_y( class SparseGaussianProcessRegression( GPflowPredictor, - SupportsCovarianceBetweenPoints, + EncodedSupportsCovarianceBetweenPoints, SupportsGetInducingVariables, SupportsGetInternalData, HasTrajectorySampler, @@ -551,6 +553,7 @@ def __init__( inducing_point_selector: Optional[ InducingPointSelector[SparseGaussianProcessRegression] ] = None, + encoder: EncoderFunction | None = None, ): """ :param model: The GPflow model to wrap. @@ -566,8 +569,10 @@ def __init__( :raise NotImplementedError (or ValueError): If we try to use a model with invalid ``num_rff_features``, or an ``inducing_point_selector`` with a model that has more than one set of inducing points. + :param encoder: Optional encoder with which to transform query points before + generating predictions. """ - super().__init__(optimizer) + super().__init__(optimizer, encoder) self._model = model check_optimizer(self.optimizer) @@ -609,9 +614,8 @@ def inducing_point_selector( ) -> Optional[InducingPointSelector[SparseGaussianProcessRegression]]: return self._inducing_point_selector - @inherit_check_shapes - def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]: - f_mean, f_var = self.predict(query_points) + def predict_y_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + f_mean, f_var = self.predict_encoded(query_points) return self.model.likelihood.predict_mean_and_var(query_points, f_mean, f_var) def _ensure_variable_model_data(self) -> None: @@ -637,7 +641,7 @@ def _ensure_variable_model_data(self) -> None: if not is_variable(self._model.num_data): self._model.num_data = tf.Variable(self._model.num_data, trainable=False) - def optimize(self, dataset: Dataset) -> OptimizeResult: + def optimize_encoded(self, dataset: Dataset) -> OptimizeResult: """ Optimize the model with the specified `dataset`. @@ -647,7 +651,7 @@ def optimize(self, dataset: Dataset) -> OptimizeResult: self.update_posterior_cache() return result - def update(self, dataset: Dataset) -> None: + def update_encoded(self, dataset: Dataset) -> None: self._ensure_variable_model_data() x, y = self.model.data[0].value(), self.model.data[1].value() @@ -779,7 +783,7 @@ def get_inducing_variables( return inducing_points, q_mu, q_sqrt, whiten - def covariance_between_points( + def covariance_between_points_encoded( self, query_points_1: TensorType, query_points_2: TensorType ) -> TensorType: r""" @@ -837,7 +841,7 @@ def get_internal_data(self) -> Dataset: class SparseVariational( GPflowPredictor, - SupportsCovarianceBetweenPoints, + EncodedSupportsCovarianceBetweenPoints, SupportsGetInducingVariables, HasTrajectorySampler, ): @@ -858,6 +862,7 @@ def __init__( optimizer: Optimizer | None = None, num_rff_features: int = 1000, inducing_point_selector: Optional[InducingPointSelector[SparseVariational]] = None, + encoder: EncoderFunction | None = None, ): """ :param model: The underlying GPflow sparse variational model. @@ -874,6 +879,8 @@ def __init__( the optimization progresses. :raise NotImplementedError: If we try to use an inducing_point_selector with a model that has more than one set of inducing points. + :param encoder: Optional encoder with which to transform query points before + generating predictions. """ tf.debugging.assert_rank( @@ -883,7 +890,7 @@ def __init__( if optimizer is None: optimizer = BatchOptimizer(tf.optimizers.Adam(), batch_size=100, compile=True) - super().__init__(optimizer) + super().__init__(optimizer, encoder) self._model = model if num_rff_features <= 0: @@ -932,12 +939,11 @@ def model(self) -> SVGP: def inducing_point_selector(self) -> Optional[InducingPointSelector[SparseVariational]]: return self._inducing_point_selector - @inherit_check_shapes - def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]: - f_mean, f_var = self.predict(query_points) + def predict_y_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + f_mean, f_var = self.predict_encoded(query_points) return self.model.likelihood.predict_mean_and_var(query_points, f_mean, f_var) - def update(self, dataset: Dataset) -> None: + def update_encoded(self, dataset: Dataset) -> None: self._ensure_variable_model_data() # Hard-code asserts from _assert_data_is_compatible because model doesn't store dataset @@ -979,7 +985,7 @@ def update(self, dataset: Dataset) -> None: self._update_inducing_variables(new_inducing_points) self.update_posterior_cache() - def optimize(self, dataset: Dataset) -> OptimizeResult: + def optimize_encoded(self, dataset: Dataset) -> OptimizeResult: """ Optimize the model with the specified `dataset`. @@ -1017,7 +1023,9 @@ def _update_inducing_variables(self, new_inducing_points: TensorType) -> None: if whiten: new_q_mu, new_q_sqrt = _whiten_points(self, new_inducing_points) else: - new_q_mu, new_f_cov = self.predict_joint(new_inducing_points) # [N, L], [L, N, N] + new_q_mu, new_f_cov = self.predict_joint_encoded( + new_inducing_points + ) # [N, L], [L, N, N] new_q_mu -= self.model.mean_function(new_inducing_points) jitter_mat = DEFAULTS.JITTER * tf.eye( tf.shape(new_inducing_points)[0], dtype=new_f_cov.dtype @@ -1062,7 +1070,7 @@ def get_inducing_variables( return inducing_points, self.model.q_mu, self.model.q_sqrt, self.model.whiten - def covariance_between_points( + def covariance_between_points_encoded( self, query_points_1: TensorType, query_points_2: TensorType ) -> TensorType: r""" @@ -1099,7 +1107,7 @@ def trajectory_sampler(self) -> TrajectorySampler[SparseVariational]: class VariationalGaussianProcess( GPflowPredictor, - SupportsCovarianceBetweenPoints, + EncodedSupportsCovarianceBetweenPoints, SupportsGetInducingVariables, HasTrajectorySampler, ): @@ -1132,6 +1140,7 @@ def __init__( use_natgrads: bool = False, natgrad_gamma: Optional[float] = None, num_rff_features: int = 1000, + encoder: EncoderFunction | None = None, ): """ :param model: The GPflow :class:`~gpflow.models.VGP`. @@ -1150,6 +1159,8 @@ def __init__( :raise ValueError (or InvalidArgumentError): If ``model``'s :attr:`q_sqrt` is not rank 3 or if attempting to combine natural gradients with a :class:`~gpflow.optimizers.Scipy` optimizer. + :param encoder: Optional encoder with which to transform query points before + generating predictions. """ tf.debugging.assert_rank(model.q_sqrt, 3) @@ -1158,7 +1169,7 @@ def __init__( elif optimizer is None and use_natgrads: optimizer = BatchOptimizer(tf.optimizers.Adam(), batch_size=100, compile=True) - super().__init__(optimizer) + super().__init__(optimizer, encoder) check_optimizer(self.optimizer) @@ -1245,12 +1256,11 @@ def __repr__(self) -> str: def model(self) -> VGP: return self._model - @inherit_check_shapes - def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]: - f_mean, f_var = self.predict(query_points) + def predict_y_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + f_mean, f_var = self.predict_encoded(query_points) return self.model.likelihood.predict_mean_and_var(query_points, f_mean, f_var) - def update(self, dataset: Dataset, *, jitter: float = DEFAULTS.JITTER) -> None: + def update_encoded(self, dataset: Dataset, *, jitter: float = DEFAULTS.JITTER) -> None: """ Update the model given the specified ``dataset``. Does not train the model. @@ -1262,7 +1272,7 @@ def update(self, dataset: Dataset, *, jitter: float = DEFAULTS.JITTER) -> None: update_vgp_data(self.model, (dataset.query_points, dataset.observations)) self.update_posterior_cache() - def optimize(self, dataset: Dataset) -> Optional[OptimizeResult]: + def optimize_encoded(self, dataset: Dataset) -> Optional[OptimizeResult]: """ :class:`VariationalGaussianProcess` has a custom `optimize` method that (optionally) permits alternating between standard optimization steps (for kernel parameters) and natural gradient @@ -1343,7 +1353,7 @@ def trajectory_sampler(self) -> TrajectorySampler[VariationalGaussianProcess]: return DecoupledTrajectorySampler(self, self._num_rff_features) - def covariance_between_points( + def covariance_between_points_encoded( self, query_points_1: TensorType, query_points_2: TensorType ) -> TensorType: r""" diff --git a/trieste/models/interfaces.py b/trieste/models/interfaces.py index ae265d4f85..223408ce1d 100644 --- a/trieste/models/interfaces.py +++ b/trieste/models/interfaces.py @@ -15,14 +15,15 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Callable, Generic, Optional, Sequence, TypeVar +from typing import Any, Callable, Generic, Optional, Sequence, TypeVar, overload import gpflow import tensorflow as tf -from check_shapes import check_shapes -from typing_extensions import Protocol, runtime_checkable +from check_shapes import check_shapes, inherit_check_shapes +from typing_extensions import Protocol, final, runtime_checkable from ..data import Dataset +from ..space import EncoderFunction from ..types import TensorType from ..utils import DEFAULTS @@ -742,3 +743,164 @@ def covariance_with_top_fidelity(self, query_points: TensorType) -> TensorType: :return: The covariance with the top fidelity for the `query_points`, of shape [N, P] """ raise NotImplementedError + + +def encode_dataset(dataset: Dataset, encoder: EncoderFunction) -> Dataset: + """Return a new Dataset with the query points encoded using the given encoder.""" + return Dataset(encoder(dataset.query_points), dataset.observations) + + +class EncodedProbabilisticModel(ProbabilisticModel): + """A probabilistic model with an associated query point encoder. + + Classes that inherit from this (or the other associated mixins below) should implement the + relevant _encoded methods (e.g. predict_encoded instead of predict), to which the public + methods delegate after encoding their input. Take care to use the correct methods internally + to avoid encoding twice accidentally. + """ + + @property + @abstractmethod + def encoder(self) -> EncoderFunction | None: + """Query point encoder.""" + + @overload + def encode(self, points: TensorType) -> TensorType: + ... + + @overload + def encode(self, points: Dataset) -> Dataset: + ... + + def encode(self, points: Dataset | TensorType) -> Dataset | TensorType: + """Encode points or a Dataset using the query point encoder.""" + if self.encoder is None: + return points + elif isinstance(points, Dataset): + return encode_dataset(points, self.encoder) + else: + return self.encoder(points) + + @abstractmethod + def predict_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + """Implementation of predict on encoded query points.""" + + @abstractmethod + def sample_encoded(self, query_points: TensorType, num_samples: int) -> TensorType: + """Implementation of sample on encoded query points.""" + + @final + @inherit_check_shapes + def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + return self.predict_encoded(self.encode(query_points)) + + @final + @inherit_check_shapes + def sample(self, query_points: TensorType, num_samples: int) -> TensorType: + return self.sample_encoded(self.encode(query_points), num_samples) + + +class EncodedTrainableProbabilisticModel(EncodedProbabilisticModel, TrainableProbabilisticModel): + """A trainable probabilistic model with an associated query point encoder.""" + + @abstractmethod + def update_encoded(self, dataset: Dataset) -> None: + """Implementation of update on the encoded dataset.""" + + @abstractmethod + def optimize_encoded(self, dataset: Dataset) -> Any: + """Implementation of optimize on the encoded dataset.""" + + @final + def update(self, dataset: Dataset) -> None: + return self.update_encoded(self.encode(dataset)) + + @final + def optimize(self, dataset: Dataset) -> Any: + return self.optimize_encoded(self.encode(dataset)) + + +class EncodedSupportsPredictJoint(EncodedProbabilisticModel, SupportsPredictJoint): + """A probabilistic model that supports predict_joint with an associated query point encoder.""" + + @abstractmethod + def predict_joint_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + """Implementation of predict_joint on encoded query points.""" + + @final + @inherit_check_shapes + def predict_joint(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + return self.predict_joint_encoded(self.encode(query_points)) + + +class EncodedSupportsPredictY(EncodedProbabilisticModel, SupportsPredictY): + """A probabilistic model that supports predict_y with an associated query point encoder.""" + + @abstractmethod + def predict_y_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + """Implementation of predict_y on encoded query points.""" + + @final + @inherit_check_shapes + def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + return self.predict_y_encoded(self.encode(query_points)) + + +class EncodedFastUpdateModel(EncodedProbabilisticModel, FastUpdateModel): + """A fast update model with an associated query point encoder.""" + + @abstractmethod + def conditional_predict_f_encoded( + self, query_points: TensorType, additional_data: Dataset + ) -> tuple[TensorType, TensorType]: + """Implementation of conditional_predict_f on encoded query points.""" + + @abstractmethod + def conditional_predict_joint_encoded( + self, query_points: TensorType, additional_data: Dataset + ) -> tuple[TensorType, TensorType]: + """Implementation of conditional_predict_joint on encoded query points.""" + + @abstractmethod + def conditional_predict_f_sample_encoded( + self, query_points: TensorType, additional_data: Dataset, num_samples: int + ) -> TensorType: + """Implementation of conditional_predict_f_sample on encoded query points.""" + + @abstractmethod + def conditional_predict_y_encoded( + self, query_points: TensorType, additional_data: Dataset + ) -> tuple[TensorType, TensorType]: + """Implementation of conditional_predict_y on encoded query points.""" + + @final + def conditional_predict_f( + self, query_points: TensorType, additional_data: Dataset + ) -> tuple[TensorType, TensorType]: + return self.conditional_predict_f_encoded( + self.encode(query_points), self.encode(additional_data) + ) + + @final + def conditional_predict_joint( + self, query_points: TensorType, additional_data: Dataset + ) -> tuple[TensorType, TensorType]: + return self.conditional_predict_joint_encoded( + self.encode(query_points), self.encode(additional_data) + ) + + @final + def conditional_predict_f_sample( + self, query_points: TensorType, additional_data: Dataset, num_samples: int + ) -> TensorType: + return self.conditional_predict_f_sample_encoded( + self.encode(query_points), self.encode(additional_data), num_samples + ) + + @final + def conditional_predict_y( + self, query_points: TensorType, additional_data: Dataset + ) -> tuple[TensorType, TensorType]: + return self.conditional_predict_y_encoded( + self.encode(query_points), self.encode(additional_data) + ) diff --git a/trieste/models/keras/interface.py b/trieste/models/keras/interface.py index 18a4afbc48..6868193aa7 100644 --- a/trieste/models/keras/interface.py +++ b/trieste/models/keras/interface.py @@ -19,31 +19,38 @@ import tensorflow as tf import tensorflow_probability as tfp -from check_shapes import inherit_check_shapes from typing_extensions import Protocol, runtime_checkable from ...data import Dataset +from ...space import EncoderFunction from ...types import TensorType -from ..interfaces import ProbabilisticModel +from ..interfaces import EncodedProbabilisticModel, ProbabilisticModel from ..optimizer import KerasOptimizer -class KerasPredictor(ProbabilisticModel, ABC): +class KerasPredictor(EncodedProbabilisticModel, ABC): """ This is an interface for trainable wrappers of TensorFlow and Keras neural network models. """ - def __init__(self, optimizer: Optional[KerasOptimizer] = None): + def __init__( + self, + optimizer: Optional[KerasOptimizer] = None, + encoder: EncoderFunction | None = None, + ): """ :param optimizer: The optimizer wrapper containing the optimizer with which to train the model and arguments for the wrapper and the optimizer. The optimizer must be an instance of a :class:`~tf.optimizers.Optimizer`. Defaults to :class:`~tf.optimizers.Adam` optimizer with default parameters. + :param encoder: Optional encoder with which to transform query points before + generating predictions. :raise ValueError: If the optimizer is not an instance of :class:`~tf.optimizers.Optimizer`. """ if optimizer is None: optimizer = KerasOptimizer(tf.optimizers.Adam()) self._optimizer = optimizer + self._encoder = encoder if not isinstance(optimizer.optimizer, tf.optimizers.Optimizer): raise ValueError( @@ -62,12 +69,18 @@ def optimizer(self) -> KerasOptimizer: """The optimizer wrapper for training the model.""" return self._optimizer - @inherit_check_shapes - def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + @property + def encoder(self) -> EncoderFunction | None: + return self._encoder + + @encoder.setter + def encoder(self, encoder: EncoderFunction | None) -> None: + self._encoder = encoder + + def predict_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]: return self.model.predict(query_points) - @inherit_check_shapes - def sample(self, query_points: TensorType, num_samples: int) -> TensorType: + def sample_encoded(self, query_points: TensorType, num_samples: int) -> TensorType: raise NotImplementedError( """ KerasPredictor does not implement sampling. Acquisition diff --git a/trieste/models/keras/models.py b/trieste/models/keras/models.py index 10e27acc6f..3e6cabc9c9 100644 --- a/trieste/models/keras/models.py +++ b/trieste/models/keras/models.py @@ -22,14 +22,14 @@ import tensorflow as tf import tensorflow_probability as tfp import tensorflow_probability.python.distributions as tfd -from check_shapes import inherit_check_shapes from tensorflow.python.keras.callbacks import Callback from ... import logging from ...data import Dataset +from ...space import EncoderFunction from ...types import TensorType from ...utils import flatten_leading_dims -from ..interfaces import HasTrajectorySampler, TrainableProbabilisticModel, TrajectorySampler +from ..interfaces import EncodedTrainableProbabilisticModel, HasTrajectorySampler, TrajectorySampler from ..optimizer import KerasOptimizer from ..utils import write_summary_data_based_metrics from .architectures import KerasEnsemble, MultivariateNormalTriL @@ -39,7 +39,10 @@ class DeepEnsemble( - KerasPredictor, TrainableProbabilisticModel, DeepEnsembleModel, HasTrajectorySampler + KerasPredictor, + EncodedTrainableProbabilisticModel, + DeepEnsembleModel, + HasTrajectorySampler, ): """ A :class:`~trieste.model.TrainableProbabilisticModel` wrapper for deep ensembles built using @@ -75,7 +78,7 @@ class DeepEnsemble( behaviour you would like, you will need to subclass the model and overwrite the :meth:`optimize` method. - Currently we do not support setting up the model with dictionary config. + Currently, we do not support setting up the model with dictionary config. """ def __init__( @@ -86,6 +89,7 @@ def __init__( diversify: bool = False, continuous_optimisation: bool = True, compile_args: Optional[Mapping[str, Any]] = None, + encoder: EncoderFunction | None = None, ) -> None: """ :param model: A Keras ensemble model with probabilistic networks as ensemble members. The @@ -98,12 +102,12 @@ def __init__( See https://keras.io/api/models/model_training_apis/#fit-method for a list of possible arguments. :param bootstrap: Sample with replacement data for training each network in the ensemble. - By default set to `False`. + By default, set to `False`. :param diversify: Whether to use quantiles from the approximate Gaussian distribution of the ensemble as trajectories instead of mean predictions when calling :meth:`trajectory_sampler`. This mode can be used to increase the diversity in case of optimizing very large batches of trajectories. By - default set to `False`. + default, set to `False`. :param continuous_optimisation: If True (default), the optimizer will keep track of the number of epochs across BO iterations and use this number as initial_epoch. This is essential to allow monitoring of model training across BO iterations. @@ -112,6 +116,8 @@ def __init__( See https://keras.io/api/models/model_training_apis/#compile-method for a list of possible arguments. The ``optimizer``, ``loss`` and ``metrics`` arguments must not be included. + :param encoder: Optional encoder with which to transform query points before + generating predictions. :raise ValueError: If ``model`` is not an instance of :class:`~trieste.models.keras.KerasEnsemble`, or ensemble has less than two base learners (networks), or `compile_args` contains disallowed arguments. @@ -119,7 +125,7 @@ def __init__( if model.ensemble_size < 2: raise ValueError(f"Ensemble size must be greater than 1 but got {model.ensemble_size}.") - super().__init__(optimizer) + super().__init__(optimizer, encoder) if compile_args is None: compile_args = {} @@ -244,8 +250,7 @@ def ensemble_distributions(self, query_points: TensorType) -> tuple[tfd.Distribu x_transformed: dict[str, TensorType] = self.prepare_query_points(query_points) return self._model.model(x_transformed) - @inherit_check_shapes - def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: + def predict_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]: r""" Returns mean and variance at ``query_points`` for the whole ensemble. @@ -308,14 +313,13 @@ def predict_ensemble(self, query_points: TensorType) -> tuple[TensorType, Tensor :return: The predicted mean and variance of the observations at the specified ``query_points`` for each member of the ensemble. """ - ensemble_distributions = self.ensemble_distributions(query_points) + ensemble_distributions = self.ensemble_distributions(self.encode(query_points)) predicted_means = tf.convert_to_tensor([dist.mean() for dist in ensemble_distributions]) predicted_vars = tf.convert_to_tensor([dist.variance() for dist in ensemble_distributions]) return predicted_means, predicted_vars - @inherit_check_shapes - def sample(self, query_points: TensorType, num_samples: int) -> TensorType: + def sample_encoded(self, query_points: TensorType, num_samples: int) -> TensorType: """ Return ``num_samples`` samples at ``query_points``. We use the mixture approximation in :meth:`predict` for ``query_points`` and sample ``num_samples`` times from a Gaussian @@ -327,7 +331,7 @@ def sample(self, query_points: TensorType, num_samples: int) -> TensorType: [..., S, N] + E, where S is the number of samples. """ - predicted_means, predicted_vars = self.predict(query_points) + predicted_means, predicted_vars = self.predict_encoded(query_points) normal = tfp.distributions.Normal(predicted_means, tf.sqrt(predicted_vars)) samples = normal.sample(num_samples) @@ -345,7 +349,7 @@ def sample_ensemble(self, query_points: TensorType, num_samples: int) -> TensorT :return: The samples. For a predictive distribution with event shape E, this has shape [..., S, N] + E, where S is the number of samples. """ - ensemble_distributions = self.ensemble_distributions(query_points) + ensemble_distributions = self.ensemble_distributions(self.encode(query_points)) network_indices = sample_model_index(self.ensemble_size, num_samples) stacked_samples = [] @@ -365,7 +369,7 @@ def trajectory_sampler(self) -> TrajectorySampler[DeepEnsemble]: """ return DeepEnsembleTrajectorySampler(self, self._diversify) - def update(self, dataset: Dataset) -> None: + def update_encoded(self, dataset: Dataset) -> None: """ Neural networks are parametric models and do not need to update data. `TrainableProbabilisticModel` interface, however, requires an update method, so @@ -373,7 +377,7 @@ def update(self, dataset: Dataset) -> None: """ return - def optimize(self, dataset: Dataset) -> keras.callbacks.History: + def optimize_encoded(self, dataset: Dataset) -> keras.callbacks.History: """ Optimize the underlying Keras ensemble model with the specified ``dataset``. diff --git a/trieste/objectives/multi_objectives.py b/trieste/objectives/multi_objectives.py index f8708831a8..4a4a9bf04b 100644 --- a/trieste/objectives/multi_objectives.py +++ b/trieste/objectives/multi_objectives.py @@ -24,7 +24,7 @@ from check_shapes import check_shape, check_shapes from typing_extensions import Protocol -from ..space import Box +from ..space import Box, SearchSpaceType from ..types import TensorType from .single_objectives import ObjectiveTestProblem @@ -44,7 +44,7 @@ def __call__(self, n: int, seed: int | None = None) -> TensorType: @dataclass(frozen=True) -class MultiObjectiveTestProblem(ObjectiveTestProblem): +class MultiObjectiveTestProblem(ObjectiveTestProblem[SearchSpaceType]): """ Convenience container class for synthetic multi-objective test functions, containing a generator for the pareto optimal points, which can be used as a reference of performance @@ -73,7 +73,7 @@ def vlmop2(x: TensorType, d: int) -> TensorType: return tf.stack([y1, y2], axis=-1) -def VLMOP2(input_dim: int) -> MultiObjectiveTestProblem: +def VLMOP2(input_dim: int) -> MultiObjectiveTestProblem[Box]: """ The VLMOP2 problem, typically evaluated over :math:`[-2, 2]^d`. The idea pareto fronts lies on -1/sqrt(d) - 1/sqrt(d) and x1=...=xdim. @@ -152,7 +152,7 @@ def g(xM: TensorType) -> TensorType: ) -def DTLZ1(input_dim: int, num_objective: int) -> MultiObjectiveTestProblem: +def DTLZ1(input_dim: int, num_objective: int) -> MultiObjectiveTestProblem[Box]: """ The DTLZ1 problem, the idea pareto fronts lie on a linear hyper-plane. See :cite:`deb2002scalable` for details. @@ -212,7 +212,7 @@ def g(xM: TensorType) -> TensorType: ) -def DTLZ2(input_dim: int, num_objective: int) -> MultiObjectiveTestProblem: +def DTLZ2(input_dim: int, num_objective: int) -> MultiObjectiveTestProblem[Box]: """ The DTLZ2 problem, the idea pareto fronts lie on (part of) a unit hyper sphere. See :cite:`deb2002scalable` for details. diff --git a/trieste/objectives/multifidelity_objectives.py b/trieste/objectives/multifidelity_objectives.py index 4435f57d83..aa5c37942e 100644 --- a/trieste/objectives/multifidelity_objectives.py +++ b/trieste/objectives/multifidelity_objectives.py @@ -19,13 +19,13 @@ import numpy as np import tensorflow as tf -from ..space import Box, DiscreteSearchSpace, SearchSpace, TaggedProductSearchSpace +from ..space import Box, DiscreteSearchSpace, SearchSpace, SearchSpaceType, TaggedProductSearchSpace from ..types import TensorType from .single_objectives import SingleObjectiveTestProblem @dataclass(frozen=True) -class SingleObjectiveMultifidelityTestProblem(SingleObjectiveTestProblem): +class SingleObjectiveMultifidelityTestProblem(SingleObjectiveTestProblem[SearchSpaceType]): num_fidelities: int """The number of fidelities of test function""" diff --git a/trieste/objectives/single_objectives.py b/trieste/objectives/single_objectives.py index c0fd377403..a15ab70808 100644 --- a/trieste/objectives/single_objectives.py +++ b/trieste/objectives/single_objectives.py @@ -23,12 +23,12 @@ import math from dataclasses import dataclass from math import pi -from typing import Callable, Sequence +from typing import Callable, Generic, Sequence import tensorflow as tf from check_shapes import check_shapes -from ..space import Box, Constraint, LinearConstraint, NonlinearConstraint +from ..space import Box, Constraint, LinearConstraint, NonlinearConstraint, SearchSpaceType from ..types import TensorType ObjectiveTestFunction = Callable[[TensorType], TensorType] @@ -36,7 +36,7 @@ @dataclass(frozen=True) -class ObjectiveTestProblem: +class ObjectiveTestProblem(Generic[SearchSpaceType]): """ Convenience container class for synthetic objective test functions. """ @@ -47,7 +47,7 @@ class ObjectiveTestProblem: objective: ObjectiveTestFunction """The synthetic test function""" - search_space: Box + search_space: SearchSpaceType """The (continuous) search space of the test function""" @property @@ -62,7 +62,7 @@ def bounds(self) -> list[list[float]]: @dataclass(frozen=True) -class SingleObjectiveTestProblem(ObjectiveTestProblem): +class SingleObjectiveTestProblem(ObjectiveTestProblem[SearchSpaceType]): """ Convenience container class for synthetic single-objective test functions, including the global minimizers and minimum. diff --git a/trieste/space.py b/trieste/space.py index 326cc052b5..4a228460cf 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -518,6 +518,20 @@ def one_hot_encoder(space: SearchSpace) -> EncoderFunction: return space.one_hot_encoder if isinstance(space, HasOneHotEncoder) else lambda x: x +def one_hot_encoded_space(space: SearchSpace) -> SearchSpace: + "A bounded search space corresponding to the one-hot encoding of the given space." + + if isinstance(space, GeneralDiscreteSearchSpace) and isinstance(space, HasOneHotEncoder): + return DiscreteSearchSpace(space.one_hot_encoder(space.points)) + elif isinstance(space, TaggedProductSearchSpace): + spaces = [one_hot_encoded_space(space.get_subspace(tag)) for tag in space.subspace_tags] + return TaggedProductSearchSpace(spaces=spaces, tags=space.subspace_tags) + elif isinstance(space, HasOneHotEncoder): + raise NotImplementedError(f"Unsupported one-hot-encoded space {type(space)}") + else: + return space + + class CategoricalSearchSpace(GeneralDiscreteSearchSpace, HasOneHotEncoder): r""" A categorical :class:`SearchSpace` representing a finite set :math:`\mathcal{C}` of categories, @@ -586,6 +600,7 @@ def __init__( tags = [tuple(ts) for ts in category_names] self._tags = tags + self._dtype = dtype ranges = [tf.range(len(ts), dtype=dtype) for ts in tags] meshgrid = tf.meshgrid(*ranges, indexing="ij") @@ -633,7 +648,11 @@ def encoder(x: TensorType) -> TensorType: for ts in self.tags ] encoded = tf.concat( - [encoder(column) for encoder, column in zip(encoders, columns)], axis=1 + [ + tf.cast(encoder(column), dtype=self._dtype) + for encoder, column in zip(encoders, columns) + ], + axis=1, ) return unflatten(encoded)