Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start adding typing annotations to ExactGP #2436

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions gpytorch/models/exact_gp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
#!/usr/bin/env python3

from __future__ import annotations

import warnings

from collections.abc import Sequence
from copy import deepcopy

import torch
from torch import Tensor

from .. import settings
from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
Expand Down Expand Up @@ -52,15 +57,20 @@ class ExactGP(GP):
>>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x
"""

def __init__(self, train_inputs, train_targets, likelihood):
def __init__(
self,
train_inputs: Tensor | Sequence[Tensor] | None,
train_targets: Tensor | None,
likelihood: _GaussianLikelihoodBase,
):
if train_inputs is not None and torch.is_tensor(train_inputs):
train_inputs = (train_inputs,)
if train_inputs is not None and not all(torch.is_tensor(train_input) for train_input in train_inputs):
raise RuntimeError("Train inputs must be a tensor, or a list/tuple of tensors")
if not isinstance(likelihood, _GaussianLikelihoodBase):
raise RuntimeError("ExactGP can only handle Gaussian likelihoods")

super(ExactGP, self).__init__()
super().__init__()
if train_inputs is not None:
self.train_inputs = tuple(tri.unsqueeze(-1) if tri.ndimension() == 1 else tri for tri in train_inputs)
self.train_targets = train_targets
Expand All @@ -72,20 +82,20 @@ def __init__(self, train_inputs, train_targets, likelihood):
self.prediction_strategy = None

@property
def train_targets(self):
def train_targets(self) -> tuple[Tensor] | None:
return self._train_targets

@train_targets.setter
def train_targets(self, value):
def train_targets(self, value: Tensor | None) -> None:
object.__setattr__(self, "_train_targets", value)

def _apply(self, fn):
if self.train_inputs is not None:
self.train_inputs = tuple(fn(train_input) for train_input in self.train_inputs)
self.train_targets = fn(self.train_targets)
return super(ExactGP, self)._apply(fn)
return super()._apply(fn)

def _clear_cache(self):
def _clear_cache(self) -> None:
# The precomputed caches from test time live in prediction_strategy
self.prediction_strategy = None

Expand All @@ -99,7 +109,9 @@ def local_load_samples(self, samples_dict, memo, prefix):
self.train_targets = self.train_targets.unsqueeze(0).expand(num_samples, *self.train_targets.shape)
super().local_load_samples(samples_dict, memo, prefix)

def set_train_data(self, inputs=None, targets=None, strict=True):
def set_train_data(
self, inputs: Tensor | Sequence[Tensor] | None = None, targets: Tensor | None = None, strict: bool = True
) -> None:
"""
Set training data (does not re-fit model hyper-parameters).

Expand Down Expand Up @@ -218,7 +230,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
except KeyError:
fantasy_kwargs = {}

full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs)
full_output = super().__call__(*full_inputs, **kwargs)

# Copy model without copying training data or prediction strategy (since we'll overwrite those)
old_pred_strat = self.prediction_strategy
Expand Down Expand Up @@ -257,7 +269,7 @@ def __call__(self, *args, **kwargs):
if self.training:
if self.train_inputs is None:
raise RuntimeError(
"train_inputs, train_targets cannot be None in training mode. "
"train_inputs cannot be None in training mode. "
"Call .eval() for prior predictions, or call .set_train_data() to add training data."
)
if settings.debug.on():
Expand All @@ -271,7 +283,7 @@ def __call__(self, *args, **kwargs):
# Prior mode
elif settings.prior_mode.on() or self.train_inputs is None or self.train_targets is None:
full_inputs = args
full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs)
full_output = super().__call__(*full_inputs, **kwargs)
if settings.debug().on():
if not isinstance(full_output, MultivariateNormal):
raise RuntimeError("ExactGP.forward must return a MultivariateNormal")
Expand Down Expand Up @@ -313,7 +325,7 @@ def __call__(self, *args, **kwargs):
full_inputs.append(torch.cat([train_input, input], dim=-2))

# Get the joint distribution for training/test data
full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs)
full_output = super().__call__(*full_inputs, **kwargs)
if settings.debug().on():
if not isinstance(full_output, MultivariateNormal):
raise RuntimeError("ExactGP.forward must return a MultivariateNormal")
Expand Down