Skip to content

Commit

Permalink
LMC multitask-SVGP models can output a single task per input. (#1769)
Browse files Browse the repository at this point in the history
* LMC multitask-SVGP models can output a single task per input.

If one defines a ApproximateGP model with a LMCVariationalStrategy,
there are now two different options for return types:

1. Calling `model(x)` will return a `... x N x num_tasks`
MultitaskMultivariateNormal distribution
1. Calling `model(x, task_indices=i)` will return a `... x N`
MultivariateNormal distribution, where `i` corresponds to the selected
task index for each input.

[Closes #1285, #1433]
[Addresses #1743, #1765]
  • Loading branch information
gpleiss authored Oct 1, 2021
1 parent f06004e commit fc2053b
Show file tree
Hide file tree
Showing 7 changed files with 527 additions and 80 deletions.
14 changes: 9 additions & 5 deletions docs/source/variational.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,23 @@ These are special :obj:`~gpytorch.variational._VariationalStrategy` objects that
:obj:`~gpytorch.distributions.MultitaskMultivariateNormal` distributions. Each of these objects
acts on a batch of approximate GPs.


:hidden:`IndependentMultitaskVariationalStrategy`
:hidden:`LMCVariationalStrategy`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: IndependentMultitaskVariationalStrategy
.. autoclass:: LMCVariationalStrategy
:members:

:hidden:`LMCVariationalStrategy`
.. automethod:: __call__


:hidden:`IndependentMultitaskVariationalStrategy`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: LMCVariationalStrategy
.. autoclass:: IndependentMultitaskVariationalStrategy
:members:

.. automethod:: __call__


Variational Distributions
-----------------------------
Expand Down

Large diffs are not rendered by default.

63 changes: 48 additions & 15 deletions gpytorch/variational/independent_multitask_variational_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@

import warnings

from ..distributions import MultitaskMultivariateNormal
import torch

from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
from ..lazy import RootLazyTensor
from ..module import Module
from ._variational_strategy import _VariationalStrategy


class IndependentMultitaskVariationalStrategy(_VariationalStrategy):
"""
IndependentMultitaskVariationalStrategy wraps an existing
:obj:`~gpytorch.variational.VariationalStrategy`
to produce a :obj:`~gpytorch.variational.MultitaskMultivariateNormal` distribution.
All outputs will be independent of one another.
:obj:`~gpytorch.variational.VariationalStrategy` to produce vector-valued (multi-task)
output distributions. Each task will be independent of one another.
The output will either be a :obj:`~gpytorch.distributions.MultitaskMultivariateNormal` distribution
(if we wish to evaluate all tasks for each input) or a :obj:`~gpytorch.distributions.MultivariateNormal`
(if we wish to evaluate a single task for each input).
The base variational strategy is assumed to operate on a batch of GPs. One of the batch
dimensions corresponds to the multiple tasks.
Expand Down Expand Up @@ -43,19 +49,46 @@ def variational_params_initialized(self):
def kl_divergence(self):
return super().kl_divergence().sum(dim=-1)

def __call__(self, x, prior=False, **kwargs):
def __call__(self, x, task_indices=None, prior=False, **kwargs):
r"""
See :class:`LMCVariationalStrategy`.
"""
function_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
if (
self.task_dim > 0
and self.task_dim > len(function_dist.batch_shape)
or self.task_dim < 0
and self.task_dim + len(function_dist.batch_shape) < 0
):
return MultitaskMultivariateNormal.from_repeated_mvn(function_dist, num_tasks=self.num_tasks)

if task_indices is None:
# Every data point will get an output for each task
if (
self.task_dim > 0
and self.task_dim > len(function_dist.batch_shape)
or self.task_dim < 0
and self.task_dim + len(function_dist.batch_shape) < 0
):
return MultitaskMultivariateNormal.from_repeated_mvn(function_dist, num_tasks=self.num_tasks)
else:
function_dist = MultitaskMultivariateNormal.from_batch_mvn(function_dist, task_dim=self.task_dim)
assert function_dist.event_shape[-1] == self.num_tasks
return function_dist

else:
function_dist = MultitaskMultivariateNormal.from_batch_mvn(function_dist, task_dim=self.task_dim)
assert function_dist.event_shape[-1] == self.num_tasks
return function_dist
# Each data point will get a single output corresponding to a single task

if self.task_dim > 0:
raise RuntimeError(f"task_dim must be a negative indexed batch dimension: got {self.task_dim}.")
num_batch = len(function_dist.batch_shape)
task_dim = num_batch + self.task_dim

# Create a mask to choose specific task assignment
shape = list(function_dist.batch_shape + function_dist.event_shape)
shape[task_dim] = 1
task_indices = task_indices.expand(shape).squeeze(task_dim)

# Create a mask to choose specific task assignment
task_mask = torch.nn.functional.one_hot(task_indices, num_classes=self.num_tasks)
task_mask = task_mask.permute(*range(0, task_dim), *range(task_dim + 1, num_batch + 1), task_dim)

mean = (function_dist.mean * task_mask).sum(task_dim)
covar = (function_dist.lazy_covariance_matrix * RootLazyTensor(task_mask[..., None])).sum(task_dim)
return MultivariateNormal(mean, covar)


class MultitaskVariationalStrategy(IndependentMultitaskVariationalStrategy):
Expand Down
152 changes: 116 additions & 36 deletions gpytorch/variational/lmc_variational_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,34 @@

import torch

from ..distributions import MultitaskMultivariateNormal
from ..lazy import KroneckerProductLazyTensor, MatmulLazyTensor
from .. import settings
from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
from ..lazy import KroneckerProductLazyTensor, RootLazyTensor
from ..module import Module
from ..utils.broadcasting import _mul_broadcast_shape
from ..utils.interpolation import left_interp
from ._variational_strategy import _VariationalStrategy


def _select_lmc_coefficients(lmc_coefficients: torch.Tensor, indices: torch.LongTensor) -> torch.Tensor:
"""
Given a list of indices for ... x N datapoints,
select the row from lmc_coefficient that corresponds to each datapoint
lmc_coefficients: torch.Tensor ... x num_latents x ... x num_tasks
indices: torch.Tesnor ... x N
"""
batch_shape = _mul_broadcast_shape(lmc_coefficients.shape[:-1], indices.shape[:-1])

# We will use the left_interp helper to do the indexing
lmc_coefficients = lmc_coefficients.expand(*batch_shape, lmc_coefficients.shape[-1])[..., None]
indices = indices.expand(*batch_shape, indices.shape[-1])[..., None]
res = left_interp(
indices, torch.ones(indices.shape, dtype=torch.long, device=indices.device), lmc_coefficients,
).squeeze(-1)
return res


class LMCVariationalStrategy(_VariationalStrategy):
r"""
LMCVariationalStrategy is an implementation of the "Linear Model of Coregionalization"
Expand All @@ -20,8 +42,11 @@ class LMCVariationalStrategy(_VariationalStrategy):
f_{\text{task } i}( \mathbf x) = \sum_{q=1}^Q a_i^{(q)} g^{(q)} ( \mathbf x )
LMCVariationalStrategy wraps an existing :obj:`~gpytorch.variational.VariationalStrategy`
to produce a :obj:`~gpytorch.variational.MultitaskMultivariateNormal` distribution.
LMCVariationalStrategy wraps an existing :obj:`~gpytorch.variational.VariationalStrategy`.
The output will either be a :obj:`~gpytorch.distributions.MultitaskMultivariateNormal` distribution
(if we wish to evaluate all tasks for each input) or a :obj:`~gpytorch.distributions.MultivariateNormal`
(if we wish to evaluate a single task for each input).
The base variational strategy is assumed to operate on a multi-batch of GPs, where one
of the batch dimensions corresponds to the latent function dimension.
Expand All @@ -35,13 +60,6 @@ class LMCVariationalStrategy(_VariationalStrategy):
batch shape. This would correspond to each of the latent functions having different kernels
or the same kernel, respectivly.
:param ~gpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy
:param int num_tasks: The total number of tasks (output functions)
:param int num_latents: The total number of latent functions in each group
:param latent_dim: (Default: -1) Which batch dimension corresponds to the latent function batch.
**Must be negative indexed**
:type latent_dim: `int` < 0
Example:
>>> class LMCMultitaskGP(gpytorch.models.ApproximateGP):
>>> '''
Expand Down Expand Up @@ -74,7 +92,13 @@ class LMCVariationalStrategy(_VariationalStrategy):
>>> batch_shape=torch.Size([3]),
>>> )
>>>
>>> # Model output: n x 5
:param ~gpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy
:param int num_tasks: The total number of tasks (output functions)
:param int num_latents: The total number of latent functions in each group
:param latent_dim: (Default: -1) Which batch dimension corresponds to the latent function batch.
**Must be negative indexed**
:type latent_dim: `int` < 0
"""

def __init__(
Expand Down Expand Up @@ -120,28 +144,84 @@ def variational_params_initialized(self):
def kl_divergence(self):
return super().kl_divergence().sum(dim=self.latent_dim)

def __call__(self, x, prior=False, **kwargs):
function_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
lmc_coefficients = self.lmc_coefficients.expand(*function_dist.batch_shape, self.lmc_coefficients.size(-1))
num_batch = len(function_dist.batch_shape)
num_dim = num_batch + len(function_dist.event_shape)
latent_dim = num_batch + self.latent_dim if self.latent_dim is not None else None

# Mean
mean = function_dist.mean.permute(*range(0, latent_dim), *range(latent_dim + 1, num_dim), latent_dim)
mean = mean @ lmc_coefficients.permute(
*range(0, latent_dim), *range(latent_dim + 1, num_dim - 1), latent_dim, -1
)

# Covar
covar = function_dist.lazy_covariance_matrix
lmc_factor = MatmulLazyTensor(lmc_coefficients.unsqueeze(-1), lmc_coefficients.unsqueeze(-2))
covar = KroneckerProductLazyTensor(covar, lmc_factor)
covar = covar.sum(latent_dim)

# Add a bit of jitter to make the covar PD
covar = covar.add_jitter(1e-6)

# Done!
function_dist = MultitaskMultivariateNormal(mean, covar)
def __call__(self, x, task_indices=None, prior=False, **kwargs):
r"""
Computes the variational (or prior) distribution
:math:`q( \mathbf f \mid \mathbf X)` (or :math:`p( \mathbf f \mid \mathbf X)`).
There are two modes:
1. Compute **all tasks** for all inputs.
If this is the case, the :attr:`task_indices` attribute should be None.
The return type will be a (... x N x num_tasks)
:class:`~gpytorch.distributions.MultitaskMultivariateNormal`.
2. Compute **one task** per inputs.
If this is the case, the (... x N) :attr:`task_indices` tensor should contain
the indices of each input's assigned task.
The return type will be a (... x N)
:class:`~gpytorch.distributions.MultivariateNormal`.
:param x: Input locations to evaluate variational strategy
:type x: torch.Tensor (... x N x D)
:param task_indices: (Default: None) Task index associated with each input.
If this **is not** provided, then the returned distribution evaluates every input on every task
(returns :class:`~gpytorch.distributions.MultitaskMultivariateNormal`).
If this **is** provided, then the returned distribution evaluates each input only on its assigned task.
(returns :class:`~gpytorch.distributions.MultivariateNormal`).
:type task_indices: torch.Tensor (... x N), optional
:param prior: (Default: False) If False, returns the variational distribution
:math:`q( \mathbf f \mid \mathbf X)`.
If True, returns the prior distribution
:math:`p( \mathbf f \mid \mathbf X)`.
:type prior: bool
:return: :math:`q( \mathbf f \mid \mathbf X)` (or the prior),
either for all tasks (if `task_indices == None`)
or for a specific task (if `task_indices != None`).
:rtype: ~gpytorch.distributions.MultitaskMultivariateNormal (... x N x num_tasks)
or ~gpytorch.distributions.MultivariateNormal (... x N)
"""
latent_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
num_batch = len(latent_dist.batch_shape)
latent_dim = num_batch + self.latent_dim

if task_indices is None:
num_dim = num_batch + len(latent_dist.event_shape)

# Every data point will get an output for each task
# Therefore, we will set up the lmc_coefficients shape for a matmul
lmc_coefficients = self.lmc_coefficients.expand(*latent_dist.batch_shape, self.lmc_coefficients.size(-1))

# Mean: ... x N x num_tasks
latent_mean = latent_dist.mean.permute(*range(0, latent_dim), *range(latent_dim + 1, num_dim), latent_dim)
mean = latent_mean @ lmc_coefficients.permute(
*range(0, latent_dim), *range(latent_dim + 1, num_dim - 1), latent_dim, -1
)

# Covar: ... x (N x num_tasks) x (N x num_tasks)
latent_covar = latent_dist.lazy_covariance_matrix
lmc_factor = RootLazyTensor(lmc_coefficients.unsqueeze(-1))
covar = KroneckerProductLazyTensor(latent_covar, lmc_factor).sum(latent_dim)
# Add a bit of jitter to make the covar PD
covar = covar.add_jitter(settings.cholesky_jitter.value(dtype=mean.dtype))

# Done!
function_dist = MultitaskMultivariateNormal(mean, covar)

else:
# Each data point will get a single output corresponding to a single task
# Therefore, we will select the appropriate lmc coefficients for each task
lmc_coefficients = _select_lmc_coefficients(self.lmc_coefficients, task_indices)

# Mean: ... x N
mean = (latent_dist.mean * lmc_coefficients).sum(latent_dim)

# Covar: ... x N x N
latent_covar = latent_dist.lazy_covariance_matrix
lmc_factor = RootLazyTensor(lmc_coefficients.unsqueeze(-1))
covar = (latent_covar * lmc_factor).sum(latent_dim)
# Add a bit of jitter to make the covar PD
covar = covar.add_jitter(settings.cholesky_jitter.value(dtype=mean.dtype))

# Done!
function_dist = MultivariateNormal(mean, covar)

return function_dist
54 changes: 52 additions & 2 deletions test/examples/test_lmc_svgp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import gpytorch
import torch
from gpytorch.likelihoods import MultitaskGaussianLikelihood
from gpytorch.likelihoods import GaussianLikelihood, MultitaskGaussianLikelihood


# Batch training test: Let's learn hyperparameters on a sine dataset, but test on a sine dataset and a cosine dataset
Expand Down Expand Up @@ -75,7 +75,6 @@ def tearDown(self):
torch.set_rng_state(self.rng_state)

def test_train_and_eval(self):
# We're manually going to set the hyperparameters to something they shouldn't be
likelihood = MultitaskGaussianLikelihood(num_tasks=4)
model = LMCModel()

Expand Down Expand Up @@ -132,6 +131,57 @@ def test_train_and_eval(self):
self.assertEqual(lower.shape, train_y.shape)
self.assertEqual(upper.shape, train_y.shape)

def test_indexed_train_and_eval(self):
likelihood = GaussianLikelihood()
model = LMCModel()

# Find optimal model hyperparameters
model.train()
likelihood.train()
optimizer = torch.optim.Adam([
{'params': model.parameters()},
{'params': likelihood.parameters()},
], lr=0.01)

# Our loss object. We're using the VariationalELBO, which essentially just computes the ELBO
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))

# Create some task indices
arange = torch.arange(train_x.size(0))
train_i = torch.rand(train_x.size(0)).mul(4).floor().long()

# We use more CG iterations here because the preconditioner introduced in the NeurIPS paper seems to be less
# effective for VI.
for i in range(400):
# Within each iteration, we will go over each minibatch of data
optimizer.zero_grad()
output = model(train_x, task_indices=train_i)
loss = -mll(output, train_y[arange, train_i])
loss.backward()
optimizer.step()

for param in model.parameters():
self.assertTrue(param.grad is not None)
self.assertGreater(param.grad.norm().item(), 0)
for param in likelihood.parameters():
self.assertTrue(param.grad is not None)
self.assertGreater(param.grad.norm().item(), 0)

# Test the model
model.eval()
likelihood.eval()

# Make predictions for both sets of test points, and check MAEs.
with torch.no_grad(), gpytorch.settings.max_eager_kernel_size(1):
predictions = likelihood(model(train_x, task_indices=train_i))
mean_abs_error = torch.mean(torch.abs(train_y[arange, train_i] - predictions.mean))
self.assertLess(mean_abs_error.squeeze().item(), 0.15)

# Smoke test for getting predictive uncertainties
lower, upper = predictions.confidence_region()
self.assertEqual(lower.shape, train_i.shape)
self.assertEqual(upper.shape, train_i.shape)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit fc2053b

Please sign in to comment.