question about LMCVariationalStrategy #1552
Unanswered
sjiang2018
asked this question in
Q&A
Replies: 1 comment
-
I think the easiest thing to do is to set num_tasks = 5
num_latents = 4
class MultitaskGPModel(gpytorch.models.ApproximateGP):
def __init__(self):
# Let's use a different set of inducing points for each latent function
inducing_points = torch.rand(num_latents, 16, 1)
# We have to mark the CholeskyVariationalDistribution as batch
# so that we learn a variational distribution for each task
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
inducing_points.size(-2), batch_shape=torch.Size([num_latents])
)
# We have to wrap the VariationalStrategy in a LMCVariationalStrategy
# so that the output will be a MultitaskMultivariateNormal rather than a batch output
variational_strategy = gpytorch.variational.LMCVariationalStrategy(
gpytorch.variational.VariationalStrategy(
self, inducing_points, variational_distribution, learn_inducing_locations=True
),
num_tasks=num_tasks,
num_latents=num_latents,
latent_dim=-1
)
super().__init__(variational_strategy)
# The mean and covariance modules should be marked as batch
# so we learn a different set of hyperparameters
self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents]))
self.covar_module = gpytorch.kernels.ScaleKernel(
gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])),
batch_shape=torch.Size([num_latents])
)
def forward(self, x):
# The forward function should be written as if we were dealing with each output
# dimension in batch
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) and then in your training loop: dtype = model.variational_strategy.lmc_coefficients.dtype
device = model.variational_strategy.lmc_coefficients.device
for i in range(num_iter):
# Tasks 1-3 get their own latent GP
model.variational_strategy.lmc_coefficients[..., :3, :3].data = torch.eye(3, dtype=dtype, device=device)
# Tasks 1-3 shouldn't be influenced by the GP for tasks 4-5
model.variational_strategy.lmc_coefficients[..., 3:, :3].data.fill_(0)
# Tasks 4-5 shouldn't be influenced by the GP for tasks 1-3
model.variational_strategy.lmc_coefficients[..., :3, 3:].data.fill_(0)
# ...
#optimizer.step() |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
If there is a multi-dimension output [y1, y2, y3, y4, y5]; only y4 and y5 are correlated. Is there a way to specify the correlated output dimensions in LMCVariationalStrategy, instead of setting num_latents as 4.
Beta Was this translation helpful? Give feedback.
All reactions