Skip to content

Commit

Permalink
Merge branch 'main' into fix_fantasy_model
Browse files Browse the repository at this point in the history
  • Loading branch information
SaiAakash authored Jan 25, 2025
2 parents 9d7857f + 4156bf4 commit 3437156
Show file tree
Hide file tree
Showing 20 changed files with 352 additions and 468 deletions.
72 changes: 35 additions & 37 deletions examples/01_Exact_GPs/GP_Regression_Fully_Bayesian.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,44 +23,57 @@
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import torch\n",
"import os\n",
"\n",
"import gpytorch\n",
"from gpytorch.priors import UniformPrior\n",
"import matplotlib.pyplot as plt\n",
"import pyro\n",
"from pyro.infer.mcmc import NUTS, MCMC, HMC\n",
"from matplotlib import pyplot as plt\n",
"from pyro.infer.mcmc import NUTS, MCMC\n",
"import torch\n",
"\n",
"%matplotlib inline\n",
"%load_ext autoreload\n",
"%autoreload 2"
"# this is for running the notebook in our testing framework\n",
"smoke_test = ('CI' in os.environ)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Training data is 11 points in [0,1] inclusive regularly spaced\n",
"# Training data is 4 points in [0,1] inclusive regularly spaced\n",
"train_x = torch.linspace(0, 1, 4)\n",
"# True function is sin(2*pi*x) with Gaussian noise\n",
"train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2"
]
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# We will use the simplest form of GP model, exact inference\n",
"class ExactGPModel(gpytorch.models.ExactGP):\n",
" def __init__(self, train_x, train_y, likelihood):\n",
" super(ExactGPModel, self).__init__(train_x, train_y, likelihood)\n",
" super().__init__(train_x, train_y, likelihood)\n",
" self.mean_module = gpytorch.means.ConstantMean()\n",
" self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
" \n",
"\n",
" def forward(self, x):\n",
" mean_x = self.mean_module(x)\n",
" covar_x = self.covar_module(x)\n",
Expand All @@ -78,7 +91,7 @@
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -90,25 +103,17 @@
}
],
"source": [
"# this is for running the notebook in our testing framework\n",
"import os\n",
"smoke_test = ('CI' in os.environ)\n",
"num_samples = 2 if smoke_test else 100\n",
"warmup_steps = 2 if smoke_test else 100\n",
"\n",
"\n",
"from gpytorch.priors import LogNormalPrior, NormalPrior, UniformPrior\n",
"# Use a positive constraint instead of usual GreaterThan(1e-4) so that LogNormal has support over full range.\n",
"likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())\n",
"likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
"model = ExactGPModel(train_x, train_y, likelihood)\n",
"\n",
"model.mean_module.register_prior(\"mean_prior\", UniformPrior(-1, 1), \"constant\")\n",
"model.covar_module.base_kernel.register_prior(\"lengthscale_prior\", UniformPrior(0.01, 0.5), \"lengthscale\")\n",
"model.covar_module.register_prior(\"outputscale_prior\", UniformPrior(1, 2), \"outputscale\")\n",
"likelihood.register_prior(\"noise_prior\", UniformPrior(0.01, 0.5), \"noise\")\n",
"\n",
"mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n",
"\n",
"def pyro_model(x, y):\n",
" with gpytorch.settings.fast_computations(False, False, False):\n",
" sampled_model = model.pyro_sample_from_prior()\n",
Expand All @@ -132,7 +137,7 @@
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -141,7 +146,7 @@
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -158,12 +163,12 @@
"source": [
"## Plot Mean Functions\n",
"\n",
"In the next cell, we plot the first 25 mean functions on the samep lot. This particular example has a fairly large amount of data for only 1 dimension, so the hyperparameter posterior is quite tight and there is relatively little variance."
"In the next cell, we plot the first 25 mean functions on the same plot. This particular example has a fairly large amount of data for only 1 dimension, so the hyperparameter posterior is quite tight and there is relatively little variance."
]
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 8,
"metadata": {
"scrolled": false
},
Expand All @@ -185,14 +190,14 @@
"with torch.no_grad():\n",
" # Initialize plot\n",
" f, ax = plt.subplots(1, 1, figsize=(4, 3))\n",
" \n",
"\n",
" # Plot training data as black stars\n",
" ax.plot(train_x.numpy(), train_y.numpy(), 'k*', zorder=10)\n",
" \n",
"\n",
" for i in range(min(num_samples, 25)):\n",
" # Plot predictive means as blue line\n",
" ax.plot(test_x.numpy(), output.mean[i].detach().numpy(), 'b', linewidth=0.3)\n",
" \n",
"\n",
" # Shade between the lower and upper confidence bounds\n",
" # ax.fill_between(test_x.numpy(), lower.numpy(), upper.numpy(), alpha=0.5)\n",
" ax.set_ylim([-3, 3])\n",
Expand All @@ -212,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -221,7 +226,7 @@
"<All keys matched successfully>"
]
},
"execution_count": 63,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -235,13 +240,6 @@
"\n",
"model.load_state_dict(state_dict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
"source": [
"## The PyroGP model\n",
"\n",
"In order to use Pyro with GPyTorch, your model must inherit from `gpytorch.models.PyroGP` (rather than `gpytorch.modelks.ApproximateGP`). The `PyroGP` extends the `ApproximateGP` class and differs in a few key ways:\n",
"In order to use Pyro with GPyTorch, your model must inherit from `gpytorch.models.PyroGP` (rather than `gpytorch.models.ApproximateGP`). The `PyroGP` extends the `ApproximateGP` class and differs in a few key ways:\n",
"\n",
"- It adds the `model` and `guide` functions which are used by Pyro's inference engine.\n",
"- It's constructor requires two additional arguments beyond the variational strategy:\n",
Expand Down
25 changes: 16 additions & 9 deletions gpytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
variational,
)
from .functions import inv_matmul, log_normal_cdf, logdet, matmul # Deprecated
from .lazy import cat, delazify, lazify
from .mlls import ExactMarginalLogLikelihood
from .module import Module

