Skip to content

Commit

Permalink
Remove _psd_safe_cholesky that uses torch.linalg.cholesky (#1850)
Browse files Browse the repository at this point in the history
* remove `_psd_safe_cholesky` that uses `torch.linalg.cholesky`

* black
  • Loading branch information
valtron authored Dec 3, 2021
1 parent bf13e7a commit 8d11dd5
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 107 deletions.
45 changes: 31 additions & 14 deletions gpytorch/test/lazy_tensor_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import gpytorch
from gpytorch.settings import linalg_dtypes
from gpytorch.utils.cholesky import CHOLESKY_METHOD
from gpytorch.utils.errors import CachingError
from gpytorch.utils.memoize import get_from_cache

Expand Down Expand Up @@ -211,15 +210,28 @@ def test_getitem_tensor_index(self):
# Batch case
else:
for batch_index in product(
[torch.tensor([0, 1, 1, 0]), slice(None, None, None)], repeat=(lazy_tensor.dim() - 2)
[torch.tensor([0, 1, 1, 0]), slice(None, None, None)],
repeat=(lazy_tensor.dim() - 2),
):
index = (*batch_index, torch.tensor([0, 1, 0, 2]), torch.tensor([1, 2, 0, 1]))
index = (
*batch_index,
torch.tensor([0, 1, 0, 2]),
torch.tensor([1, 2, 0, 1]),
)
res, actual = lazy_tensor[index], evaluated[index]
self.assertAllClose(res, actual)
index = (*batch_index, torch.tensor([0, 1, 0, 2]), slice(None, None, None))
index = (
*batch_index,
torch.tensor([0, 1, 0, 2]),
slice(None, None, None),
)
res, actual = gpytorch.delazify(lazy_tensor[index]), evaluated[index]
self.assertAllClose(res, actual)
index = (*batch_index, slice(None, None, None), torch.tensor([0, 1, 2, 1]))
index = (
*batch_index,
slice(None, None, None),
torch.tensor([0, 1, 2, 1]),
)
res, actual = gpytorch.delazify(lazy_tensor[index]), evaluated[index]
self.assertAllClose(res, actual)
index = (*batch_index, slice(None, None, None), slice(None, None, None))
Expand Down Expand Up @@ -298,7 +310,10 @@ class LazyTensorTestCase(RectangularLazyTensorTestCase):
"root_inv_decomposition": {"rtol": 0.05, "atol": 0.02},
"sample": {"rtol": 0.3, "atol": 0.3},
"sqrt_inv_matmul": {"rtol": 1e-2, "atol": 1e-3},
"symeig": {"double": {"rtol": 1e-4, "atol": 1e-3}, "float": {"rtol": 1e-3, "atol": 1e-2}},
"symeig": {
"double": {"rtol": 1e-4, "atol": 1e-3},
"float": {"rtol": 1e-3, "atol": 1e-2},
},
"svd": {"rtol": 1e-4, "atol": 1e-3},
}

Expand Down Expand Up @@ -650,15 +665,13 @@ def _test_triangular_lazy_tensor_inv_quad_logdet(self):
chol = lazy_tensor.root_decomposition().root.clone()
gpytorch.utils.memoize.clear_cache_hook(lazy_tensor)
gpytorch.utils.memoize.add_to_cache(
lazy_tensor, "root_decomposition", gpytorch.lazy.RootLazyTensor(chol)
lazy_tensor,
"root_decomposition",
gpytorch.lazy.RootLazyTensor(chol),
)

_wrapped_cholesky = MagicMock(
wraps=torch.linalg.cholesky
if CHOLESKY_METHOD == "torch.linalg.cholesky"
else torch.linalg.cholesky_ex
)
with patch(CHOLESKY_METHOD, new=_wrapped_cholesky) as cholesky_mock:
_wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex)
with patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky) as cholesky_mock:
self._test_inv_quad_logdet(reduce_inv_quad=True, cholesky=True, lazy_tensor=lazy_tensor)
self.assertFalse(cholesky_mock.called)

Expand Down Expand Up @@ -778,7 +791,11 @@ def test_symeig(self):

# since LazyTensor.symeig does not sort evals, we do this here for the check
evals, idxr = torch.sort(evals_unsorted, dim=-1, descending=False)
evecs = torch.gather(evecs_unsorted, dim=-1, index=idxr.unsqueeze(-2).expand(evecs_unsorted.shape))
evecs = torch.gather(
evecs_unsorted,
dim=-1,
index=idxr.unsqueeze(-2).expand(evecs_unsorted.shape),
)

evals_actual, evecs_actual = torch.linalg.eigh(evaluated.type(dtype))
evals_actual = evals_actual.to(dtype=evaluated.dtype)
Expand Down
39 changes: 26 additions & 13 deletions gpytorch/test/variational_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch

import gpytorch
from gpytorch.utils.cholesky import CHOLESKY_METHOD

