diff --git a/gpytorch/test/lazy_tensor_test_case.py b/gpytorch/test/lazy_tensor_test_case.py index f51dbc477..e50976b9b 100644 --- a/gpytorch/test/lazy_tensor_test_case.py +++ b/gpytorch/test/lazy_tensor_test_case.py @@ -10,7 +10,6 @@ import gpytorch from gpytorch.settings import linalg_dtypes -from gpytorch.utils.cholesky import CHOLESKY_METHOD from gpytorch.utils.errors import CachingError from gpytorch.utils.memoize import get_from_cache @@ -211,15 +210,28 @@ def test_getitem_tensor_index(self): # Batch case else: for batch_index in product( - [torch.tensor([0, 1, 1, 0]), slice(None, None, None)], repeat=(lazy_tensor.dim() - 2) + [torch.tensor([0, 1, 1, 0]), slice(None, None, None)], + repeat=(lazy_tensor.dim() - 2), ): - index = (*batch_index, torch.tensor([0, 1, 0, 2]), torch.tensor([1, 2, 0, 1])) + index = ( + *batch_index, + torch.tensor([0, 1, 0, 2]), + torch.tensor([1, 2, 0, 1]), + ) res, actual = lazy_tensor[index], evaluated[index] self.assertAllClose(res, actual) - index = (*batch_index, torch.tensor([0, 1, 0, 2]), slice(None, None, None)) + index = ( + *batch_index, + torch.tensor([0, 1, 0, 2]), + slice(None, None, None), + ) res, actual = gpytorch.delazify(lazy_tensor[index]), evaluated[index] self.assertAllClose(res, actual) - index = (*batch_index, slice(None, None, None), torch.tensor([0, 1, 2, 1])) + index = ( + *batch_index, + slice(None, None, None), + torch.tensor([0, 1, 2, 1]), + ) res, actual = gpytorch.delazify(lazy_tensor[index]), evaluated[index] self.assertAllClose(res, actual) index = (*batch_index, slice(None, None, None), slice(None, None, None)) @@ -298,7 +310,10 @@ class LazyTensorTestCase(RectangularLazyTensorTestCase): "root_inv_decomposition": {"rtol": 0.05, "atol": 0.02}, "sample": {"rtol": 0.3, "atol": 0.3}, "sqrt_inv_matmul": {"rtol": 1e-2, "atol": 1e-3}, - "symeig": {"double": {"rtol": 1e-4, "atol": 1e-3}, "float": {"rtol": 1e-3, "atol": 1e-2}}, + "symeig": { + "double": {"rtol": 1e-4, "atol": 1e-3}, + "float": {"rtol": 1e-3, "atol": 1e-2}, + }, "svd": {"rtol": 1e-4, "atol": 1e-3}, } @@ -650,15 +665,13 @@ def _test_triangular_lazy_tensor_inv_quad_logdet(self): chol = lazy_tensor.root_decomposition().root.clone() gpytorch.utils.memoize.clear_cache_hook(lazy_tensor) gpytorch.utils.memoize.add_to_cache( - lazy_tensor, "root_decomposition", gpytorch.lazy.RootLazyTensor(chol) + lazy_tensor, + "root_decomposition", + gpytorch.lazy.RootLazyTensor(chol), ) - _wrapped_cholesky = MagicMock( - wraps=torch.linalg.cholesky - if CHOLESKY_METHOD == "torch.linalg.cholesky" - else torch.linalg.cholesky_ex - ) - with patch(CHOLESKY_METHOD, new=_wrapped_cholesky) as cholesky_mock: + _wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex) + with patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky) as cholesky_mock: self._test_inv_quad_logdet(reduce_inv_quad=True, cholesky=True, lazy_tensor=lazy_tensor) self.assertFalse(cholesky_mock.called) @@ -778,7 +791,11 @@ def test_symeig(self): # since LazyTensor.symeig does not sort evals, we do this here for the check evals, idxr = torch.sort(evals_unsorted, dim=-1, descending=False) - evecs = torch.gather(evecs_unsorted, dim=-1, index=idxr.unsqueeze(-2).expand(evecs_unsorted.shape)) + evecs = torch.gather( + evecs_unsorted, + dim=-1, + index=idxr.unsqueeze(-2).expand(evecs_unsorted.shape), + ) evals_actual, evecs_actual = torch.linalg.eigh(evaluated.type(dtype)) evals_actual = evals_actual.to(dtype=evaluated.dtype) diff --git a/gpytorch/test/variational_test_case.py b/gpytorch/test/variational_test_case.py index 85271c5da..57943a323 100644 --- a/gpytorch/test/variational_test_case.py +++ b/gpytorch/test/variational_test_case.py @@ -6,7 +6,6 @@ import torch import gpytorch -from gpytorch.utils.cholesky import CHOLESKY_METHOD from .base_test_case import BaseTestCase @@ -25,7 +24,10 @@ class _SVGPRegressionModel(gpytorch.models.ApproximateGP): def __init__(self, inducing_points): variational_distribution = distribution_cls(num_inducing, batch_shape=batch_shape) variational_strategy = strategy_cls( - self, inducing_points, variational_distribution, learn_inducing_locations=True + self, + inducing_points, + variational_distribution, + learn_inducing_locations=True, ) super().__init__(variational_strategy) if constant_mean: @@ -45,7 +47,12 @@ def forward(self, x): return _SVGPRegressionModel(inducing_points), self.likelihood_cls() def _training_iter( - self, model, likelihood, batch_shape=torch.Size([]), mll_cls=gpytorch.mlls.VariationalELBO, cuda=False + self, + model, + likelihood, + batch_shape=torch.Size([]), + mll_cls=gpytorch.mlls.VariationalELBO, + cuda=False, ): train_x = torch.randn(*batch_shape, 32, 2).clamp(-2.5, 2.5) train_y = torch.linspace(-1, 1, self.event_shape[0]) @@ -132,12 +139,10 @@ def test_eval_iteration( eval_data_batch_shape = eval_data_batch_shape if eval_data_batch_shape is not None else self.batch_shape # Mocks - _wrapped_cholesky = MagicMock( - wraps=torch.linalg.cholesky if CHOLESKY_METHOD == "torch.linalg.cholesky" else torch.linalg.cholesky_ex - ) + _wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex) _wrapped_cg = MagicMock(wraps=gpytorch.utils.linear_cg) _wrapped_ciq = MagicMock(wraps=gpytorch.utils.contour_integral_quad) - _cholesky_mock = patch(CHOLESKY_METHOD, new=_wrapped_cholesky) + _cholesky_mock = patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky) _cg_mock = patch("gpytorch.utils.linear_cg", new=_wrapped_cg) _ciq_mock = patch("gpytorch.utils.contour_integral_quad", new=_wrapped_ciq) @@ -194,12 +199,10 @@ def test_training_iteration( expected_batch_shape = expected_batch_shape if expected_batch_shape is not None else self.batch_shape # Mocks - _wrapped_cholesky = MagicMock( - wraps=torch.linalg.cholesky if CHOLESKY_METHOD == "torch.linalg.cholesky" else torch.linalg.cholesky_ex - ) + _wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex) _wrapped_cg = MagicMock(wraps=gpytorch.utils.linear_cg) _wrapped_ciq = MagicMock(wraps=gpytorch.utils.contour_integral_quad) - _cholesky_mock = patch(CHOLESKY_METHOD, new=_wrapped_cholesky) + _cholesky_mock = patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky) _cg_mock = patch("gpytorch.utils.linear_cg", new=_wrapped_cg) _ciq_mock = patch("gpytorch.utils.contour_integral_quad", new=_wrapped_ciq) @@ -216,11 +219,21 @@ def test_training_iteration( with _cholesky_mock as cholesky_mock, _cg_mock as cg_mock, _ciq_mock as ciq_mock: # Iter 1 self.assertEqual(model.variational_strategy.variational_params_initialized.item(), 0) - self._training_iter(model, likelihood, data_batch_shape, mll_cls=self.mll_cls, cuda=self.cuda) + self._training_iter( + model, + likelihood, + data_batch_shape, + mll_cls=self.mll_cls, + cuda=self.cuda, + ) self.assertEqual(model.variational_strategy.variational_params_initialized.item(), 1) # Iter 2 output, loss = self._training_iter( - model, likelihood, data_batch_shape, mll_cls=self.mll_cls, cuda=self.cuda + model, + likelihood, + data_batch_shape, + mll_cls=self.mll_cls, + cuda=self.cuda, ) self.assertEqual(output.batch_shape, expected_batch_shape) self.assertEqual(output.event_shape, self.event_shape) diff --git a/gpytorch/utils/cholesky.py b/gpytorch/utils/cholesky.py index e07901e30..97d2e0bdf 100644 --- a/gpytorch/utils/cholesky.py +++ b/gpytorch/utils/cholesky.py @@ -8,86 +8,41 @@ from .errors import NanError, NotPSDError from .warnings import NumericalWarning -try: - from torch.linalg import cholesky_ex # noqa: F401 - CHOLESKY_METHOD = "torch.linalg.cholesky_ex" # used for counting mock calls - - def _psd_safe_cholesky(A, out=None, jitter=None, max_tries=3): - # Maybe log - if settings.verbose_linalg.on(): - settings.verbose_linalg.logger.debug(f"Running Cholesky on a matrix of size {A.shape}.") - - if out is not None: - out = (out, torch.empty(A.shape[:-2], dtype=torch.int32, device=out.device)) - - L, info = torch.linalg.cholesky_ex(A, out=out) +def _psd_safe_cholesky(A, out=None, jitter=None, max_tries=3): + # Maybe log + if settings.verbose_linalg.on(): + settings.verbose_linalg.logger.debug(f"Running Cholesky on a matrix of size {A.shape}.") + + if out is not None: + out = (out, torch.empty(A.shape[:-2], dtype=torch.int32, device=out.device)) + + L, info = torch.linalg.cholesky_ex(A, out=out) + if not torch.any(info): + return L + + isnan = torch.isnan(A) + if isnan.any(): + raise NanError(f"cholesky_cpu: {isnan.sum().item()} of {A.numel()} elements of the {A.shape} tensor are NaN.") + + if jitter is None: + jitter = settings.cholesky_jitter.value(A.dtype) + Aprime = A.clone() + jitter_prev = 0 + for i in range(max_tries): + jitter_new = jitter * (10 ** i) + # add jitter only where needed + diag_add = ((info > 0) * (jitter_new - jitter_prev)).unsqueeze(-1).expand(*Aprime.shape[:-1]) + Aprime.diagonal(dim1=-1, dim2=-2).add_(diag_add) + jitter_prev = jitter_new + warnings.warn( + f"A not p.d., added jitter of {jitter_new:.1e} to the diagonal", + NumericalWarning, + ) + L, info = torch.linalg.cholesky_ex(Aprime, out=out) if not torch.any(info): return L - - isnan = torch.isnan(A) - if isnan.any(): - raise NanError( - f"cholesky_cpu: {isnan.sum().item()} of {A.numel()} elements of the {A.shape} tensor are NaN." - ) - - if jitter is None: - jitter = settings.cholesky_jitter.value(A.dtype) - Aprime = A.clone() - jitter_prev = 0 - for i in range(max_tries): - jitter_new = jitter * (10 ** i) - # add jitter only where needed - diag_add = ((info > 0) * (jitter_new - jitter_prev)).unsqueeze(-1).expand(*Aprime.shape[:-1]) - Aprime.diagonal(dim1=-1, dim2=-2).add_(diag_add) - jitter_prev = jitter_new - warnings.warn(f"A not p.d., added jitter of {jitter_new:.1e} to the diagonal", NumericalWarning) - L, info = torch.linalg.cholesky_ex(Aprime, out=out) - if not torch.any(info): - return L - raise NotPSDError(f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}.") - - -except ImportError: - - # Fall back to torch.linalg.cholesky - this can be more than 3 orders of magnitude slower! - # TODO: Remove once PyTorch req. is >= 1.9 - - CHOLESKY_METHOD = "torch.linalg.cholesky" # used for counting mock calls - - def _psd_safe_cholesky(A, out=None, jitter=None, max_tries=3): - # Maybe log - if settings.verbose_linalg.on(): - settings.verbose_linalg.logger.debug(f"Running Cholesky on a matrix of size {A.shape}.") - - try: - L = torch.linalg.cholesky(A, out=out) - return L - except RuntimeError as e: - isnan = torch.isnan(A) - if isnan.any(): - raise NanError( - f"cholesky_cpu: {isnan.sum().item()} of {A.numel()} elements of the {A.shape} tensor are NaN." - ) - - if jitter is None: - jitter = settings.cholesky_jitter.value(A.dtype) - Aprime = A.clone() - jitter_prev = 0 - for i in range(max_tries): - jitter_new = jitter * (10 ** i) - Aprime.diagonal(dim1=-2, dim2=-1).add_(jitter_new - jitter_prev) - jitter_prev = jitter_new - try: - L = torch.linalg.cholesky(Aprime, out=out) - warnings.warn(f"A not p.d., added jitter of {jitter_new:.1e} to the diagonal", NumericalWarning) - return L - except RuntimeError: - continue - raise NotPSDError( - f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}. " - f"Original error on first attempt: {e}" - ) + raise NotPSDError(f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}.") def psd_safe_cholesky(A, upper=False, out=None, jitter=None, max_tries=3): diff --git a/test/examples/test_sgpr_regression.py b/test/examples/test_sgpr_regression.py index 7a2cebbcc..e657354d8 100644 --- a/test/examples/test_sgpr_regression.py +++ b/test/examples/test_sgpr_regression.py @@ -15,7 +15,6 @@ from gpytorch.means import ConstantMean from gpytorch.priors import SmoothedBoxPrior from gpytorch.test.utils import least_used_cuda_device -from gpytorch.utils.cholesky import CHOLESKY_METHOD from gpytorch.utils.warnings import NumericalWarning from torch import optim @@ -82,9 +81,9 @@ def test_sgpr_mean_abs_error(self, cuda=False): # Mock cholesky _wrapped_cholesky = MagicMock( - wraps=torch.linalg.cholesky if CHOLESKY_METHOD == "torch.linalg.cholesky" else torch.linalg.cholesky_ex + wraps=torch.linalg.cholesky_ex ) - with patch(CHOLESKY_METHOD, new=_wrapped_cholesky) as cholesky_mock: + with patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky) as cholesky_mock: # Optimize the model gp_model.train()