-
Notifications
You must be signed in to change notification settings - Fork 561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ensure PSD-safe factorization in constructor of MultivariateNormal
#2297
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,27 +42,36 @@ class MultivariateNormal(TMultivariateNormal, Distribution): | |
:ivar torch.Tensor variance: The variance. | ||
""" | ||
|
||
def __init__(self, mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator], validate_args: bool = False): | ||
self._islazy = isinstance(mean, LinearOperator) or isinstance(covariance_matrix, LinearOperator) | ||
if self._islazy: | ||
if validate_args: | ||
ms = mean.size(-1) | ||
cs1 = covariance_matrix.size(-1) | ||
cs2 = covariance_matrix.size(-2) | ||
if not (ms == cs1 and ms == cs2): | ||
raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}") | ||
self.loc = mean | ||
self._covar = covariance_matrix | ||
self.__unbroadcasted_scale_tril = None | ||
self._validate_args = validate_args | ||
batch_shape = torch.broadcast_shapes(self.loc.shape[:-1], covariance_matrix.shape[:-2]) | ||
|
||
event_shape = self.loc.shape[-1:] | ||
|
||
# TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic | ||
super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False) | ||
else: | ||
super().__init__(loc=mean, covariance_matrix=covariance_matrix, validate_args=validate_args) | ||
def __init__( | ||
self, | ||
mean: Union[Tensor, LinearOperator], | ||
covariance_matrix: Union[Tensor, LinearOperator], | ||
validate_args: bool = False, | ||
): | ||
self._islazy = True | ||
# casting Tensor to DenseLinearOperator because the super constructor calls cholesky, which | ||
# will fail if the covariance matrix is semi-definite, whereas DenseLinearOperator ends up | ||
# calling _psd_safe_cholesky, which factorizes semi-definite matrices by adding to the diagonal. | ||
if isinstance(covariance_matrix, Tensor): | ||
self._islazy = False # to allow _unbroadcasted_scale_tril setter | ||
covariance_matrix = to_linear_operator(covariance_matrix) | ||
|
||
if validate_args: | ||
ms = mean.size(-1) | ||
cs1 = covariance_matrix.size(-1) | ||
cs2 = covariance_matrix.size(-2) | ||
if not (ms == cs1 and ms == cs2): | ||
raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}") | ||
self.loc = mean | ||
self._covar = covariance_matrix | ||
self.__unbroadcasted_scale_tril = None | ||
self._validate_args = validate_args | ||
batch_shape = torch.broadcast_shapes(self.loc.shape[:-1], covariance_matrix.shape[:-2]) | ||
|
||
event_shape = self.loc.shape[-1:] | ||
|
||
# TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean changing the torch code to validate LinearOperator inputs? That might be somewhat challenging to do if we want to use LinearOperators there explicitly. What would work is to make changes in pure torch that would make it easier to use LinearOperator objects by means of the |
||
super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False) | ||
|
||
def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size: | ||
""" | ||
|
@@ -81,16 +90,16 @@ def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size | |
def _repr_sizes(mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator]) -> str: | ||
return f"MultivariateNormal(loc: {mean.size()}, scale: {covariance_matrix.size()})" | ||
|
||
@property | ||
@property # not using lazy_property here, because it does not allow for setter below | ||
def _unbroadcasted_scale_tril(self) -> Tensor: | ||
if self.islazy and self.__unbroadcasted_scale_tril is None: | ||
if self.__unbroadcasted_scale_tril is None: | ||
# cache root decoposition | ||
ust = to_dense(self.lazy_covariance_matrix.cholesky()) | ||
self.__unbroadcasted_scale_tril = ust | ||
return self.__unbroadcasted_scale_tril | ||
|
||
@_unbroadcasted_scale_tril.setter | ||
def _unbroadcasted_scale_tril(self, ust: Tensor): | ||
def _unbroadcasted_scale_tril(self, ust: Tensor) -> None: | ||
if self.islazy: | ||
raise NotImplementedError("Cannot set _unbroadcasted_scale_tril for lazy MVN distributions") | ||
else: | ||
|
@@ -114,10 +123,7 @@ def base_sample_shape(self) -> torch.Size: | |
|
||
@lazy_property | ||
def covariance_matrix(self) -> Tensor: | ||
if self.islazy: | ||
return self._covar.to_dense() | ||
else: | ||
return super().covariance_matrix | ||
return self._covar.to_dense() | ||
|
||
def confidence_region(self) -> Tuple[Tensor, Tensor]: | ||
""" | ||
|
@@ -157,10 +163,7 @@ def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor: | |
|
||
@lazy_property | ||
def lazy_covariance_matrix(self) -> LinearOperator: | ||
if self.islazy: | ||
return self._covar | ||
else: | ||
return to_linear_operator(super().covariance_matrix) | ||
return self._covar | ||
|
||
def log_prob(self, value: Tensor) -> Tensor: | ||
r""" | ||
|
@@ -304,13 +307,10 @@ def to_data_independent_dist(self) -> torch.distributions.Normal: | |
|
||
@property | ||
def variance(self) -> Tensor: | ||
if self.islazy: | ||
# overwrite this since torch MVN uses unbroadcasted_scale_tril for this | ||
diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2) | ||
diag = diag.view(diag.shape[:-1] + self._event_shape) | ||
variance = diag.expand(self._batch_shape + self._event_shape) | ||
else: | ||
variance = super().variance | ||
# overwrite this since torch MVN uses unbroadcasted_scale_tril for this | ||
diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2) | ||
diag = diag.view(diag.shape[:-1] + self._event_shape) | ||
variance = diag.expand(self._batch_shape + self._event_shape) | ||
|
||
# Check to make sure that variance isn't lower than minimum allowed value (default 1e-6). | ||
# This ensures that all variances are positive | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems odd to have
_islazy
set toTrue
if the covariance matrix is indeed aLinearOperator
. I guess the "lazy" nomenclature is a bit outdated anyway with the move toLinearOperator
.