from .base_test_case import BaseTestCase

Expand All @@ -25,7 +24,10 @@ class _SVGPRegressionModel(gpytorch.models.ApproximateGP):
def __init__(self, inducing_points):
variational_distribution = distribution_cls(num_inducing, batch_shape=batch_shape)
variational_strategy = strategy_cls(
self, inducing_points, variational_distribution, learn_inducing_locations=True
self,
inducing_points,
variational_distribution,
learn_inducing_locations=True,
)
super().__init__(variational_strategy)
if constant_mean:
Expand All @@ -45,7 +47,12 @@ def forward(self, x):
return _SVGPRegressionModel(inducing_points), self.likelihood_cls()

def _training_iter(
self, model, likelihood, batch_shape=torch.Size([]), mll_cls=gpytorch.mlls.VariationalELBO, cuda=False
self,
model,
likelihood,
batch_shape=torch.Size([]),
mll_cls=gpytorch.mlls.VariationalELBO,
cuda=False,
):
train_x = torch.randn(*batch_shape, 32, 2).clamp(-2.5, 2.5)
train_y = torch.linspace(-1, 1, self.event_shape[0])
Expand Down Expand Up @@ -132,12 +139,10 @@ def test_eval_iteration(
eval_data_batch_shape = eval_data_batch_shape if eval_data_batch_shape is not None else self.batch_shape

# Mocks
_wrapped_cholesky = MagicMock(
wraps=torch.linalg.cholesky if CHOLESKY_METHOD == "torch.linalg.cholesky" else torch.linalg.cholesky_ex
)
_wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex)
_wrapped_cg = MagicMock(wraps=gpytorch.utils.linear_cg)
_wrapped_ciq = MagicMock(wraps=gpytorch.utils.contour_integral_quad)
_cholesky_mock = patch(CHOLESKY_METHOD, new=_wrapped_cholesky)
_cholesky_mock = patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky)
_cg_mock = patch("gpytorch.utils.linear_cg", new=_wrapped_cg)
_ciq_mock = patch("gpytorch.utils.contour_integral_quad", new=_wrapped_ciq)

Expand Down Expand Up @@ -194,12 +199,10 @@ def test_training_iteration(
expected_batch_shape = expected_batch_shape if expected_batch_shape is not None else self.batch_shape

# Mocks
_wrapped_cholesky = MagicMock(
wraps=torch.linalg.cholesky if CHOLESKY_METHOD == "torch.linalg.cholesky" else torch.linalg.cholesky_ex
)
_wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex)
_wrapped_cg = MagicMock(wraps=gpytorch.utils.linear_cg)
_wrapped_ciq = MagicMock(wraps=gpytorch.utils.contour_integral_quad)
_cholesky_mock = patch(CHOLESKY_METHOD, new=_wrapped_cholesky)
_cholesky_mock = patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky)
_cg_mock = patch("gpytorch.utils.linear_cg", new=_wrapped_cg)
_ciq_mock = patch("gpytorch.utils.contour_integral_quad", new=_wrapped_ciq)