Expand Down Expand Up @@ -116,7 +115,10 @@ def inv_quad(input: Anysor, inv_quad_rhs: Tensor, reduce_inv_quad: bool = True)


def inv_quad_logdet(
input: Anysor, inv_quad_rhs: Optional[Tensor] = None, logdet: bool = False, reduce_inv_quad: bool = True
input: Anysor,
inv_quad_rhs: Optional[Tensor] = None,
logdet: bool = False,
reduce_inv_quad: bool = True,
) -> Tuple[Tensor, Tensor]:
r"""
Calls both :func:`inv_quad_logdet` and :func:`logdet` on a positive definite matrix (or batch) :math:`\mathbf A`.
Expand All @@ -133,12 +135,18 @@ def inv_quad_logdet(
If `reduce_inv_quad=True`, the inverse quadratic term is of shape (...). Otherwise, it is (... x M).
"""
return linear_operator.inv_quad_logdet(
input=input, inv_quad_rhs=inv_quad_rhs, logdet=logdet, reduce_inv_quad=reduce_inv_quad
input=input,
inv_quad_rhs=inv_quad_rhs,
logdet=logdet,
reduce_inv_quad=reduce_inv_quad,
)


def pivoted_cholesky(
input: Anysor, rank: int, error_tol: Optional[float] = None, return_pivots: bool = False
input: Anysor,
rank: int,
error_tol: Optional[float] = None,
return_pivots: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
r"""
Performs a partial pivoted Cholesky factorization of a positive definite matrix (or batch of matrices).
Expand Down Expand Up @@ -201,7 +209,10 @@ def root_inv_decomposition(
:return: A tensor :math:`\mathbf R` such that :math:`\mathbf R \mathbf R^\top \approx \mathbf A^{-1}`.
"""
return linear_operator.root_inv_decomposition(
input=input, initial_vectors=initial_vectors, test_vectors=test_vectors, method=method
input=input,
initial_vectors=initial_vectors,
test_vectors=test_vectors,
method=method,
)


Expand Down Expand Up @@ -307,11 +318,7 @@ def sqrt_inv_matmul(input: Anysor, rhs: Tensor, lhs: Optional[Tensor] = None) ->
# Other
"__version__",
# Deprecated
"add_diag",
"cat",
"delazify",
"inv_matmul",
"lazify",
"logdet",
"log_normal_cdf",
"matmul",
Expand Down
3 changes: 2 additions & 1 deletion gpytorch/distributions/multitask_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def from_independent_mvns(cls, mvns):
if any(isinstance(mvn, MultitaskMultivariateNormal) for mvn in mvns):
raise ValueError("Cannot accept MultitaskMultivariateNormals")
if not all(m.batch_shape == mvns[0].batch_shape for m in mvns[1:]):
raise ValueError("All MultivariateNormals must have the same batch shape")
batch_shape = torch.broadcast_shapes(*(m.batch_shape for m in mvns))
mvns = [mvn.expand(batch_shape) for mvn in mvns]
if not all(m.event_shape == mvns[0].event_shape for m in mvns[1:]):
raise ValueError("All MultivariateNormals must have the same event shape")
mean = torch.stack([mvn.mean for mvn in mvns], -1)
Expand Down
65 changes: 61 additions & 4 deletions gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,67 @@ def expand(self, batch_size: torch.Size) -> MultivariateNormal:
See :py:meth:`torch.distributions.Distribution.expand
<torch.distributions.distribution.Distribution.expand>`.
"""
new_loc = self.loc.expand(torch.Size(batch_size) + self.loc.shape[-1:])
new_covar = self._covar.expand(torch.Size(batch_size) + self._covar.shape[-2:])
res = self.__class__(new_loc, new_covar)
return res
# NOTE: Pyro may call this method with list[int] instead of torch.Size.
batch_size = torch.Size(batch_size)
new_loc = self.loc.expand(batch_size + self.loc.shape[-1:])
if self.islazy:
new_covar = self._covar.expand(batch_size + self._covar.shape[-2:])
new = self.__class__(mean=new_loc, covariance_matrix=new_covar)
if self.__unbroadcasted_scale_tril is not None:
# Reuse the scale tril if available.
new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.expand(
batch_size + self.__unbroadcasted_scale_tril.shape[-2:]
)
else:
# Non-lazy MVN is represented using scale_tril in PyTorch.
# Constructing it from scale_tril will avoid unnecessary computation.
# Initialize using __new__, so that we can skip __init__ and use scale_tril.
new = self.__new__(type(self))
new._islazy = False
new_scale_tril = self.__unbroadcasted_scale_tril.expand(
batch_size + self.__unbroadcasted_scale_tril.shape[-2:]
)
super(MultivariateNormal, new).__init__(loc=new_loc, scale_tril=new_scale_tril)
# Set the covar matrix, since it is always available for GPyTorch MVN.
new.covariance_matrix = self.covariance_matrix.expand(batch_size + self.covariance_matrix.shape[-2:])
return new

def unsqueeze(self, dim: int) -> MultivariateNormal:
r"""
Constructs a new MultivariateNormal with the batch shape unsqueezed
by the given dimension.
For example, if `self.batch_shape = torch.Size([2, 3])` and `dim = 0`, then
the returned MultivariateNormal will have `batch_shape = torch.Size([1, 2, 3])`.
If `dim = -1`, then the returned MultivariateNormal will have
`batch_shape = torch.Size([2, 3, 1])`.
"""
if dim > len(self.batch_shape) or dim < -len(self.batch_shape) - 1:
raise IndexError(
"Dimension out of range (expected to be in range of "
f"[{-len(self.batch_shape) - 1}, {len(self.batch_shape)}], but got {dim})."
)
if dim < 0:
# If dim is negative, get the positive equivalent.
dim = len(self.batch_shape) + dim + 1

new_loc = self.loc.unsqueeze(dim)
if self.islazy:
new_covar = self._covar.unsqueeze(dim)
new = self.__class__(mean=new_loc, covariance_matrix=new_covar)
if self.__unbroadcasted_scale_tril is not None:
# Reuse the scale tril if available.
new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
else:
# Non-lazy MVN is represented using scale_tril in PyTorch.
# Constructing it from scale_tril will avoid unnecessary computation.
# Initialize using __new__, so that we can skip __init__ and use scale_tril.
new = self.__new__(type(self))
new._islazy = False
new_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
super(MultivariateNormal, new).__init__(loc=new_loc, scale_tril=new_scale_tril)
# Set the covar matrix, since it is always available for GPyTorch MVN.
new.covariance_matrix = self.covariance_matrix.unsqueeze(dim)
return new

def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
r"""
Expand Down
2 changes: 1 addition & 1 deletion gpytorch/kernels/cosine_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def period_length(self):

@period_length.setter
def period_length(self, value):
return self._set_period_length(value)
self._set_period_length(value)

def _set_period_length(self, value):
if not torch.is_tensor(value):
Expand Down
Loading

0 comments on commit 3437156

Please sign in to comment.