Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trajectory sampling additive kernels #882

Open
wants to merge 5 commits into
base: khurram/traj_sampling_add_kernels
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"absl-py",
"dill<0.3.6",
"gpflow>=2.9.2",
"gpflux@ git+https://github.com/secondmind-labs/GPflux.git@khurram/rff_additive_kernels",
"gpflux@ git+https://github.com/secondmind-labs/GPflux.git@uri/rff_additive_kernels",
"numpy",
"tensorflow>=2.5,<2.17; platform_system!='Darwin' or platform_machine!='arm64'",
"tensorflow-macos>=2.5,<2.17; platform_system=='Darwin' and platform_machine=='arm64'",
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/models/gpflow/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,3 +1062,48 @@ def test_qmc_samples_shapes__invalid_values(
) -> None:
with pytest.raises(expected_error_type):
qmc_normal_samples(num_samples=num_samples, n_sample_dim=n_sample_dim, skip=skip)


@pytest.mark.parametrize(
"sampler_type", [RandomFourierFeatureTrajectorySampler, DecoupledTrajectorySampler]
)
@random_seed
def test_trajectory_sampler_respects_active_dims_for_additive_kernels(
sampler_type: Type[RandomFourierFeatureTrajectorySampler],
) -> None:
# Test that the trajectory sampler respects the active_dims settings for an additive kernel.
num_points = 10
query_points = tf.random.uniform((num_points, 2), dtype=tf.float64)
dataset = Dataset(query_points, quadratic(query_points))

model = GaussianProcessRegression(gpr_model(dataset.query_points, dataset.observations))
model.model.kernel = gpflow.kernels.Sum(
[
# one subkernel varies a lot, the other's almost constant
gpflow.kernels.Matern52(variance=10000, active_dims=[0]),
gpflow.kernels.Matern52(lengthscales=10000, active_dims=[1]),
]
)

trajectory_sampler = sampler_type(model)
trajectory = trajectory_sampler.get_trajectory()

batch_size = 2

def with_batching(x_test: tf.Tensor) -> tf.Tensor:
x_test_with_batching = tf.expand_dims(x_test, -2)
return tf.tile(x_test_with_batching, [1, batch_size, 1]) # [N, B, D]

# The output should be constant when we only vary the second dimension
x_rnd = tf.random.uniform((num_points, 2), dtype=tf.float64)
x_fix = tf.constant(0.5, shape=(num_points, 2), dtype=tf.float64)
x_test = tf.where([False, True], x_rnd, x_fix)
model_eval = trajectory(with_batching(x_test))
assert model_eval.shape == (num_points, batch_size, 1)
assert tf.math.reduce_max(tf.math.reduce_std(model_eval, axis=0)) < 2e-4

# But not so when we only vary the first
x_test = tf.where([True, False], x_rnd, x_fix)
model_eval = trajectory(with_batching(x_test))
assert model_eval.shape == (num_points, batch_size, 1)
assert tf.math.reduce_max(tf.math.reduce_std(model_eval, axis=0)) > 1
23 changes: 14 additions & 9 deletions tests/unit/models/gpflux/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from __future__ import annotations

from typing import Callable, Tuple
from typing import Callable, Sequence, Tuple
from unittest.mock import patch