Expand All @@ -216,11 +219,21 @@ def test_training_iteration(
with _cholesky_mock as cholesky_mock, _cg_mock as cg_mock, _ciq_mock as ciq_mock:
# Iter 1
self.assertEqual(model.variational_strategy.variational_params_initialized.item(), 0)
self._training_iter(model, likelihood, data_batch_shape, mll_cls=self.mll_cls, cuda=self.cuda)
self._training_iter(
model,
likelihood,
data_batch_shape,
mll_cls=self.mll_cls,
cuda=self.cuda,
)
self.assertEqual(model.variational_strategy.variational_params_initialized.item(), 1)
# Iter 2
output, loss = self._training_iter(
model, likelihood, data_batch_shape, mll_cls=self.mll_cls, cuda=self.cuda
model,
likelihood,
data_batch_shape,
mll_cls=self.mll_cls,
cuda=self.cuda,
)
self.assertEqual(output.batch_shape, expected_batch_shape)
self.assertEqual(output.event_shape, self.event_shape)
Expand Down
109 changes: 32 additions & 77 deletions gpytorch/utils/cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,86 +8,41 @@
from .errors import NanError, NotPSDError
from .warnings import NumericalWarning

try:
from torch.linalg import cholesky_ex # noqa: F401

CHOLESKY_METHOD = "torch.linalg.cholesky_ex" # used for counting mock calls

def _psd_safe_cholesky(A, out=None, jitter=None, max_tries=3):
# Maybe log
if settings.verbose_linalg.on():
settings.verbose_linalg.logger.debug(f"Running Cholesky on a matrix of size {A.shape}.")

if out is not None:
out = (out, torch.empty(A.shape[:-2], dtype=torch.int32, device=out.device))

L, info = torch.linalg.cholesky_ex(A, out=out)
def _psd_safe_cholesky(A, out=None, jitter=None, max_tries=3):
# Maybe log
if settings.verbose_linalg.on():
settings.verbose_linalg.logger.debug(f"Running Cholesky on a matrix of size {A.shape}.")

if out is not None:
out = (out, torch.empty(A.shape[:-2], dtype=torch.int32, device=out.device))

L, info = torch.linalg.cholesky_ex(A, out=out)
if not torch.any(info):
return L

isnan = torch.isnan(A)
if isnan.any():
raise NanError(f"cholesky_cpu: {isnan.sum().item()} of {A.numel()} elements of the {A.shape} tensor are NaN.")

if jitter is None:
jitter = settings.cholesky_jitter.value(A.dtype)
Aprime = A.clone()
jitter_prev = 0
for i in range(max_tries):
jitter_new = jitter * (10 ** i)
# add jitter only where needed
diag_add = ((info > 0) * (jitter_new - jitter_prev)).unsqueeze(-1).expand(*Aprime.shape[:-1])
Aprime.diagonal(dim1=-1, dim2=-2).add_(diag_add)
jitter_prev = jitter_new
warnings.warn(
f"A not p.d., added jitter of {jitter_new:.1e} to the diagonal",
NumericalWarning,
)
L, info = torch.linalg.cholesky_ex(Aprime, out=out)
if not torch.any(info):
return L

isnan = torch.isnan(A)
if isnan.any():
raise NanError(
f"cholesky_cpu: {isnan.sum().item()} of {A.numel()} elements of the {A.shape} tensor are NaN."
)

if jitter is None:
jitter = settings.cholesky_jitter.value(A.dtype)
Aprime = A.clone()
jitter_prev = 0
for i in range(max_tries):
jitter_new = jitter * (10 ** i)
# add jitter only where needed
diag_add = ((info > 0) * (jitter_new - jitter_prev)).unsqueeze(-1).expand(*Aprime.shape[:-1])
Aprime.diagonal(dim1=-1, dim2=-2).add_(diag_add)
jitter_prev = jitter_new
warnings.warn(f"A not p.d., added jitter of {jitter_new:.1e} to the diagonal", NumericalWarning)
L, info = torch.linalg.cholesky_ex(Aprime, out=out)
if not torch.any(info):
return L
raise NotPSDError(f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}.")


except ImportError:

# Fall back to torch.linalg.cholesky - this can be more than 3 orders of magnitude slower!
# TODO: Remove once PyTorch req. is >= 1.9

CHOLESKY_METHOD = "torch.linalg.cholesky" # used for counting mock calls

def _psd_safe_cholesky(A, out=None, jitter=None, max_tries=3):
# Maybe log
if settings.verbose_linalg.on():
settings.verbose_linalg.logger.debug(f"Running Cholesky on a matrix of size {A.shape}.")

try:
L = torch.linalg.cholesky(A, out=out)
return L
except RuntimeError as e:
isnan = torch.isnan(A)
if isnan.any():
raise NanError(
f"cholesky_cpu: {isnan.sum().item()} of {A.numel()} elements of the {A.shape} tensor are NaN."
)

if jitter is None:
jitter = settings.cholesky_jitter.value(A.dtype)
Aprime = A.clone()
jitter_prev = 0
for i in range(max_tries):
jitter_new = jitter * (10 ** i)
Aprime.diagonal(dim1=-2, dim2=-1).add_(jitter_new - jitter_prev)
jitter_prev = jitter_new
try:
L = torch.linalg.cholesky(Aprime, out=out)
warnings.warn(f"A not p.d., added jitter of {jitter_new:.1e} to the diagonal", NumericalWarning)
return L
except RuntimeError:
continue
raise NotPSDError(
f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}. "
f"Original error on first attempt: {e}"
)
raise NotPSDError(f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}.")


def psd_safe_cholesky(A, upper=False, out=None, jitter=None, max_tries=3):
Expand Down
5 changes: 2 additions & 3 deletions test/examples/test_sgpr_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from gpytorch.means import ConstantMean
from gpytorch.priors import SmoothedBoxPrior
from gpytorch.test.utils import least_used_cuda_device
from gpytorch.utils.cholesky import CHOLESKY_METHOD
from gpytorch.utils.warnings import NumericalWarning
from torch import optim

Expand Down Expand Up @@ -82,9 +81,9 @@ def test_sgpr_mean_abs_error(self, cuda=False):

# Mock cholesky
_wrapped_cholesky = MagicMock(
wraps=torch.linalg.cholesky if CHOLESKY_METHOD == "torch.linalg.cholesky" else torch.linalg.cholesky_ex
wraps=torch.linalg.cholesky_ex
)
with patch(CHOLESKY_METHOD, new=_wrapped_cholesky) as cholesky_mock:
with patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky) as cholesky_mock:

# Optimize the model
gp_model.train()
Expand Down

0 comments on commit 8d11dd5

Please sign in to comment.