Skip to content

Commit

Permalink
fixed the double noise prior bug and added a test (#2355)
Browse files Browse the repository at this point in the history
  • Loading branch information
LuisAugenstein authored May 27, 2023
1 parent 163600b commit e3457b2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
2 changes: 1 addition & 1 deletion gpytorch/mlls/exact_marginal_log_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _add_other_terms(self, res, params):

# Add log probs of priors on the (functions of) parameters
res_ndim = res.ndim
for name, module, prior, closure, _ in self.named_priors():
for name, module, prior, closure, _ in self.model.named_priors():
prior_term = prior.log_prob(closure(module))
res.add_(prior_term.view(*prior_term.shape[:res_ndim], -1).sum(dim=-1))

Expand Down
22 changes: 22 additions & 0 deletions test/mlls/test_exact_marginal_log_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,25 @@ def test_batched_eval(self):
self.assertEqual(non_batch_mll_eval.shape, torch.Size())
self.assertEqual(batch_mll_eval.shape, torch.Size([10]))
self.assertTrue(torch.allclose(non_batch_mll_eval.expand(10), batch_mll_eval))

def test_mll_computation(self):
train_x, train_y = (torch.rand(10, 2), torch.rand(10))
model = ExactGPModel(train_x, train_y)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
output = model(train_x)
marginal_log_likelihood = mll(output, train_y)

marginal_likelihood = model.likelihood(output)
noise_prior = next(model.likelihood.named_priors())[2]
outputscale_prior = next(model.covar_module.named_priors())[2]
lengthscale_prior = next(model.covar_module.base_kernel.named_priors())[2]

log_probs = [
marginal_likelihood.log_prob(train_y),
noise_prior.log_prob(model.likelihood.noise),
outputscale_prior.log_prob(model.covar_module.outputscale),
lengthscale_prior.log_prob(model.covar_module.base_kernel.lengthscale).sum(),
]
marginal_log_likelihood_by_hand = sum(log_probs) / train_y.shape[0]

self.assertTrue(torch.allclose(marginal_log_likelihood, marginal_log_likelihood_by_hand))

0 comments on commit e3457b2

Please sign in to comment.