import gpflow.kernels
Expand Down Expand Up @@ -508,8 +508,15 @@ def test_dgp_decoupled_layer_update_updates(

evals_1 = decoupled_layer(xs)

original_W = decoupled_layer._feature_functions.W.value().numpy()
original_b = decoupled_layer._feature_functions.b.value().numpy()
def get_values(x: tf.Variable | Sequence[tf.Variable]) -> Sequence[tf.Tensor]:
# weights and biases are either a single variable or a list of variables
if isinstance(x, tf.Variable):
x = [x]
return [x.value().numpy() for x in x]

original_W = get_values(decoupled_layer._feature_functions.W)
original_b = get_values(decoupled_layer._feature_functions.b)

for _ in range(5):
x_train = tf.random.uniform([20, 2], minval=-10.0, maxval=10.0, dtype=tf.float64)
y_train = tf.random.normal([20, 1], dtype=tf.float64)
Expand All @@ -522,9 +529,7 @@ def test_dgp_decoupled_layer_update_updates(
npt.assert_array_less(1e-2, tf.reduce_sum(tf.abs(evals_1 - evals_new)))

# Check that RFF weights change
npt.assert_array_less(
1e-2, tf.reduce_sum(tf.abs(original_b - decoupled_layer._feature_functions.b))
)
npt.assert_array_less(
1e-2, tf.reduce_sum(tf.abs(original_W - decoupled_layer._feature_functions.W))
)
for old_b, new_b in zip(original_b, get_values(decoupled_layer._feature_functions.b)):
npt.assert_array_less(1e-2, tf.reduce_sum(tf.abs(old_b - new_b)))
for old_W, new_W in zip(original_W, get_values(decoupled_layer._feature_functions.W)):
npt.assert_array_less(1e-2, tf.reduce_sum(tf.abs(old_W - new_W)))
25 changes: 21 additions & 4 deletions trieste/models/gpflow/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def build_sgpr(
trainable_likelihood: bool = False,
num_inducing_points: Optional[int] = None,
trainable_inducing_points: bool = False,
kernel: Optional[gpflow.kernels.Kernel] = None,
) -> SGPR:
"""
Build a :class:`~gpflow.models.SGPR` model with sensible initial parameters and
Expand Down Expand Up @@ -205,11 +206,14 @@ def build_sgpr(
``MAX_NUM_INDUCING_POINTS``, whichever is smaller.
:param trainable_inducing_points: If set to `True` inducing points will be set to
be trainable. This option should be used with caution. By default set to `False`.
:param kernel: The kernel to use in the model, defaults to letting the function set up a
:class:`~gpflow.kernels.Matern52` kernel.
:return: An :class:`~gpflow.models.SGPR` model.
"""
empirical_mean, empirical_variance, _ = _get_data_stats(data)

kernel = _get_kernel(empirical_variance, search_space, kernel_priors, kernel_priors)
if kernel is None:
kernel = _get_kernel(empirical_variance, search_space, kernel_priors, kernel_priors)
mean = _get_mean_function(empirical_mean)

inducing_points = gpflow.inducing_variables.InducingPoints(
Expand All @@ -228,10 +232,11 @@ def build_sgpr(

def build_vgp_classifier(
data: Dataset,
search_space: SearchSpace,
search_space: Optional[SearchSpace] = None,
kernel_priors: bool = True,
noise_free: bool = False,
kernel_variance: Optional[float] = None,
kernel: Optional[gpflow.kernels.Kernel] = None,
) -> VGP:
"""
Build a :class:`~gpflow.models.VGP` binary classification model with sensible initial
Expand Down Expand Up @@ -264,6 +269,8 @@ def build_vgp_classifier(
certain value. If left unspecified (default), the kernel variance is set to
``CLASSIFICATION_KERNEL_VARIANCE_NOISE_FREE`` in the ``noise_free`` case and to
``CLASSIFICATION_KERNEL_VARIANCE`` otherwise.
:param kernel: The kernel to use in the model, defaults to letting the function set up a
:class:`~gpflow.kernels.Matern52` kernel.
:return: A :class:`~gpflow.models.VGP` model.
"""
if kernel_variance is not None:
Expand All @@ -281,7 +288,13 @@ def build_vgp_classifier(
add_prior_to_variance = kernel_priors

model_likelihood = gpflow.likelihoods.Bernoulli()
kernel = _get_kernel(variance, search_space, kernel_priors, add_prior_to_variance)
if kernel is None:
if search_space is None:
raise ValueError(
"'build_gpr' function requires one of 'search_space' or 'kernel' arguments,"
" but got neither"
)
kernel = _get_kernel(variance, search_space, kernel_priors, add_prior_to_variance)
mean = _get_mean_function(tf.constant(0.0, dtype=gpflow.default_float()))

model = VGP(data.astuple(), kernel, model_likelihood, mean_function=mean)
Expand All @@ -300,6 +313,7 @@ def build_svgp(
trainable_likelihood: bool = False,
num_inducing_points: Optional[int] = None,
trainable_inducing_points: bool = False,
kernel: Optional[gpflow.kernels.Kernel] = None,
) -> SVGP:
"""
Build a :class:`~gpflow.models.SVGP` model with sensible initial parameters and
Expand Down Expand Up @@ -348,6 +362,8 @@ def build_svgp(
``MAX_NUM_INDUCING_POINTS``, whichever is smaller.
:param trainable_inducing_points: If set to `True` inducing points will be set to
be trainable. This option should be used with caution. By default set to `False`.
:param kernel: The kernel to use in the model, defaults to letting the function set up a
:class:`~gpflow.kernels.Matern52` kernel.
:return: An :class:`~gpflow.models.SVGP` model.
"""
empirical_mean, empirical_variance, num_data_points = _get_data_stats(data)
Expand All @@ -359,7 +375,8 @@ def build_svgp(
else:
model_likelihood = gpflow.likelihoods.Gaussian()

kernel = _get_kernel(empirical_variance, search_space, kernel_priors, kernel_priors)
if kernel is None:
kernel = _get_kernel(empirical_variance, search_space, kernel_priors, kernel_priors)
mean = _get_mean_function(empirical_mean)

inducing_points = _get_inducing_points(search_space, num_inducing_points)
Expand Down
22 changes: 13 additions & 9 deletions trieste/models/gpflow/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from itertools import cycle
from typing import Callable, Optional, Tuple, TypeVar, Union, cast

import tensorflow as tf
Expand Down Expand Up @@ -793,7 +794,6 @@ def __init__(
dummy_X = model.get_inducing_variables()[0][0:1, :]
else:
dummy_X = model.get_internal_data().query_points[0:1, :]
dummy_X = self.kernel.slice(dummy_X, None)[0] # Keep only the active dims from the kernel.

# Always build the weights and biases. This is important for saving the trajectory (using
# tf.saved_model.save) before it has been used.
Expand All @@ -803,15 +803,19 @@ def resample(self) -> None:
"""
Resample weights and biases
"""
self.b.assign(self._bias_init(tf.shape(self.b), dtype=self._dtype))
self.W.assign(self._weights_init(tf.shape(self.W), dtype=self._dtype))
if isinstance(self.b, tf.Variable):
self.b.assign(self._bias_init(tf.shape(self.b), dtype=self._dtype))
else:
tf.debugging.Assert(isinstance(self.b, list), [])
for b in self.b:
b.assign(self._bias_init(tf.shape(b), dtype=self._dtype))

def call(self, inputs: TensorType) -> TensorType: # [N, D] -> [N, F] or [L, N, F]
"""
Evaluate the basis functions at ``inputs``
"""
inputs = self.kernel.slice(inputs, None)[0] # Keep only active dims from the kernel
return super().call(inputs) # [N, F] or [L, N, F]
if isinstance(self.W, tf.Variable):
self.W.assign(self._weights_init(self.kernel)(tf.shape(self.W), self._dtype))
else:
tf.debugging.Assert(isinstance(self.W, list), [])
for W, k in zip(self.W, cycle(self.sub_kernels)):
W.assign(self._weights_init(k)(tf.shape(W), self._dtype))


class ResampleableDecoupledFeatureFunctions(ResampleableRandomFourierFeatureFunctions):
Expand Down
20 changes: 13 additions & 7 deletions trieste/models/gpflux/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from abc import ABC
from itertools import cycle
from typing import Callable, cast

import gpflow.kernels
Expand Down Expand Up @@ -439,19 +440,24 @@ def __init__(self, layer: GPLayer, n_components: int):
dummy_X = inducing_points[0:1, :]

self.__call__(dummy_X)
self.b: TensorType = tf.Variable(self.b)
self.W: TensorType = tf.Variable(self.W)

def resample(self) -> None:
"""
Resample weights and biases.
"""
if not hasattr(self, "_bias_init"):
self.b.assign(self._sample_bias(tf.shape(self.b), dtype=self._dtype))
self.W.assign(self._sample_weights(tf.shape(self.W), dtype=self._dtype))
else:
if isinstance(self.b, tf.Variable):
self.b.assign(self._bias_init(tf.shape(self.b), dtype=self._dtype))
self.W.assign(self._weights_init(tf.shape(self.W), dtype=self._dtype))
else:
tf.debugging.Assert(isinstance(self.b, list), [])
for b in self.b:
b.assign(self._bias_init(tf.shape(b), dtype=self._dtype))

if isinstance(self.W, tf.Variable):
self.W.assign(self._weights_init(self.kernel)(tf.shape(self.W), self._dtype))
else:
tf.debugging.Assert(isinstance(self.W, list), [])
for W, k in zip(self.W, cycle(self.sub_kernels)):
W.assign(self._weights_init(k)(tf.shape(W), self._dtype))

def __call__(self, x: TensorType) -> TensorType: # [N, D] -> [N, L + M] or [P, N, L + M]
"""
Expand Down
Loading