diff --git a/README.md b/README.md
index 079d1229..d3590abd 100644
--- a/README.md
+++ b/README.md
@@ -4,17 +4,17 @@
[![Main](https://travis-ci.com/AlexImmer/Laplace.svg?token=rpuRxEjQS6cCZi7ptL9y&branch=main)](https://travis-ci.com/AlexImmer/Laplace)
-The laplace package facilitates the application of Laplace approximations for entire neural networks or just their last layer.
+The laplace package facilitates the application of Laplace approximations for entire neural networks, subnetworks of neural networks, or just their last layer.
The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations.
The library documentation is available at [https://aleximmer.github.io/Laplace](https://aleximmer.github.io/Laplace).
There is also a corresponding paper, [*Laplace Redux — Effortless Bayesian Deep Learning*](https://arxiv.org/abs/2106.14806), which introduces the library, provides an introduction to the Laplace approximation, reviews its use in deep learning, and empirically demonstrates its versatility and competitiveness. Please consider referring to the paper when using our library:
```bibtex
-@article{daxberger2021laplace,
- title={Laplace Redux--Effortless Bayesian Deep Learning},
- author={Daxberger, Erik and Kristiadi, Agustinus and Immer, Alexander
- and Eschenhagen, Runa and Bauer, Matthias and Hennig, Philipp},
- journal={arXiv preprint arXiv:2106.14806},
+@inproceedings{laplace2021,
+ title={Laplace Redux--Effortless {B}ayesian Deep Learning},
+ author={Erik Daxberger and Agustinus Kristiadi and Alexander Immer
+ and Runa Eschenhagen and Matthias Bauer and Philipp Hennig},
+ booktitle={{N}eur{IPS}},
year={2021}
}
```
@@ -39,18 +39,24 @@ pytest tests/
## Structure
The laplace package consists of two main components:
-1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, and `'diag'`). This results in six currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, and the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace`, which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
+1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'` and `'diag'`). This results in _eight_ currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), [`laplace.SubnetLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetlaplace.py) (which only supports a `'full'` Hessian approximation) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
2. The backends in [`laplace.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/) which provide access to Hessian approximations of
the corresponding sparsity structures, for example, the diagonal GGN.
Additionally, the package provides utilities for
-decomposing a neural network into feature extractor and last layer for `LLLaplace` subclasses ([`laplace.feature_extractor`](https://github.com/AlexImmer/Laplace/blob/main/laplace/feature_extractor.py))
+decomposing a neural network into feature extractor and last layer for `LLLaplace` subclasses ([`laplace.utils.feature_extractor`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/feature_extractor.py))
and
-effectively dealing with Kronecker factors ([`laplace.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/matrix.py)).
+effectively dealing with Kronecker factors ([`laplace.utils.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/matrix.py)).
+
+Finally, the package implements several options to select/specify a subnetwork for `SubnetLaplace` (as subclasses of [`laplace.utils.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/subnetmask.py)).
+Automatic subnetwork selection strategies include: uniformly at random (`laplace.utils.subnetmask.RandomSubnetMask`), by largest parameter magnitudes (`LargestMagnitudeSubnetMask`), and by largest marginal parameter variances (`LargestVarianceDiagLaplaceSubnetMask` and `LargestVarianceSWAGSubnetMask`).
+In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (`ParamNameSubnetMask`) or modules (`ModuleNameSubnetMask`) to perform Laplace inference over.
## Extendability
To extend the laplace package, new `BaseLaplace` subclasses can be designed, for example,
-a block-diagonal structure or subset-of-weights Laplace.
+Laplace with a block-diagonal Hessian structure.
+One can also implement custom subnetwork selection strategies as new subclasses of `SubnetMask`.
+
Alternatively, extending or integrating backends (subclasses of [`curvature.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvature.py)) allows to provide different Hessian
approximations to the Laplace approximations.
For example, currently the [`curvature.BackPackInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/backpack.py) based on [BackPACK](https://github.com/f-dangel/backpack/) and [`curvature.AsdlInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/asdl.py) based on [ASDL](https://github.com/kazukiosawa/asdfghjkl) are available.
@@ -60,10 +66,11 @@ for a regression (MSELoss) loss function.
## Example usage
-### *Post-hoc* prior precision tuning of last-layer LA
+### *Post-hoc* prior precision tuning of diagonal LA
In the following example, a pre-trained model is loaded,
-then the Laplace approximation is fit to the training data,
+then the Laplace approximation is fit to the training data
+(using a diagonal Hessian approximation over all parameters),
and the prior precision is optimized with cross-validation `'CV'`.
After that, the resulting LA is used for prediction with
the `'probit'` predictive for classification.
@@ -71,7 +78,7 @@ the `'probit'` predictive for classification.
```python
from laplace import Laplace
-# pre-trained model
+# Pre-trained model
model = load_map_model()
# User-specified LA flavor
@@ -87,7 +94,7 @@ pred = la(x, link_approx='probit')
### Differentiating the log marginal likelihood w.r.t. hyperparameters
-The marginal likelihood can be used for model selection and is differentiable
+The marginal likelihood can be used for model selection [10] and is differentiable
for continuous hyperparameters like the prior precision or observation noise.
Here, we fit the library default, KFAC last-layer LA and differentiate
the log marginal likelihood.
@@ -107,6 +114,45 @@ ml = la.log_marginal_likelihood(prior_prec, obs_noise)
ml.backward()
```
+### Applying the LA over only a subset of the model parameters
+
+This example shows how to fit the Laplace approximation over only
+a subnetwork within a neural network (while keeping all other parameters
+fixed at their MAP estimates), as proposed in [11]. It also exemplifies
+different ways to specify the subnetwork to perform inference over.
+
+```python
+from laplace import Laplace
+
+# Pre-trained model
+model = load_model()
+
+# Examples of different ways to specify the subnetwork
+# via indices of the vectorized model parameters
+#
+# Example 1: select the 128 parameters with the largest magnitude
+from laplace.utils import LargestMagnitudeSubnetMask
+subnetwork_mask = LargestMagnitudeSubnetMask(model, n_params_subnet=128)
+subnetwork_indices = subnetwork_mask.select()
+
+# Example 2: specify the layers that define the subnetwork
+from laplace.utils import ModuleNameSubnetMask
+subnetwork_mask = ModuleNameSubnetMask(model, module_names=['layer.1', 'layer.3'])
+subnetwork_mask.select()
+subnetwork_indices = subnetwork_mask.indices
+
+# Example 3: manually define the subnetwork via custom subnetwork indices
+import torch
+subnetwork_indices = torch.tensor([0, 4, 11, 42, 123, 2021])
+
+# Define and fit subnetwork LA using the specified subnetwork indices
+la = Laplace(model, 'classification',
+ subset_of_weights='subnetwork',
+ hessian_structure='full',
+ subnetwork_indices=subnetwork_indices)
+la.fit(train_loader)
+```
+
## Documentation
The documentation is available [here](https://aleximmer.github.io/Laplace) or can be generated and/or viewed locally:
@@ -122,7 +168,7 @@ pdoc --http 0.0.0.0:8080 laplace --template-dir template
## References
-This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1].
+This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1]. Please consider citing the respective papers if you use any of their proposed methods via our laplace library.
- [1] MacKay, DJC. [*A Practical Bayesian Framework for Backpropagation Networks*](https://authors.library.caltech.edu/13793/). Neural Computation 1992.
- [2] Gibbs, M. N. [*Bayesian Gaussian Processes for Regression and Classification*](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.147.1130&rep=rep1&type=pdf). PhD Thesis 1997.
@@ -132,4 +178,6 @@ This package relies on various improvements to the Laplace approximation for neu
- [6] Khan, M. E., Immer, A., Abedi, E., Korzepa, M. [*Approximate Inference Turns Deep Networks into Gaussian Processes*](https://arxiv.org/abs/1906.01930). NeurIPS 2019.
- [7] Kristiadi, A., Hein, M., Hennig, P. [*Being Bayesian, Even Just a Bit, Fixes Overconfidence in ReLU Networks*](https://arxiv.org/abs/2002.10118). ICML 2020.
- [8] Immer, A., Korzepa, M., Bauer, M. [*Improving predictions of Bayesian neural nets via local linearization*](https://arxiv.org/abs/2008.08400). AISTATS 2021.
-- [9] Immer, A., Bauer, M., Fortuin, V., Rätsch, G., Khan, EM. [*Scalable Marginal Likelihood Estimation for Model Selection in Deep Learning*](https://arxiv.org/abs/2104.04975). ICML 2021.
+- [9] Sharma, A., Azizan, N., Pavone, M. [*Sketching Curvature for Efficient Out-of-Distribution Detection for Deep Neural Networks*](https://arxiv.org/abs/2102.12567). UAI 2021.
+- [10] Immer, A., Bauer, M., Fortuin, V., Rätsch, G., Khan, EM. [*Scalable Marginal Likelihood Estimation for Model Selection in Deep Learning*](https://arxiv.org/abs/2104.04975). ICML 2021.
+- [11] Daxberger, E., Nalisnick, E., Allingham, JU., Antorán, J., Hernández-Lobato, JM. [*Bayesian Deep Learning via Subnetwork Inference*](https://arxiv.org/abs/2010.14689). ICML 2021.
\ No newline at end of file
diff --git a/docs/baselaplace.html b/docs/baselaplace.html
index fbaa0a07..ea13c62b 100644
--- a/docs/baselaplace.html
+++ b/docs/baselaplace.html
@@ -172,6 +172,253 @@
Subclasses need to specify how the Hessian approximation is initialized,
+how to add up curvature over training data, how to sample from the
+Laplace approximation, and how to compute the functional variance.
+
A Laplace approximation is represented by a MAP which is given by the
+model parameter and a posterior precision or covariance specifying
+a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}).
+The goal of this class is to compute the posterior precision P
+which sums as
+
+P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta)
+\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}.
+
+Every subclass implements different approximations to the log likelihood Hessians,
+for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have
+a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 .
+In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in
+all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.
Computes the scatter, a term of the log marginal likelihood that
+corresponds to L-2 regularization:
+scatter = (\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) .
+
Returns
+
[type]
+[description]
+
+
var log_det_prior_precision
+
+
Compute log determinant of the prior precision
+\log \det P_0
+
Returns
+
+
log_det : torch.Tensor
+
+
+
+
var log_det_posterior_precision
+
+
Compute log determinant of the posterior precision
+\log \det P which depends on the subclasses structure
+used for the Hessian approximation.
+
Returns
+
+
log_det : torch.Tensor
+
+
+
+
var log_det_ratio
+
+
Compute the log determinant ratio, a part of the log marginal likelihood.
+
+\log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0
+
+
Returns
+
+
log_det_ratio : torch.Tensor
+
+
+
+
var posterior_precision
+
+
Compute or return the posterior precision P.
+
Returns
+
+
posterior_prec : torch.Tensor
+
+
+
+
+
Methods
+
+
+def fit(self, train_loader, override=True)
+
+
+
Fit the local Laplace approximation at the parameters of the model.
+
Parameters
+
+
train_loader : torch.data.utils.DataLoader
+
each iterate is a training batch (X, y);
+train_loader.dataset needs to be set to access N, size of the data set
+
override : bool, default=True
+
whether to initialize H, loss, and n_data again; setting to False is useful for
+online learning settings to accumulate a sequential posterior approximation.
+
+
+
+def square_norm(self, value)
+
+
+
Compute the square norm under post. Precision with value-self.mean as 𝛥:
+
+\Delta^
+op P \Delta
+
+Returns
+
+
+
square_form
+
+
+
+
+def log_prob(self, value, normalized=True)
+
+
+
Compute the log probability under the (current) Laplace approximation.
+
Parameters
+
+
normalized : bool, default=True
+
whether to return log of a properly normalized Gaussian or just the
+terms that depend on value.
Compute the Laplace approximation to the log marginal likelihood subject
+to specific Hessian approximations that subclasses implement.
+Requires that the Laplace approximation has been fit before.
+The resulting torch.Tensor is differentiable in prior_precision and
+sigma_noise if these have gradients enabled.
+By passing prior_precision or sigma_noise, the current value is
+overwritten. This is useful for iterating on the log marginal likelihood.
+
Parameters
+
+
prior_precision : torch.Tensor, optional
+
prior precision if should be changed from current prior_precision value
+
sigma_noise : [type], optional
+
observation noise standard deviation if should be changed
Sample from the posterior predictive on input data x.
+Can be used, for example, for Thompson sampling.
+
Parameters
+
+
x : torch.Tensor
+
input data (batch_size, input_shape)
+
pred_type : {'glm', 'nn'}, default='glm'
+
type of posterior predictive, linearized GLM predictive or neural
+network sampling predictive. The GLM predictive is consistent with
+the curvature approximations used here.
+
n_samples : int
+
number of samples
+
+
Returns
+
+
samples : torch.Tensor
+
samples (n_samples, batch_size, output_shape)
+
+
+
+def functional_variance(self, Jacs)
+
+
+
Compute functional variance for the 'glm' predictive:
+f_var[i] = Jacs[i] @ P.inv() @ Jacs[i].T, which is a output x output
+predictive covariance matrix.
+Mathematically, we have for a single Jacobian
+\mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}}
+the output covariance matrix
+ \mathcal{J} P^{-1} \mathcal{J}^T .
+
Parameters
+
+
Jacs : torch.Tensor
+
Jacobians of model output wrt parameters
+(batch, outputs, parameters)
+
+
Returns
+
+
f_var : torch.Tensor
+
output covariance (batch, outputs, outputs)
+
+
+
+def sample(self, n_samples=100)
+
+
+
Sample from the Laplace posterior approximation, i.e.,
+ \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1}).
Mathematically, we have for each parameter group, e.g., torch.nn.Module,
that \P\approx Q \otimes H.
See BaseLaplace for the full interface and see
-Kron and KronDecomposed for the structure of
+Kron and KronDecomposed for the structure of
the Kronecker factors. Kron is used to aggregate factors by summing up and
KronDecomposed is used to add the prior, a Hessian factor (e.g. temperature),
and computing posterior covariances, marginal likelihood, etc.
@@ -273,7 +523,7 @@
Subclasses need to specify how the Hessian approximation is initialized,
-how to add up curvature over training data, how to sample from the
-Laplace approximation, and how to compute the functional variance.
-
A Laplace approximation is represented by a MAP which is given by the
-model parameter and a posterior precision or covariance specifying
-a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}).
-The goal of this class is to compute the posterior precision P
-which sums as
-
-P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta)
-\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}.
-
-Every subclass implements different approximations to the log likelihood Hessians,
-for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have
-a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 .
-In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in
-all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.
+
Laplace approximation with low-rank log likelihood Hessian (approximation).
+The low-rank matrix is represented by an eigendecomposition (vecs, values).
+Based on the chosen backend, either a true Hessian or, for example, GGN
+approximation could be used.
+The posterior precision is computed as
+ P = V diag(l) V^T + P_0.
+To sample, compute the functional variance, and log determinant, algebraic tricks
+are usedto reduce the costs of inversion to the that of a K
+imes K matrix
+if we have a rank of K.
Computes the scatter, a term of the log marginal likelihood that
-corresponds to L-2 regularization:
-scatter = (\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) .
-
Returns
-
[type]
-[description]
-
-
var log_det_prior_precision
-
-
Compute log determinant of the prior precision
-\log \det P_0
-
Returns
-
-
log_det : torch.Tensor
-
-
-
-
var log_det_posterior_precision
+
var V
-
Compute log determinant of the posterior precision
-\log \det P which depends on the subclasses structure
-used for the Hessian approximation.
-
Returns
-
-
log_det : torch.Tensor
-
-
-
-
var log_det_ratio
-
-
Compute the log determinant ratio, a part of the log marginal likelihood.
-
-\log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0
-
-
Returns
-
-
log_det_ratio : torch.Tensor
-
-
-
-
var posterior_precision
-
-
Compute or return the posterior precision P.
-
Returns
-
-
posterior_prec : torch.Tensor
-
-
-
-
-
Methods
-
-
-def fit(self, train_loader)
-
-
-
Fit the local Laplace approximation at the parameters of the model.
-
Parameters
-
-
train_loader : torch.data.utils.DataLoader
-
each iterate is a training batch (X, y);
-train_loader.dataset needs to be set to access N, size of the data set
Compute the Laplace approximation to the log marginal likelihood subject
-to specific Hessian approximations that subclasses implement.
-Requires that the Laplace approximation has been fit before.
-The resulting torch.Tensor is differentiable in prior_precision and
-sigma_noise if these have gradients enabled.
-By passing prior_precision or sigma_noise, the current value is
-overwritten. This is useful for iterating on the log marginal likelihood.
-
Parameters
-
-
prior_precision : torch.Tensor, optional
-
prior precision if should be changed from current prior_precision value
-
sigma_noise : [type], optional
-
observation noise standard deviation if should be changed
Sample from the posterior predictive on input data x.
-Can be used, for example, for Thompson sampling.
-
Parameters
-
-
x : torch.Tensor
-
input data (batch_size, input_shape)
-
pred_type : {'glm', 'nn'}, default='glm'
-
type of posterior predictive, linearized GLM predictive or neural
-network sampling predictive. The GLM predictive is consistent with
-the curvature approximations used here.
-
n_samples : int
-
number of samples
-
-
Returns
-
-
samples : torch.Tensor
-
samples (n_samples, batch_size, output_shape)
-
+
-
-def functional_variance(self, Jacs)
-
+
var posterior_precision
-
Compute functional variance for the 'glm' predictive:
-f_var[i] = Jacs[i] @ P.inv() @ Jacs[i].T, which is a output x output
-predictive covariance matrix.
-Mathematically, we have for a single Jacobian
-\mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}}
-the output covariance matrix
- \mathcal{J} P^{-1} \mathcal{J}^T .
-
Parameters
-
-
Jacs : torch.Tensor
-
Jacobians of model output wrt parameters
-(batch, outputs, parameters)
-
+
Return correctly scaled posterior precision that would be constructed
+as H[0] @ diag(H[1]) @ H[0].T + self.prior_precision_diag.
Returns
-
f_var : torch.Tensor
-
output covariance (batch, outputs, outputs)
-
-
-
-def sample(self, n_samples=100)
-
-
-
Sample from the Laplace posterior approximation, i.e.,
- \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1}).
-
Parameters
-
-
n_samples : int, default=100
-
number of samples
+
H : tuple(eigenvectors, eigenvalues)
+
scaled self.H with temperature and loss factors.
+
prior_precision_diag : torch.Tensor
+
diagonal prior precision shape parameters to be added to H.
The laplace package facilitates the application of Laplace approximations for entire neural networks or just their last layer.
+
The laplace package facilitates the application of Laplace approximations for entire neural networks, subnetworks of neural networks, or just their last layer.
The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations.
The library documentation is available at https://aleximmer.github.io/Laplace.
There is also a corresponding paper, Laplace Redux — Effortless Bayesian Deep Learning, which introduces the library, provides an introduction to the Laplace approximation, reviews its use in deep learning, and empirically demonstrates its versatility and competitiveness. Please consider referring to the paper when using our library:
-
@article{daxberger2021laplace,
- title={Laplace Redux--Effortless Bayesian Deep Learning},
- author={Daxberger, Erik and Kristiadi, Agustinus and Immer, Alexander
- and Eschenhagen, Runa and Bauer, Matthias and Hennig, Philipp},
- journal={arXiv preprint arXiv:2106.14806},
+
@inproceedings{laplace2021,
+ title={Laplace Redux--Effortless {B}ayesian Deep Learning},
+ author={Erik Daxberger and Agustinus Kristiadi and Alexander Immer
+ and Runa Eschenhagen and Matthias Bauer and Philipp Hennig},
+ booktitle={{N}eur{IPS}},
year={2021}
}
@@ -56,34 +56,39 @@
Setup
Structure
The laplace package consists of two main components:
-
The subclasses of laplace.BaseLaplace that implement different sparsity structures: different subsets of weights ('all' and 'last_layer') and different structures of the Hessian approximation ('full', 'kron', and 'diag'). This results in six currently available options: FullLaplace, KronLaplace, DiagLaplace, and the corresponding last-layer variations FullLLLaplace, KronLLLaplace,
-and DiagLLLaplace, which are all subclasses of laplace.LLLaplace. All of these can be conveniently accessed via the laplace.Laplace function.
+
The subclasses of laplace.BaseLaplace that implement different sparsity structures: different subsets of weights ('all', 'subnetwork' and 'last_layer') and different structures of the Hessian approximation ('full', 'kron', 'lowrank' and 'diag'). This results in eight currently available options: FullLaplace, KronLaplace, DiagLaplace, the corresponding last-layer variations FullLLLaplace, KronLLLaplace,
+and DiagLLLaplace (which are all subclasses of laplace.LLLaplace), laplace.SubnetLaplace (which only supports a 'full' Hessian approximation) and LowRankLaplace (which only supports inference over 'all' weights). All of these can be conveniently accessed via the laplace.Laplace function.
The backends in laplace.curvature which provide access to Hessian approximations of
the corresponding sparsity structures, for example, the diagonal GGN.
Additionally, the package provides utilities for
-decomposing a neural network into feature extractor and last layer for LLLaplace subclasses (laplace.feature_extractor)
+decomposing a neural network into feature extractor and last layer for LLLaplace subclasses (laplace.utils.feature_extractor)
and
-effectively dealing with Kronecker factors (laplace.matrix).
Finally, the package implements several options to select/specify a subnetwork for SubnetLaplace (as subclasses of laplace.utils.subnetmask.SubnetMask).
+Automatic subnetwork selection strategies include: uniformly at random (RandomSubnetMask), by largest parameter magnitudes (LargestMagnitudeSubnetMask), and by largest marginal parameter variances (LargestVarianceDiagLaplaceSubnetMask and LargestVarianceSWAGSubnetMask).
+In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (ParamNameSubnetMask) or modules (ModuleNameSubnetMask) to perform Laplace inference over.
Extendability
To extend the laplace package, new BaseLaplace subclasses can be designed, for example,
-a block-diagonal structure or subset-of-weights Laplace.
-Alternatively, extending or integrating backends (subclasses of curvature.curvature) allows to provide different Hessian
+Laplace with a block-diagonal Hessian structure.
+One can also implement custom subnetwork selection strategies as new subclasses of SubnetMask.
+
Alternatively, extending or integrating backends (subclasses of curvature.curvature) allows to provide different Hessian
approximations to the Laplace approximations.
For example, currently the curvature.BackPackInterface based on BackPACK and curvature.AsdlInterface based on ASDL are available.
The AsdlInterface provides a Kronecker factored empirical Fisher while the BackPackInterface
does not, and only the BackPackInterface provides access to Hessian approximations
for a regression (MSELoss) loss function.
Example usage
-
Post-hoc prior precision tuning of last-layer LA
+
Post-hoc prior precision tuning of diagonal LA
In the following example, a pre-trained model is loaded,
-then the Laplace approximation is fit to the training data,
+then the Laplace approximation is fit to the training data
+(using a diagonal Hessian approximation over all parameters),
and the prior precision is optimized with cross-validation 'CV'.
After that, the resulting LA is used for prediction with
the 'probit' predictive for classification.
from laplace import Laplace
-# pre-trained model
+# Pre-trained model
model = load_map_model()
# User-specified LA flavor
@@ -97,7 +102,7 @@
Post-hoc prio
pred = la(x, link_approx='probit')
Differentiating the log marginal likelihood w.r.t. hyperparameters
-
The marginal likelihood can be used for model selection and is differentiable
+
The marginal likelihood can be used for model selection [10] and is differentiable
for continuous hyperparameters like the prior precision or observation noise.
Here, we fit the library default, KFAC last-layer LA and differentiate
the log marginal likelihood.
@@ -114,6 +119,41 @@
Differe
ml = la.log_marginal_likelihood(prior_prec, obs_noise)
ml.backward()
+
Applying the LA over only a subset of the model parameters
+
This example shows how to fit the Laplace approximation over only
+a subnetwork within a neural network (while keeping all other parameters
+fixed at their MAP estimates), as proposed in [11]. It also exemplifies
+different ways to specify the subnetwork to perform inference over.
+
from laplace import Laplace
+
+# Pre-trained model
+model = load_model()
+
+# Examples of different ways to specify the subnetwork
+# via indices of the vectorized model parameters
+#
+# Example 1: select the 128 parameters with the largest magnitude
+from laplace.utils import LargestMagnitudeSubnetMask
+subnetwork_mask = LargestMagnitudeSubnetMask(model, n_params_subnet=128)
+subnetwork_indices = subnetwork_mask.select()
+
+# Example 2: specify the layers that define the subnetwork
+from laplace.utils import ModuleNameSubnetMask
+subnetwork_mask = ModuleNameSubnetMask(model, module_names=['layer.1', 'layer.3'])
+subnetwork_mask.select()
+subnetwork_indices = subnetwork_mask.indices
+
+# Example 3: manually define the subnetwork via custom subnetwork indices
+import torch
+subnetwork_indices = torch.tensor([0, 4, 11, 42, 123, 2021])
+
+# Define and fit subnetwork LA using the specified subnetwork indices
+la = Laplace(model, 'classification',
+ subset_of_weights='subnetwork',
+ hessian_structure='full',
+ subnetwork_indices=subnetwork_indices)
+la.fit(train_loader)
+
Documentation
The documentation is available here or can be generated and/or viewed locally:
# assuming the repository was cloned
@@ -124,7 +164,7 @@
This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1].
+
This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1]. Please consider citing the respective papers if you use any of their proposed methods via our laplace library.
Fit the local Laplace approximation at the parameters of the model.
@@ -706,6 +745,45 @@
Parameters
train_loader : torch.data.utils.DataLoader
each iterate is a training batch (X, y);
train_loader.dataset needs to be set to access N, size of the data set
+
override : bool, default=True
+
whether to initialize H, loss, and n_data again; setting to False is useful for
+online learning settings to accumulate a sequential posterior approximation.
+
+
+
+def square_norm(self, value)
+
+
+
Compute the square norm under post. Precision with value-self.mean as 𝛥:
+
+\Delta^
+op P \Delta
+
+Returns
+
+
+
square_form
+
+
+
+
+def log_prob(self, value, normalized=True)
+
+
+
Compute the log probability under the (current) Laplace approximation.
+
Parameters
+
+
normalized : bool, default=True
+
whether to return log of a properly normalized Gaussian or just the
+terms that depend on value.
Mathematically, we have for each parameter group, e.g., torch.nn.Module,
that \P\approx Q \otimes H.
See BaseLaplace for the full interface and see
-Kron and KronDecomposed for the structure of
+Kron and KronDecomposed for the structure of
the Kronecker factors. Kron is used to aggregate factors by summing up and
KronDecomposed is used to add the prior, a Hessian factor (e.g. temperature),
and computing posterior covariances, marginal likelihood, etc.
@@ -909,7 +990,7 @@
Laplace approximation with low-rank log likelihood Hessian (approximation).
+The low-rank matrix is represented by an eigendecomposition (vecs, values).
+Based on the chosen backend, either a true Hessian or, for example, GGN
+approximation could be used.
+The posterior precision is computed as
+ P = V diag(l) V^T + P_0.
+To sample, compute the functional variance, and log determinant, algebraic tricks
+are usedto reduce the costs of inversion to the that of a K
+imes K matrix
+if we have a rank of K.
Mathematically, we have for the last parameter group, i.e., torch.nn.Linear,
that \P\approx Q \otimes H.
See KronLaplace, LLLaplace, and BaseLaplace for the full interface and see
-Kron and KronDecomposed for the structure of
+Kron and KronDecomposed for the structure of
the Kronecker factors. Kron is used to aggregate factors by summing up and
KronDecomposed is used to add the prior, a Hessian factor (e.g. temperature),
and computing posterior covariances, marginal likelihood, etc.
Use of damping is possible by initializing or setting damping=True.
Class for subnetwork Laplace, which computes the Laplace approximation over
+just a subset of the model parameters (i.e. a subnetwork within the neural network),
+as proposed in [1]. Subnetwork Laplace only supports a full Hessian approximation; other
+approximations could be used in theory, but would not make as much sense conceptually.
+
A Laplace approximation is represented by a MAP which is given by the
+model parameter and a posterior precision or covariance specifying
+a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}).
+Here, only a subset of the model parameters (i.e. a subnetwork of the
+neural network) are treated probabilistically.
+The goal of this class is to compute the posterior precision P
+which sums as
+
+P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta)
+\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}.
+
+The prior is assumed to be Gaussian and therefore we have a simple form for
+\nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 .
+In particular, we assume a scalar or diagonal prior precision so that in
+all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.
+
The subnetwork Laplace approximation only supports a full, i.e., dense, log likelihood
+Hessian approximation and hence posterior precision.
+Based on the chosen backend
+parameter, the full approximation can be, for example, a generalized Gauss-Newton
+matrix.
+Mathematically, we have P \in \mathbb{R}^{P \times P}.
+See FullLaplace and BaseLaplace for the full interface.
determines the log likelihood Hessian approximation
+
subnetwork_indices : torch.LongTensor
+
indices of the vectorized model parameters
+(i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
+that define the subnetwork to apply the Laplace approximation over
+
sigma_noise : torch.Tensor or float, default=1
+
observation noise for the regression setting; must be 1 for classification
+
prior_precision : torch.Tensor or float, default=1
+
prior precision of a Gaussian prior (= weight decay);
+can be scalar, per-layer, or diagonal in the most general case
+
prior_mean : torch.Tensor or float, default=0
+
prior mean of a Gaussian prior, useful for continual learning
+
temperature : float, default=1
+
temperature of the likelihood; lower temperature leads to more
+concentrated posterior and vice versa.
Baseclass for all last-layer Laplace approximations in this library.
+Subclasses specify the structure of the Hessian approximation.
+See BaseLaplace for the full interface.
+
A Laplace approximation is represented by a MAP which is given by the
+model parameter and a posterior precision or covariance specifying
+a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}).
+Here, only the parameters of the last layer of the neural network
+are treated probabilistically.
+The goal of this class is to compute the posterior precision P
+which sums as
+
+P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta)
+\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}.
+
+Every subclass implements different approximations to the log likelihood Hessians,
+for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have
+a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 .
+In particular, we assume a scalar or diagonal prior precision so that in
+all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.
and hence posterior precision. Based on the chosen backend parameter, the full
approximation can be, for example, a generalized Gauss-Newton matrix.
Mathematically, we have P \in \mathbb{R}^{P \times P}.
-See FullLaplace, LLLaplace, and BaseLaplace for the full interface.
+See FullLaplace, LLLaplace, and BaseLaplace for the full interface.
and hence posterior precision.
Mathematically, we have for the last parameter group, i.e., torch.nn.Linear,
that \P\approx Q \otimes H.
-See KronLaplace, LLLaplace, and BaseLaplace for the full interface and see
-Kron and KronDecomposed for the structure of
+See KronLaplace, LLLaplace, and BaseLaplace for the full interface and see
+Kron and KronDecomposed for the structure of
the Kronecker factors. Kron is used to aggregate factors by summing up and
KronDecomposed is used to add the prior, a Hessian factor (e.g. temperature),
and computing posterior covariances, marginal likelihood, etc.
Use of damping is possible by initializing or setting damping=True.
Last-layer Laplace approximation with diagonal log likelihood Hessian approximation
and hence posterior precision.
Mathematically, we have P \approx \textrm{diag}(P).
-See DiagLaplace, LLLaplace, and BaseLaplace for the full interface.
+See DiagLaplace, LLLaplace, and BaseLaplace for the full interface.
Class for subnetwork Laplace, which computes the Laplace approximation over
+just a subset of the model parameters (i.e. a subnetwork within the neural network),
+as proposed in [1]. Subnetwork Laplace only supports a full Hessian approximation; other
+approximations could be used in theory, but would not make as much sense conceptually.
+
A Laplace approximation is represented by a MAP which is given by the
+model parameter and a posterior precision or covariance specifying
+a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}).
+Here, only a subset of the model parameters (i.e. a subnetwork of the
+neural network) are treated probabilistically.
+The goal of this class is to compute the posterior precision P
+which sums as
+
+P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta)
+\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}.
+
+The prior is assumed to be Gaussian and therefore we have a simple form for
+\nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 .
+In particular, we assume a scalar or diagonal prior precision so that in
+all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.
+
The subnetwork Laplace approximation only supports a full, i.e., dense, log likelihood
+Hessian approximation and hence posterior precision.
+Based on the chosen backend
+parameter, the full approximation can be, for example, a generalized Gauss-Newton
+matrix.
+Mathematically, we have P \in \mathbb{R}^{P \times P}.
+See FullLaplace and BaseLaplace for the full interface.
determines the log likelihood Hessian approximation
+
subnetwork_indices : torch.LongTensor
+
indices of the vectorized model parameters
+(i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
+that define the subnetwork to apply the Laplace approximation over
+
sigma_noise : torch.Tensor or float, default=1
+
observation noise for the regression setting; must be 1 for classification
+
prior_precision : torch.Tensor or float, default=1
+
prior precision of a Gaussian prior (= weight decay);
+can be scalar, per-layer, or diagonal in the most general case
+
prior_mean : torch.Tensor or float, default=0
+
prior mean of a Gaussian prior, useful for continual learning
+
temperature : float, default=1
+
temperature of the likelihood; lower temperature leads to more
+concentrated posterior and vice versa.
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/feature_extractor.html b/docs/utils/feature_extractor.html
similarity index 84%
rename from docs/feature_extractor.html
rename to docs/utils/feature_extractor.html
index 1c7ef071..9599128b 100644
--- a/docs/feature_extractor.html
+++ b/docs/utils/feature_extractor.html
@@ -4,7 +4,7 @@
-laplace.feature_extractor API documentation
+laplace.utils.feature_extractor API documentation
@@ -20,7 +20,7 @@
-
Module laplace.feature_extractor
+
Module laplace.utils.feature_extractor
@@ -33,7 +33,7 @@
Module laplace.feature_extractor
Classes
-
+
class FeatureExtractor(model: torch.nn.modules.module.Module, last_layer_name: Optional[str] = None)
Fit diagonal SWAG [1], which estimates marginal variances of model parameters by
+computing the first and second moment of SGD iterates with a large learning rate.
Feature extractor for a PyTorch neural network.
+A wrapper which can return the output of the penultimate layer in addition to
+the output of the last layer for each forward pass. If the name of the last
+layer is not known, it can determine it automatically. It assumes that the
+last layer is linear and that for every forward pass the last layer is the same.
+If the name of the last layer is known, it can be passed as a parameter at
+initilization; this is the safest way to use this class.
+Based on https://gist.github.com/fkodom/27ed045c9051a39102e8bcf4ce31df76.
+
Parameters
+
+
model : torch.nn.Module
+
PyTorch model
+
last_layer_name : str, default=None
+
if the name of the last layer is already known, otherwise it will
+be determined automatically.
+
+
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Forward pass which returns the output of the penultimate layer along
+with the output of the last layer. If the last layer is not known yet,
+it will be determined when this function is called for the first time.
+
Parameters
+
+
x : torch.Tensor
+
one batch of data to use as input for the forward pass
Automatically determines the last layer of the model with one
+forward pass. It assumes that the last layer is the same for every
+forward pass and that it is an instance of torch.nn.Linear.
+Might not work with every architecture, but is tested with all PyTorch
+torchvision classification models (besides SqueezeNet, which has no
+linear last layer).
+
Parameters
+
+
x : torch.Tensor
+
one batch of data to use as input for the forward pass
+
+
+
+
+
+class Kron
+(kfacs)
+
+
+
Kronecker factored approximate curvature representation for a corresponding
+neural network.
+Each element in kfacs is either a tuple or single matrix.
+A tuple represents two Kronecker factors Q, and H and a single element
+is just a full block Hessian approximation.
+
Parameters
+
+
kfacs : list[Tuple]
+
each element in the list is a Tuple of two Kronecker factors Q, H
+or a single matrix approximating the Hessian (in case of bias, for example)
+
+
Static methods
+
+
+def init_from_model(model, device)
+
+
+
Initialize Kronecker factors based on a models architecture.
Batched matrix multiplication with the Kronecker factors.
+If Kron is H, we compute H @ W.
+This is useful for computing the predictive or a regularization
+based on Kronecker factors as in continual learning.
+
Parameters
+
+
W : torch.Tensor
+
matrix (batch, classes, params)
+
exponent : float, default=1
+
only can be 1 for Kron, requires KronDecomposed for other
+exponent values of the Kronecker factors.
+
+
Returns
+
+
SW : torch.Tensor
+
result (batch, classes, params)
+
+
+
+def logdet(self) ‑> torch.Tensor
+
+
+
Compute log determinant of the Kronecker factors and sums them up.
+This corresponds to the log determinant of the entire Hessian approximation.
+
Returns
+
+
logdet : torch.Tensor
+
+
+
+
+def diag(self) ‑> torch.Tensor
+
+
+
Extract diagonal of the entire Kronecker factorization.
+
Returns
+
+
diag : torch.Tensor
+
+
+
+
+def to_matrix(self) ‑> torch.Tensor
+
+
+
Make the Kronecker factorization dense by computing the kronecker product.
+Warning: this should only be used for testing purposes as it will allocate
+large amounts of memory for big architectures.
Decomposed Kronecker factored approximate curvature representation
+for a corresponding neural network.
+Each matrix in Kron is decomposed to obtain KronDecomposed.
+Front-loading decomposition allows cheap repeated computation
+of inverses and log determinants.
+In contrast to Kron, we can add scalar or layerwise scalars but
+we cannot add other Kron or KronDecomposed anymore.
+
Parameters
+
+
eigenvectors : list[Tuple[torch.Tensor]]
+
eigenvectors corresponding to matrices in a corresponding Kron
+
eigenvalues : list[Tuple[torch.Tensor]]
+
eigenvalues corresponding to matrices in a corresponding Kron
+
deltas : torch.Tensor
+
addend for each group of Kronecker factors representing, for example,
+a prior precision
+
dampen : bool, default=False
+
use dampen approximation mixing prior and Kron partially multiplicatively
+
+
Methods
+
+
+def detach(self)
+
+
+
+
+
+def logdet(self) ‑> torch.Tensor
+
+
+
Compute log determinant of the Kronecker factors and sums them up.
+This corresponds to the log determinant of the entire Hessian approximation.
+In contrast to Kron.logdet(), additive deltas corresponding to prior
+precisions are added.
Batched matrix multiplication with the decomposed Kronecker factors.
+This is useful for computing the predictive or a regularization loss.
+Compared to Kron.bmm(), a prior can be added here in form of deltas
+and the exponent can be other than just 1.
+Computes H^{exponent} W.
Make the Kronecker factorization dense by computing the kronecker product.
+Warning: this should only be used for testing purposes as it will allocate
+large amounts of memory for big architectures.
+
Returns
+
+
block_diag : torch.Tensor
+
+
+
+
+
+
+class SubnetMask
+(model)
+
+
+
Baseclass for all subnetwork masks in this library (for subnetwork Laplace).
Converts a subnetwork mask into subnetwork indices.
+
Parameters
+
+
subnet_mask : torch.Tensor
+
a binary vector of size (n_params) where 1s locate the subnetwork parameters
+within the vectorized model parameters
+(i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
+
+
Returns
+
+
subnet_mask_indices : torch.LongTensor
+
a vector of indices of the vectorized model parameters
+(i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
+that define the subnetwork
each iterate is a training batch (X, y);
+train_loader.dataset needs to be set to access N, size of the data set
+
+
Returns
+
+
subnet_mask_indices : torch.LongTensor
+
a vector of indices of the vectorized model parameters
+(i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
+that define the subnetwork
+
+
+
+def get_subnet_mask(self, train_loader)
+
+
+
Get the subnetwork mask.
+
Parameters
+
+
train_loader : torch.data.utils.DataLoader
+
each iterate is a training batch (X, y);
+train_loader.dataset needs to be set to access N, size of the data set
+
+
Returns
+
+
subnet_mask : torch.Tensor
+
a binary vector of size (n_params) where 1s locate the subnetwork parameters
+within the vectorized model parameters
+(i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
Subnetwork mask identifying the parameters with the largest marginal variances
+(estimated using a diagonal Laplace approximation over all model parameters).
+
Parameters
+
+
model : torch.nn.Module
+
+
n_params_subnet : int
+
number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/matrix.html b/docs/utils/matrix.html
similarity index 75%
rename from docs/matrix.html
rename to docs/utils/matrix.html
index 6323f7cf..caa3e688 100644
--- a/docs/matrix.html
+++ b/docs/utils/matrix.html
@@ -4,7 +4,7 @@
-laplace.matrix API documentation
+laplace.utils.matrix API documentation
@@ -20,7 +20,7 @@
-
only can be 1 for Kron, requires KronDecomposed for other
+
only can be 1 for Kron, requires KronDecomposed for other
exponent values of the Kronecker factors.
Returns
@@ -111,7 +111,7 @@
Returns
result (batch, classes, params)
-
+
def logdet(self) ‑> torch.Tensor
@@ -123,7 +123,7 @@
Returns
-
+
def diag(self) ‑> torch.Tensor
@@ -134,7 +134,7 @@
Returns
-
+
def to_matrix(self) ‑> torch.Tensor
@@ -149,24 +149,24 @@
Returns
-
+
class KronDecomposed(eigenvectors, eigenvalues, deltas=None, damping=False)
Decomposed Kronecker factored approximate curvature representation
for a corresponding neural network.
-Each matrix in Kron is decomposed to obtain KronDecomposed.
+Each matrix in Kron is decomposed to obtain KronDecomposed.
Front-loading decomposition allows cheap repeated computation
of inverses and log determinants.
-In contrast to Kron, we can add scalar or layerwise scalars but
-we cannot add other Kron or KronDecomposed anymore.
+In contrast to Kron, we can add scalar or layerwise scalars but
+we cannot add other Kron or KronDecomposed anymore.
Parameters
eigenvectors : list[Tuple[torch.Tensor]]
-
eigenvectors corresponding to matrices in a corresponding Kron
+
eigenvectors corresponding to matrices in a corresponding Kron
eigenvalues : list[Tuple[torch.Tensor]]
-
eigenvalues corresponding to matrices in a corresponding Kron
+
eigenvalues corresponding to matrices in a corresponding Kron
deltas : torch.Tensor
addend for each group of Kronecker factors representing, for example,
a prior precision
@@ -175,19 +175,19 @@
Parameters
Methods
-
+
def detach(self)
-
+
def logdet(self) ‑> torch.Tensor
Compute log determinant of the Kronecker factors and sums them up.
This corresponds to the log determinant of the entire Hessian approximation.
-In contrast to Kron.logdet(), additive deltas corresponding to prior
+In contrast to Kron.logdet(), additive deltas corresponding to prior
precisions are added.
Batched matrix multiplication with the decomposed Kronecker factors.
This is useful for computing the predictive or a regularization loss.
-Compared to Kron.bmm(), a prior can be added here in form of deltas
+Compared to Kron.bmm(), a prior can be added here in form of deltas
and the exponent can be other than just 1.
Computes H^{exponent} W.
Converts a subnetwork mask into subnetwork indices.
+
Parameters
+
+
subnet_mask : torch.Tensor
+
a binary vector of size (n_params) where 1s locate the subnetwork parameters
+within the vectorized model parameters
+(i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
+
+
Returns
+
+
subnet_mask_indices : torch.LongTensor
+
a vector of indices of the vectorized model parameters
+(i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
+that define the subnetwork
each iterate is a training batch (X, y);
+train_loader.dataset needs to be set to access N, size of the data set
+
+
Returns
+
+
subnet_mask_indices : torch.LongTensor
+
a vector of indices of the vectorized model parameters
+(i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
+that define the subnetwork
+
+
+
+def get_subnet_mask(self, train_loader)
+
+
+
Get the subnetwork mask.
+
Parameters
+
+
train_loader : torch.data.utils.DataLoader
+
each iterate is a training batch (X, y);
+train_loader.dataset needs to be set to access N, size of the data set
+
+
Returns
+
+
subnet_mask : torch.Tensor
+
a binary vector of size (n_params) where 1s locate the subnetwork parameters
+within the vectorized model parameters
+(i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
Subnetwork mask identifying the parameters with the largest marginal variances
+(estimated using a diagonal Laplace approximation over all model parameters).
+
Parameters
+
+
model : torch.nn.Module
+
+
n_params_subnet : int
+
number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
Fit diagonal SWAG [1], which estimates marginal variances of model parameters by
+computing the first and second moment of SGD iterates with a large learning rate.