diff --git a/README.md b/README.md index 079d1229..d3590abd 100644 --- a/README.md +++ b/README.md @@ -4,17 +4,17 @@ [![Main](https://travis-ci.com/AlexImmer/Laplace.svg?token=rpuRxEjQS6cCZi7ptL9y&branch=main)](https://travis-ci.com/AlexImmer/Laplace) -The laplace package facilitates the application of Laplace approximations for entire neural networks or just their last layer. +The laplace package facilitates the application of Laplace approximations for entire neural networks, subnetworks of neural networks, or just their last layer. The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations. The library documentation is available at [https://aleximmer.github.io/Laplace](https://aleximmer.github.io/Laplace). There is also a corresponding paper, [*Laplace Redux — Effortless Bayesian Deep Learning*](https://arxiv.org/abs/2106.14806), which introduces the library, provides an introduction to the Laplace approximation, reviews its use in deep learning, and empirically demonstrates its versatility and competitiveness. Please consider referring to the paper when using our library: ```bibtex -@article{daxberger2021laplace, - title={Laplace Redux--Effortless Bayesian Deep Learning}, - author={Daxberger, Erik and Kristiadi, Agustinus and Immer, Alexander - and Eschenhagen, Runa and Bauer, Matthias and Hennig, Philipp}, - journal={arXiv preprint arXiv:2106.14806}, +@inproceedings{laplace2021, + title={Laplace Redux--Effortless {B}ayesian Deep Learning}, + author={Erik Daxberger and Agustinus Kristiadi and Alexander Immer + and Runa Eschenhagen and Matthias Bauer and Philipp Hennig}, + booktitle={{N}eur{IPS}}, year={2021} } ``` @@ -39,18 +39,24 @@ pytest tests/ ## Structure The laplace package consists of two main components: -1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, and `'diag'`). This results in six currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, and the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace`, which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function. +1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'` and `'diag'`). This results in _eight_ currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), [`laplace.SubnetLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetlaplace.py) (which only supports a `'full'` Hessian approximation) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function. 2. The backends in [`laplace.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/) which provide access to Hessian approximations of the corresponding sparsity structures, for example, the diagonal GGN. Additionally, the package provides utilities for -decomposing a neural network into feature extractor and last layer for `LLLaplace` subclasses ([`laplace.feature_extractor`](https://github.com/AlexImmer/Laplace/blob/main/laplace/feature_extractor.py)) +decomposing a neural network into feature extractor and last layer for `LLLaplace` subclasses ([`laplace.utils.feature_extractor`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/feature_extractor.py)) and -effectively dealing with Kronecker factors ([`laplace.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/matrix.py)). +effectively dealing with Kronecker factors ([`laplace.utils.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/matrix.py)). + +Finally, the package implements several options to select/specify a subnetwork for `SubnetLaplace` (as subclasses of [`laplace.utils.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/subnetmask.py)). +Automatic subnetwork selection strategies include: uniformly at random (`laplace.utils.subnetmask.RandomSubnetMask`), by largest parameter magnitudes (`LargestMagnitudeSubnetMask`), and by largest marginal parameter variances (`LargestVarianceDiagLaplaceSubnetMask` and `LargestVarianceSWAGSubnetMask`). +In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (`ParamNameSubnetMask`) or modules (`ModuleNameSubnetMask`) to perform Laplace inference over. ## Extendability To extend the laplace package, new `BaseLaplace` subclasses can be designed, for example, -a block-diagonal structure or subset-of-weights Laplace. +Laplace with a block-diagonal Hessian structure. +One can also implement custom subnetwork selection strategies as new subclasses of `SubnetMask`. + Alternatively, extending or integrating backends (subclasses of [`curvature.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvature.py)) allows to provide different Hessian approximations to the Laplace approximations. For example, currently the [`curvature.BackPackInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/backpack.py) based on [BackPACK](https://github.com/f-dangel/backpack/) and [`curvature.AsdlInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/asdl.py) based on [ASDL](https://github.com/kazukiosawa/asdfghjkl) are available. @@ -60,10 +66,11 @@ for a regression (MSELoss) loss function. ## Example usage -### *Post-hoc* prior precision tuning of last-layer LA +### *Post-hoc* prior precision tuning of diagonal LA In the following example, a pre-trained model is loaded, -then the Laplace approximation is fit to the training data, +then the Laplace approximation is fit to the training data +(using a diagonal Hessian approximation over all parameters), and the prior precision is optimized with cross-validation `'CV'`. After that, the resulting LA is used for prediction with the `'probit'` predictive for classification. @@ -71,7 +78,7 @@ the `'probit'` predictive for classification. ```python from laplace import Laplace -# pre-trained model +# Pre-trained model model = load_map_model() # User-specified LA flavor @@ -87,7 +94,7 @@ pred = la(x, link_approx='probit') ### Differentiating the log marginal likelihood w.r.t. hyperparameters -The marginal likelihood can be used for model selection and is differentiable +The marginal likelihood can be used for model selection [10] and is differentiable for continuous hyperparameters like the prior precision or observation noise. Here, we fit the library default, KFAC last-layer LA and differentiate the log marginal likelihood. @@ -107,6 +114,45 @@ ml = la.log_marginal_likelihood(prior_prec, obs_noise) ml.backward() ``` +### Applying the LA over only a subset of the model parameters + +This example shows how to fit the Laplace approximation over only +a subnetwork within a neural network (while keeping all other parameters +fixed at their MAP estimates), as proposed in [11]. It also exemplifies +different ways to specify the subnetwork to perform inference over. + +```python +from laplace import Laplace + +# Pre-trained model +model = load_model() + +# Examples of different ways to specify the subnetwork +# via indices of the vectorized model parameters +# +# Example 1: select the 128 parameters with the largest magnitude +from laplace.utils import LargestMagnitudeSubnetMask +subnetwork_mask = LargestMagnitudeSubnetMask(model, n_params_subnet=128) +subnetwork_indices = subnetwork_mask.select() + +# Example 2: specify the layers that define the subnetwork +from laplace.utils import ModuleNameSubnetMask +subnetwork_mask = ModuleNameSubnetMask(model, module_names=['layer.1', 'layer.3']) +subnetwork_mask.select() +subnetwork_indices = subnetwork_mask.indices + +# Example 3: manually define the subnetwork via custom subnetwork indices +import torch +subnetwork_indices = torch.tensor([0, 4, 11, 42, 123, 2021]) + +# Define and fit subnetwork LA using the specified subnetwork indices +la = Laplace(model, 'classification', + subset_of_weights='subnetwork', + hessian_structure='full', + subnetwork_indices=subnetwork_indices) +la.fit(train_loader) +``` + ## Documentation The documentation is available [here](https://aleximmer.github.io/Laplace) or can be generated and/or viewed locally: @@ -122,7 +168,7 @@ pdoc --http 0.0.0.0:8080 laplace --template-dir template ## References -This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1]. +This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1]. Please consider citing the respective papers if you use any of their proposed methods via our laplace library. - [1] MacKay, DJC. [*A Practical Bayesian Framework for Backpropagation Networks*](https://authors.library.caltech.edu/13793/). Neural Computation 1992. - [2] Gibbs, M. N. [*Bayesian Gaussian Processes for Regression and Classification*](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.147.1130&rep=rep1&type=pdf). PhD Thesis 1997. @@ -132,4 +178,6 @@ This package relies on various improvements to the Laplace approximation for neu - [6] Khan, M. E., Immer, A., Abedi, E., Korzepa, M. [*Approximate Inference Turns Deep Networks into Gaussian Processes*](https://arxiv.org/abs/1906.01930). NeurIPS 2019. - [7] Kristiadi, A., Hein, M., Hennig, P. [*Being Bayesian, Even Just a Bit, Fixes Overconfidence in ReLU Networks*](https://arxiv.org/abs/2002.10118). ICML 2020. - [8] Immer, A., Korzepa, M., Bauer, M. [*Improving predictions of Bayesian neural nets via local linearization*](https://arxiv.org/abs/2008.08400). AISTATS 2021. -- [9] Immer, A., Bauer, M., Fortuin, V., Rätsch, G., Khan, EM. [*Scalable Marginal Likelihood Estimation for Model Selection in Deep Learning*](https://arxiv.org/abs/2104.04975). ICML 2021. +- [9] Sharma, A., Azizan, N., Pavone, M. [*Sketching Curvature for Efficient Out-of-Distribution Detection for Deep Neural Networks*](https://arxiv.org/abs/2102.12567). UAI 2021. +- [10] Immer, A., Bauer, M., Fortuin, V., Rätsch, G., Khan, EM. [*Scalable Marginal Likelihood Estimation for Model Selection in Deep Learning*](https://arxiv.org/abs/2104.04975). ICML 2021. +- [11] Daxberger, E., Nalisnick, E., Allingham, JU., Antorán, J., Hernández-Lobato, JM. [*Bayesian Deep Learning via Subnetwork Inference*](https://arxiv.org/abs/2010.14689). ICML 2021. \ No newline at end of file diff --git a/docs/baselaplace.html b/docs/baselaplace.html index fbaa0a07..ea13c62b 100644 --- a/docs/baselaplace.html +++ b/docs/baselaplace.html @@ -172,6 +172,253 @@

Parameters

+
+class ParametricLaplace +(model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None) +
+
+

Parametric Laplace class.

+

Subclasses need to specify how the Hessian approximation is initialized, +how to add up curvature over training data, how to sample from the +Laplace approximation, and how to compute the functional variance.

+

A Laplace approximation is represented by a MAP which is given by the +model parameter and a posterior precision or covariance specifying +a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). +The goal of this class is to compute the posterior precision P +which sums as + +P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) +\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. + +Every subclass implements different approximations to the log likelihood Hessians, +for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have +a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . +In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in +all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.

+

Ancestors

+ +

Subclasses

+ +

Instance variables

+
+
var scatter
+
+

Computes the scatter, a term of the log marginal likelihood that +corresponds to L-2 regularization: +scatter = (\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) .

+

Returns

+

[type] +[description]

+
+
var log_det_prior_precision
+
+

Compute log determinant of the prior precision +\log \det P_0

+

Returns

+
+
log_det : torch.Tensor
+
 
+
+
+
var log_det_posterior_precision
+
+

Compute log determinant of the posterior precision +\log \det P which depends on the subclasses structure +used for the Hessian approximation.

+

Returns

+
+
log_det : torch.Tensor
+
 
+
+
+
var log_det_ratio
+
+

Compute the log determinant ratio, a part of the log marginal likelihood. + +\log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0 +

+

Returns

+
+
log_det_ratio : torch.Tensor
+
 
+
+
+
var posterior_precision
+
+

Compute or return the posterior precision P.

+

Returns

+
+
posterior_prec : torch.Tensor
+
 
+
+
+
+

Methods

+
+
+def fit(self, train_loader, override=True) +
+
+

Fit the local Laplace approximation at the parameters of the model.

+

Parameters

+
+
train_loader : torch.data.utils.DataLoader
+
each iterate is a training batch (X, y); +train_loader.dataset needs to be set to access N, size of the data set
+
override : bool, default=True
+
whether to initialize H, loss, and n_data again; setting to False is useful for +online learning settings to accumulate a sequential posterior approximation.
+
+
+
+def square_norm(self, value) +
+
+

Compute the square norm under post. Precision with value-self.mean as 𝛥: + +\Delta^ +op P \Delta + +Returns

+
+
+
square_form
+
 
+
+
+
+def log_prob(self, value, normalized=True) +
+
+

Compute the log probability under the (current) Laplace approximation.

+

Parameters

+
+
normalized : bool, default=True
+
whether to return log of a properly normalized Gaussian or just the +terms that depend on value.
+
+

Returns

+
+
log_prob : torch.Tensor
+
 
+
+
+
+def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None) +
+
+

Compute the Laplace approximation to the log marginal likelihood subject +to specific Hessian approximations that subclasses implement. +Requires that the Laplace approximation has been fit before. +The resulting torch.Tensor is differentiable in prior_precision and +sigma_noise if these have gradients enabled. +By passing prior_precision or sigma_noise, the current value is +overwritten. This is useful for iterating on the log marginal likelihood.

+

Parameters

+
+
prior_precision : torch.Tensor, optional
+
prior precision if should be changed from current prior_precision value
+
sigma_noise : [type], optional
+
observation noise standard deviation if should be changed
+
+

Returns

+
+
log_marglik : torch.Tensor
+
 
+
+
+
+def predictive_samples(self, x, pred_type='glm', n_samples=100) +
+
+

Sample from the posterior predictive on input data x. +Can be used, for example, for Thompson sampling.

+

Parameters

+
+
x : torch.Tensor
+
input data (batch_size, input_shape)
+
pred_type : {'glm', 'nn'}, default='glm'
+
type of posterior predictive, linearized GLM predictive or neural +network sampling predictive. The GLM predictive is consistent with +the curvature approximations used here.
+
n_samples : int
+
number of samples
+
+

Returns

+
+
samples : torch.Tensor
+
samples (n_samples, batch_size, output_shape)
+
+
+
+def functional_variance(self, Jacs) +
+
+

Compute functional variance for the 'glm' predictive: +f_var[i] = Jacs[i] @ P.inv() @ Jacs[i].T, which is a output x output +predictive covariance matrix. +Mathematically, we have for a single Jacobian +\mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}} +the output covariance matrix + \mathcal{J} P^{-1} \mathcal{J}^T .

+

Parameters

+
+
Jacs : torch.Tensor
+
Jacobians of model output wrt parameters +(batch, outputs, parameters)
+
+

Returns

+
+
f_var : torch.Tensor
+
output covariance (batch, outputs, outputs)
+
+
+
+def sample(self, n_samples=100) +
+
+

Sample from the Laplace posterior approximation, i.e., + \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1}).

+

Parameters

+
+
n_samples : int, default=100
+
number of samples
+
+
+
+def optimize_prior_precision(self, method='marglik', pred_type='glm', n_steps=100, lr=0.1, init_prior_prec=1.0, val_loader=None, loss=<function get_nll>, log_prior_prec_min=-4, log_prior_prec_max=4, grid_size=100, link_approx='probit', n_samples=100, verbose=False, cv_loss_with_var=False) +
+
+
+
+
+

Inherited members

+ +
class FullLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None) @@ -190,6 +437,7 @@

Ancestors

Subclasses

Instance variables

@@ -233,11 +481,13 @@

Inherited members

  • log_det_ratio
  • log_likelihood
  • log_marginal_likelihood
  • +
  • log_prob
  • optimize_prior_precision_base
  • predictive_samples
  • prior_precision_diag
  • sample
  • scatter
  • +
  • square_norm
  • @@ -252,7 +502,7 @@

    Inherited members

    Mathematically, we have for each parameter group, e.g., torch.nn.Module, that \P\approx Q \otimes H. See BaseLaplace for the full interface and see -Kron and KronDecomposed for the structure of +Kron and KronDecomposed for the structure of the Kronecker factors. Kron is used to aggregate factors by summing up and KronDecomposed is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. @@ -273,7 +523,7 @@

    Instance variables

    Kronecker factored Posterior precision P.

    Returns

    -
    precision : KronDecomposed
    +
    precision : KronDecomposed
     
    @@ -293,11 +543,13 @@

    Inherited members

  • log_det_ratio
  • log_likelihood
  • log_marginal_likelihood
  • +
  • log_prob
  • optimize_prior_precision_base
  • predictive_samples
  • prior_precision_diag
  • sample
  • scatter
  • +
  • square_norm
  • @@ -361,218 +613,80 @@

    Inherited members

  • log_det_ratio
  • log_likelihood
  • log_marginal_likelihood
  • +
  • log_prob
  • optimize_prior_precision_base
  • predictive_samples
  • prior_precision_diag
  • sample
  • scatter
  • +
  • square_norm
  • -
    -class ParametricLaplace -(model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None) +
    +class LowRankLaplace +(model, likelihood, sigma_noise=1, prior_precision=1, prior_mean=0, temperature=1, backend=laplace.curvature.asdl.AsdlHessian, backend_kwargs=None)
    -

    Parametric Laplace class.

    -

    Subclasses need to specify how the Hessian approximation is initialized, -how to add up curvature over training data, how to sample from the -Laplace approximation, and how to compute the functional variance.

    -

    A Laplace approximation is represented by a MAP which is given by the -model parameter and a posterior precision or covariance specifying -a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). -The goal of this class is to compute the posterior precision P -which sums as - -P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) -\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. - -Every subclass implements different approximations to the log likelihood Hessians, -for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have -a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . -In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in -all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.

    +

    Laplace approximation with low-rank log likelihood Hessian (approximation). +The low-rank matrix is represented by an eigendecomposition (vecs, values). +Based on the chosen backend, either a true Hessian or, for example, GGN +approximation could be used. +The posterior precision is computed as + P = V diag(l) V^T + P_0. +To sample, compute the functional variance, and log determinant, algebraic tricks +are usedto reduce the costs of inversion to the that of a K +imes K matrix +if we have a rank of K.

    +

    See BaseLaplace for the full interface.

    Ancestors

    -

    Subclasses

    -

    Instance variables

    -
    var scatter
    -
    -

    Computes the scatter, a term of the log marginal likelihood that -corresponds to L-2 regularization: -scatter = (\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) .

    -

    Returns

    -

    [type] -[description]

    -
    -
    var log_det_prior_precision
    -
    -

    Compute log determinant of the prior precision -\log \det P_0

    -

    Returns

    -
    -
    log_det : torch.Tensor
    -
     
    -
    -
    -
    var log_det_posterior_precision
    +
    var V
    -

    Compute log determinant of the posterior precision -\log \det P which depends on the subclasses structure -used for the Hessian approximation.

    -

    Returns

    -
    -
    log_det : torch.Tensor
    -
     
    -
    -
    -
    var log_det_ratio
    -
    -

    Compute the log determinant ratio, a part of the log marginal likelihood. - -\log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0 -

    -

    Returns

    -
    -
    log_det_ratio : torch.Tensor
    -
     
    -
    -
    -
    var posterior_precision
    -
    -

    Compute or return the posterior precision P.

    -

    Returns

    -
    -
    posterior_prec : torch.Tensor
    -
     
    -
    -
    -
    -

    Methods

    -
    -
    -def fit(self, train_loader) -
    -
    -

    Fit the local Laplace approximation at the parameters of the model.

    -

    Parameters

    -
    -
    train_loader : torch.data.utils.DataLoader
    -
    each iterate is a training batch (X, y); -train_loader.dataset needs to be set to access N, size of the data set
    -
    -
    -
    -def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None) -
    -
    -

    Compute the Laplace approximation to the log marginal likelihood subject -to specific Hessian approximations that subclasses implement. -Requires that the Laplace approximation has been fit before. -The resulting torch.Tensor is differentiable in prior_precision and -sigma_noise if these have gradients enabled. -By passing prior_precision or sigma_noise, the current value is -overwritten. This is useful for iterating on the log marginal likelihood.

    -

    Parameters

    -
    -
    prior_precision : torch.Tensor, optional
    -
    prior precision if should be changed from current prior_precision value
    -
    sigma_noise : [type], optional
    -
    observation noise standard deviation if should be changed
    -
    -

    Returns

    -
    -
    log_marglik : torch.Tensor
    -
     
    -
    +
    -
    -def predictive_samples(self, x, pred_type='glm', n_samples=100) -
    +
    var Kinv
    -

    Sample from the posterior predictive on input data x. -Can be used, for example, for Thompson sampling.

    -

    Parameters

    -
    -
    x : torch.Tensor
    -
    input data (batch_size, input_shape)
    -
    pred_type : {'glm', 'nn'}, default='glm'
    -
    type of posterior predictive, linearized GLM predictive or neural -network sampling predictive. The GLM predictive is consistent with -the curvature approximations used here.
    -
    n_samples : int
    -
    number of samples
    -
    -

    Returns

    -
    -
    samples : torch.Tensor
    -
    samples (n_samples, batch_size, output_shape)
    -
    +
    -
    -def functional_variance(self, Jacs) -
    +
    var posterior_precision
    -

    Compute functional variance for the 'glm' predictive: -f_var[i] = Jacs[i] @ P.inv() @ Jacs[i].T, which is a output x output -predictive covariance matrix. -Mathematically, we have for a single Jacobian -\mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}} -the output covariance matrix - \mathcal{J} P^{-1} \mathcal{J}^T .

    -

    Parameters

    -
    -
    Jacs : torch.Tensor
    -
    Jacobians of model output wrt parameters -(batch, outputs, parameters)
    -
    +

    Return correctly scaled posterior precision that would be constructed +as H[0] @ diag(H[1]) @ H[0].T + self.prior_precision_diag.

    Returns

    -
    f_var : torch.Tensor
    -
    output covariance (batch, outputs, outputs)
    -
    -
    -
    -def sample(self, n_samples=100) -
    -
    -

    Sample from the Laplace posterior approximation, i.e., - \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1}).

    -

    Parameters

    -
    -
    n_samples : int, default=100
    -
    number of samples
    +
    H : tuple(eigenvectors, eigenvalues)
    +
    scaled self.H with temperature and loss factors.
    +
    prior_precision_diag : torch.Tensor
    +
    diagonal prior precision shape parameters to be added to H.
    -
    -def optimize_prior_precision(self, method='marglik', pred_type='glm', n_steps=100, lr=0.1, init_prior_prec=1.0, val_loader=None, loss=<function get_nll>, log_prior_prec_min=-4, log_prior_prec_max=4, grid_size=100, link_approx='probit', n_samples=100, verbose=False, cv_loss_with_var=False) -
    -
    -
    -

    Inherited members

    @@ -603,18 +717,11 @@

    FullLaplace

    - -
  • -

    KronLaplace

    -
  • -
  • -

    DiagLaplace

    -
  • -
  • ParametricLaplace

  • +
  • +

    FullLaplace

    +
  • +
  • +

    KronLaplace

    +
  • +
  • +

    DiagLaplace

    +
  • +
  • +

    LowRankLaplace

    +
  • diff --git a/docs/curvature/asdl.html b/docs/curvature/asdl.html index 23b40a34..ecba76dd 100644 --- a/docs/curvature/asdl.html +++ b/docs/curvature/asdl.html @@ -35,7 +35,7 @@

    Classes

    class AsdlInterface -(model, likelihood, last_layer=False) +(model, likelihood, last_layer=False, subnetwork_indices=None)

    Interface for asdfghjkl backend.

    @@ -47,19 +47,18 @@

    Subclasses

    -

    Static methods

    +

    Methods

    -def jacobians(model, x) +def jacobians(self, x)

    Compute Jacobians \nabla_\theta f(x;\theta) at current parameter \theta using asdfghjkl's gradient per output dimension.

    Parameters

    -
    model : torch.nn.Module
    -
     
    x : torch.Tensor
    input data (batch, input_shape) on compatible device with model.
    @@ -71,9 +70,6 @@

    Returns

    output function (batch, outputs)
    -
    -

    Methods

    -
    def gradients(self, x, y)
    @@ -108,9 +104,43 @@

    Inherited members

    +
    +class AsdlHessian +(model, likelihood, last_layer=False, low_rank=10) +
    +
    +

    Interface for asdfghjkl backend.

    +

    Ancestors

    + +

    Methods

    +
    +
    +def eig_lowrank(self, data_loader) +
    +
    +
    +
    +
    +

    Inherited members

    + +
    class AsdlGGN -(model, likelihood, last_layer=False, stochastic=False) +(model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False)

    Implementation of the GGNInterface using asdfghjkl.

    @@ -184,6 +214,12 @@

    AsdlHessian

    + + +
  • AsdlGGN

  • diff --git a/docs/curvature/backpack.html b/docs/curvature/backpack.html index 0e610d54..1ae69561 100644 --- a/docs/curvature/backpack.html +++ b/docs/curvature/backpack.html @@ -35,7 +35,7 @@

    Classes

    class BackPackInterface -(model, likelihood, last_layer=False) +(model, likelihood, last_layer=False, subnetwork_indices=None)

    Interface for Backpack backend.

    @@ -48,18 +48,16 @@

    Subclasses

  • BackPackEF
  • BackPackGGN
  • -

    Static methods

    +

    Methods

    -def jacobians(model, x) +def jacobians(self, x)

    Compute Jacobians \nabla_{\theta} f(x;\theta) at current parameter \theta using backpack's BatchGrad per output dimension.

    Parameters

    -
    model : torch.nn.Module
    -
     
    x : torch.Tensor
    input data (batch, input_shape) on compatible device with model.
    @@ -71,9 +69,6 @@

    Returns

    output function (batch, outputs)
    -
    -

    Methods

    -
    def gradients(self, x, y)
    @@ -110,7 +105,7 @@

    Inherited members

  • class BackPackGGN -(model, likelihood, last_layer=False, stochastic=False) +(model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False)

    Implementation of the GGNInterface using Backpack.

    @@ -136,7 +131,7 @@

    Inherited members

    class BackPackEF -(model, likelihood, last_layer=False) +(model, likelihood, last_layer=False, subnetwork_indices=None)

    Implementation of EFInterface using Backpack.

    diff --git a/docs/curvature/curvature.html b/docs/curvature/curvature.html index 084432df..645baae7 100644 --- a/docs/curvature/curvature.html +++ b/docs/curvature/curvature.html @@ -35,7 +35,7 @@

    Classes

    class CurvatureInterface -(model, likelihood, last_layer=False) +(model, likelihood, last_layer=False, subnetwork_indices=None)

    Interface to access curvature for a model and corresponding likelihood. @@ -45,12 +45,15 @@

    Classes

    structures, for example, a block-diagonal one.

    Parameters

    -
    model : torch.nn.Module or FeatureExtractor
    +
    model : torch.nn.Module or FeatureExtractor
    torch model (neural network)
    likelihood : {'classification', 'regression'}
     
    last_layer : bool, default=False
    only consider curvature of last layer
    +
    subnetwork_indices : torch.Tensor, default=None
    +
    indices of the vectorized model parameters that define the subnetwork +to apply the Laplace approximation over

    Attributes

    @@ -67,17 +70,15 @@

    Subclasses

  • EFInterface
  • GGNInterface
  • -

    Static methods

    +

    Methods

    -def jacobians(model, x) +def jacobians(self, x)

    Compute Jacobians \nabla_\theta f(x;\theta) at current parameter \theta.

    Parameters

    -
    model : torch.nn.Module
    -
     
    x : torch.Tensor
    input data (batch, input_shape) on compatible device with model.
    @@ -90,15 +91,13 @@

    Returns

    -def last_layer_jacobians(model, x) +def last_layer_jacobians(self, x)

    Compute Jacobians \nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last}) only at current last-layer parameter \theta_{\textrm{last}}.

    Parameters

    -
    model : FeatureExtractor
    -
     
    x : torch.Tensor
     
    @@ -110,9 +109,6 @@

    Returns

    output function (batch, outputs)
    -
    -

    Methods

    -
    def gradients(self, x, y)
    @@ -175,7 +171,7 @@

    Returns

    loss : torch.Tensor
     
    -
    H : Kron
    +
    H : Kron
    Kronecker factored Hessian approximation.
    @@ -204,7 +200,7 @@

    Returns

    class GGNInterface -(model, likelihood, last_layer=False, stochastic=False) +(model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False)

    Generalized Gauss-Newton or Fisher Curvature Interface. @@ -212,12 +208,15 @@

    Returns

    In addition to CurvatureInterface, methods for Jacobians are required by subclasses.

    Parameters

    -
    model : torch.nn.Module or FeatureExtractor
    +
    model : torch.nn.Module or FeatureExtractor
    torch model (neural network)
    likelihood : {'classification', 'regression'}
     
    last_layer : bool, default=False
    only consider curvature of last layer
    +
    subnetwork_indices : torch.Tensor, default=None
    +
    indices of the vectorized model parameters that define the subnetwork +to apply the Laplace approximation over
    stochastic : bool, default=False
    Fisher if stochastic else GGN
    @@ -270,19 +269,22 @@

    Inherited members

    class EFInterface -(model, likelihood, last_layer=False) +(model, likelihood, last_layer=False, subnetwork_indices=None)

    Interface for Empirical Fisher as Hessian approximation. In addition to CurvatureInterface, methods for gradients are required by subclasses.

    Parameters

    -
    model : torch.nn.Module or FeatureExtractor
    +
    model : torch.nn.Module or FeatureExtractor
    torch model (neural network)
    likelihood : {'classification', 'regression'}
     
    last_layer : bool, default=False
    only consider curvature of last layer
    +
    subnetwork_indices : torch.Tensor, default=None
    +
    indices of the vectorized model parameters that define the subnetwork +to apply the Laplace approximation over

    Attributes

    diff --git a/docs/curvature/index.html b/docs/curvature/index.html index 72e1203b..00001e2f 100644 --- a/docs/curvature/index.html +++ b/docs/curvature/index.html @@ -50,7 +50,7 @@

    Classes

    class CurvatureInterface -(model, likelihood, last_layer=False) +(model, likelihood, last_layer=False, subnetwork_indices=None)

    Interface to access curvature for a model and corresponding likelihood. @@ -60,12 +60,15 @@

    Classes

    structures, for example, a block-diagonal one.

    Parameters

    -
    model : torch.nn.Module or FeatureExtractor
    +
    model : torch.nn.Module or FeatureExtractor
    torch model (neural network)
    likelihood : {'classification', 'regression'}
     
    last_layer : bool, default=False
    only consider curvature of last layer
    +
    subnetwork_indices : torch.Tensor, default=None
    +
    indices of the vectorized model parameters that define the subnetwork +to apply the Laplace approximation over

    Attributes

    @@ -82,17 +85,15 @@

    Subclasses

  • EFInterface
  • GGNInterface
  • -

    Static methods

    +

    Methods

    -def jacobians(model, x) +def jacobians(self, x)

    Compute Jacobians \nabla_\theta f(x;\theta) at current parameter \theta.

    Parameters

    -
    model : torch.nn.Module
    -
     
    x : torch.Tensor
    input data (batch, input_shape) on compatible device with model.
    @@ -105,15 +106,13 @@

    Returns

    -def last_layer_jacobians(model, x) +def last_layer_jacobians(self, x)

    Compute Jacobians \nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last}) only at current last-layer parameter \theta_{\textrm{last}}.

    Parameters

    -
    model : FeatureExtractor
    -
     
    x : torch.Tensor
     
    @@ -125,9 +124,6 @@

    Returns

    output function (batch, outputs)
    -
    -

    Methods

    -
    def gradients(self, x, y)
    @@ -190,7 +186,7 @@

    Returns

    loss : torch.Tensor
     
    -
    H : Kron
    +
    H : Kron
    Kronecker factored Hessian approximation.
    @@ -219,7 +215,7 @@

    Returns

    class GGNInterface -(model, likelihood, last_layer=False, stochastic=False) +(model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False)

    Generalized Gauss-Newton or Fisher Curvature Interface. @@ -227,12 +223,15 @@

    Returns

    In addition to CurvatureInterface, methods for Jacobians are required by subclasses.

    Parameters

    -
    model : torch.nn.Module or FeatureExtractor
    +
    model : torch.nn.Module or FeatureExtractor
    torch model (neural network)
    likelihood : {'classification', 'regression'}
     
    last_layer : bool, default=False
    only consider curvature of last layer
    +
    subnetwork_indices : torch.Tensor, default=None
    +
    indices of the vectorized model parameters that define the subnetwork +to apply the Laplace approximation over
    stochastic : bool, default=False
    Fisher if stochastic else GGN
    @@ -285,19 +284,22 @@

    Inherited members

    class EFInterface -(model, likelihood, last_layer=False) +(model, likelihood, last_layer=False, subnetwork_indices=None)

    Interface for Empirical Fisher as Hessian approximation. In addition to CurvatureInterface, methods for gradients are required by subclasses.

    Parameters

    -
    model : torch.nn.Module or FeatureExtractor
    +
    model : torch.nn.Module or FeatureExtractor
    torch model (neural network)
    likelihood : {'classification', 'regression'}
     
    last_layer : bool, default=False
    only consider curvature of last layer
    +
    subnetwork_indices : torch.Tensor, default=None
    +
    indices of the vectorized model parameters that define the subnetwork +to apply the Laplace approximation over

    Attributes

    @@ -356,7 +358,7 @@

    Inherited members

    class BackPackInterface -(model, likelihood, last_layer=False) +(model, likelihood, last_layer=False, subnetwork_indices=None)

    Interface for Backpack backend.

    @@ -369,18 +371,16 @@

    Subclasses

  • BackPackEF
  • BackPackGGN
  • -

    Static methods

    +

    Methods

    -def jacobians(model, x) +def jacobians(self, x)

    Compute Jacobians \nabla_{\theta} f(x;\theta) at current parameter \theta using backpack's BatchGrad per output dimension.

    Parameters

    -
    model : torch.nn.Module
    -
     
    x : torch.Tensor
    input data (batch, input_shape) on compatible device with model.
    @@ -392,9 +392,6 @@

    Returns

    output function (batch, outputs)
    -
    -

    Methods

    -
    def gradients(self, x, y)
    @@ -431,7 +428,7 @@

    Inherited members

    class BackPackGGN -(model, likelihood, last_layer=False, stochastic=False) +(model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False)

    Implementation of the GGNInterface using Backpack.

    @@ -457,7 +454,7 @@

    Inherited members

    class BackPackEF -(model, likelihood, last_layer=False) +(model, likelihood, last_layer=False, subnetwork_indices=None)

    Implementation of EFInterface using Backpack.

    @@ -483,7 +480,7 @@

    Inherited members

    class AsdlInterface -(model, likelihood, last_layer=False) +(model, likelihood, last_layer=False, subnetwork_indices=None)

    Interface for asdfghjkl backend.

    @@ -495,19 +492,18 @@

    Subclasses

    -

    Static methods

    +

    Methods

    -def jacobians(model, x) +def jacobians(self, x)

    Compute Jacobians \nabla_\theta f(x;\theta) at current parameter \theta using asdfghjkl's gradient per output dimension.

    Parameters

    -
    model : torch.nn.Module
    -
     
    x : torch.Tensor
    input data (batch, input_shape) on compatible device with model.
    @@ -519,9 +515,6 @@

    Returns

    output function (batch, outputs)
    -
    -

    Methods

    -
    def gradients(self, x, y)
    @@ -558,7 +551,7 @@

    Inherited members

    class AsdlGGN -(model, likelihood, last_layer=False, stochastic=False) +(model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False)

    Implementation of the GGNInterface using asdfghjkl.

    @@ -608,6 +601,40 @@

    Inherited members

    +
    +class AsdlHessian +(model, likelihood, last_layer=False, low_rank=10) +
    +
    +

    Interface for asdfghjkl backend.

    +

    Ancestors

    + +

    Methods

    +
    +
    +def eig_lowrank(self, data_loader) +
    +
    +
    +
    +
    +

    Inherited members

    + +
    @@ -680,6 +707,12 @@

    AsdlEF

    +
  • +

    AsdlHessian

    + +
  • diff --git a/docs/index.html b/docs/index.html index 4f80e44c..2819a3ce 100644 --- a/docs/index.html +++ b/docs/index.html @@ -29,15 +29,15 @@

    Package laplace

    Laplace

    Main

    -

    The laplace package facilitates the application of Laplace approximations for entire neural networks or just their last layer. +

    The laplace package facilitates the application of Laplace approximations for entire neural networks, subnetworks of neural networks, or just their last layer. The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations. The library documentation is available at https://aleximmer.github.io/Laplace.

    There is also a corresponding paper, Laplace Redux — Effortless Bayesian Deep Learning, which introduces the library, provides an introduction to the Laplace approximation, reviews its use in deep learning, and empirically demonstrates its versatility and competitiveness. Please consider referring to the paper when using our library:

    -
    @article{daxberger2021laplace,
    -  title={Laplace Redux--Effortless Bayesian Deep Learning},
    -  author={Daxberger, Erik and Kristiadi, Agustinus and Immer, Alexander
    -          and Eschenhagen, Runa and Bauer, Matthias and Hennig, Philipp},
    -  journal={arXiv preprint arXiv:2106.14806},
    +
    @inproceedings{laplace2021,
    +  title={Laplace Redux--Effortless {B}ayesian Deep Learning},
    +  author={Erik Daxberger and Agustinus Kristiadi and Alexander Immer 
    +          and Runa Eschenhagen and Matthias Bauer and Philipp Hennig},
    +  booktitle={{N}eur{IPS}},
       year={2021}
     }
     
    @@ -56,34 +56,39 @@

    Setup

    Structure

    The laplace package consists of two main components:

      -
    1. The subclasses of laplace.BaseLaplace that implement different sparsity structures: different subsets of weights ('all' and 'last_layer') and different structures of the Hessian approximation ('full', 'kron', and 'diag'). This results in six currently available options: FullLaplace, KronLaplace, DiagLaplace, and the corresponding last-layer variations FullLLLaplace, KronLLLaplace, -and DiagLLLaplace, which are all subclasses of laplace.LLLaplace. All of these can be conveniently accessed via the laplace.Laplace function.
    2. +
    3. The subclasses of laplace.BaseLaplace that implement different sparsity structures: different subsets of weights ('all', 'subnetwork' and 'last_layer') and different structures of the Hessian approximation ('full', 'kron', 'lowrank' and 'diag'). This results in eight currently available options: FullLaplace, KronLaplace, DiagLaplace, the corresponding last-layer variations FullLLLaplace, KronLLLaplace, +and DiagLLLaplace (which are all subclasses of laplace.LLLaplace), laplace.SubnetLaplace (which only supports a 'full' Hessian approximation) and LowRankLaplace (which only supports inference over 'all' weights). All of these can be conveniently accessed via the laplace.Laplace function.
    4. The backends in laplace.curvature which provide access to Hessian approximations of the corresponding sparsity structures, for example, the diagonal GGN.

    Additionally, the package provides utilities for -decomposing a neural network into feature extractor and last layer for LLLaplace subclasses (laplace.feature_extractor) +decomposing a neural network into feature extractor and last layer for LLLaplace subclasses (laplace.utils.feature_extractor) and -effectively dealing with Kronecker factors (laplace.matrix).

    +effectively dealing with Kronecker factors (laplace.utils.matrix).

    +

    Finally, the package implements several options to select/specify a subnetwork for SubnetLaplace (as subclasses of laplace.utils.subnetmask.SubnetMask). +Automatic subnetwork selection strategies include: uniformly at random (RandomSubnetMask), by largest parameter magnitudes (LargestMagnitudeSubnetMask), and by largest marginal parameter variances (LargestVarianceDiagLaplaceSubnetMask and LargestVarianceSWAGSubnetMask). +In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (ParamNameSubnetMask) or modules (ModuleNameSubnetMask) to perform Laplace inference over.

    Extendability

    To extend the laplace package, new BaseLaplace subclasses can be designed, for example, -a block-diagonal structure or subset-of-weights Laplace. -Alternatively, extending or integrating backends (subclasses of curvature.curvature) allows to provide different Hessian +Laplace with a block-diagonal Hessian structure. +One can also implement custom subnetwork selection strategies as new subclasses of SubnetMask.

    +

    Alternatively, extending or integrating backends (subclasses of curvature.curvature) allows to provide different Hessian approximations to the Laplace approximations. For example, currently the curvature.BackPackInterface based on BackPACK and curvature.AsdlInterface based on ASDL are available. The AsdlInterface provides a Kronecker factored empirical Fisher while the BackPackInterface does not, and only the BackPackInterface provides access to Hessian approximations for a regression (MSELoss) loss function.

    Example usage

    -

    Post-hoc prior precision tuning of last-layer LA

    +

    Post-hoc prior precision tuning of diagonal LA

    In the following example, a pre-trained model is loaded, -then the Laplace approximation is fit to the training data, +then the Laplace approximation is fit to the training data +(using a diagonal Hessian approximation over all parameters), and the prior precision is optimized with cross-validation 'CV'. After that, the resulting LA is used for prediction with the 'probit' predictive for classification.

    from laplace import Laplace
     
    -# pre-trained model
    +# Pre-trained model
     model = load_map_model()  
     
     # User-specified LA flavor
    @@ -97,7 +102,7 @@ 

    Post-hoc prio pred = la(x, link_approx='probit')

    Differentiating the log marginal likelihood w.r.t. hyperparameters

    -

    The marginal likelihood can be used for model selection and is differentiable +

    The marginal likelihood can be used for model selection [10] and is differentiable for continuous hyperparameters like the prior precision or observation noise. Here, we fit the library default, KFAC last-layer LA and differentiate the log marginal likelihood.

    @@ -114,6 +119,41 @@

    Differe ml = la.log_marginal_likelihood(prior_prec, obs_noise) ml.backward()

    +

    Applying the LA over only a subset of the model parameters

    +

    This example shows how to fit the Laplace approximation over only +a subnetwork within a neural network (while keeping all other parameters +fixed at their MAP estimates), as proposed in [11]. It also exemplifies +different ways to specify the subnetwork to perform inference over.

    +
    from laplace import Laplace
    +
    +# Pre-trained model
    +model = load_model()
    +
    +# Examples of different ways to specify the subnetwork
    +# via indices of the vectorized model parameters
    +#
    +# Example 1: select the 128 parameters with the largest magnitude
    +from laplace.utils import LargestMagnitudeSubnetMask
    +subnetwork_mask = LargestMagnitudeSubnetMask(model, n_params_subnet=128)
    +subnetwork_indices = subnetwork_mask.select()
    +
    +# Example 2: specify the layers that define the subnetwork
    +from laplace.utils import ModuleNameSubnetMask
    +subnetwork_mask = ModuleNameSubnetMask(model, module_names=['layer.1', 'layer.3'])
    +subnetwork_mask.select()
    +subnetwork_indices = subnetwork_mask.indices
    +
    +# Example 3: manually define the subnetwork via custom subnetwork indices
    +import torch
    +subnetwork_indices = torch.tensor([0, 4, 11, 42, 123, 2021])
    +
    +# Define and fit subnetwork LA using the specified subnetwork indices
    +la = Laplace(model, 'classification',
    +             subset_of_weights='subnetwork',
    +             hessian_structure='full',
    +             subnetwork_indices=subnetwork_indices)
    +la.fit(train_loader)
    +

    Documentation

    The documentation is available here or can be generated and/or viewed locally:

    # assuming the repository was cloned
    @@ -124,7 +164,7 @@ 

    Documentation

    pdoc --http 0.0.0.0:8080 laplace --template-dir template

    References

    -

    This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1].

    +

    This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1]. Please consider citing the respective papers if you use any of their proposed methods via our laplace library.

    Full example: Optimization of the marginal likelihood and prediction

    Sinusoidal toy data

    @@ -326,10 +368,6 @@

    Sub-modules

    -
    laplace.feature_extractor
    -
    -
    -
    laplace.laplace
    @@ -338,11 +376,11 @@

    Sub-modules

    -
    laplace.matrix
    +
    laplace.subnetlaplace
    -
    laplace.utils
    +
    laplace.utils
    @@ -364,9 +402,9 @@

    Parameters

     
    likelihood : {'classification', 'regression'}
     
    -
    subset_of_weights : {'last_layer', 'all'}, default='last_layer'
    +
    subset_of_weights : {'last_layer', 'subnetwork', 'all'}, default='last_layer'
    subset of weights to consider for inference
    -
    hessian_structure : {'diag', 'kron', 'full'}, default='kron'
    +
    hessian_structure : {'diag', 'kron', 'full', 'lowrank'}, default='kron'
    structure of the Hessian approximation

    Returns

    @@ -636,7 +674,8 @@

    Subclasses

  • DiagLaplace
  • FullLaplace
  • KronLaplace
  • -
  • laplace.lllaplace.LLLaplace
  • +
  • LowRankLaplace
  • +
  • LLLaplace
  • Instance variables

    @@ -697,7 +736,7 @@

    Returns

    Methods

    -def fit(self, train_loader) +def fit(self, train_loader, override=True)

    Fit the local Laplace approximation at the parameters of the model.

    @@ -706,6 +745,45 @@

    Parameters

    train_loader : torch.data.utils.DataLoader
    each iterate is a training batch (X, y); train_loader.dataset needs to be set to access N, size of the data set
    +
    override : bool, default=True
    +
    whether to initialize H, loss, and n_data again; setting to False is useful for +online learning settings to accumulate a sequential posterior approximation.
    +
    + +
    +def square_norm(self, value) +
    +
    +

    Compute the square norm under post. Precision with value-self.mean as 𝛥: + +\Delta^ +op P \Delta + +Returns

    +
    +
    +
    square_form
    +
     
    +
    +
    +
    +def log_prob(self, value, normalized=True) +
    +
    +

    Compute the log probability under the (current) Laplace approximation.

    +

    Parameters

    +
    +
    normalized : bool, default=True
    +
    whether to return log of a properly normalized Gaussian or just the +terms that depend on value.
    +
    +

    Returns

    +
    +
    log_prob : torch.Tensor
    +
     
    @@ -826,6 +904,7 @@

    Ancestors

    Subclasses

    Instance variables

    @@ -869,11 +948,13 @@

    Inherited members

  • log_det_ratio
  • log_likelihood
  • log_marginal_likelihood
  • +
  • log_prob
  • optimize_prior_precision_base
  • predictive_samples
  • prior_precision_diag
  • sample
  • scatter
  • +
  • square_norm
  • @@ -888,7 +969,7 @@

    Inherited members

    Mathematically, we have for each parameter group, e.g., torch.nn.Module, that \P\approx Q \otimes H. See BaseLaplace for the full interface and see -Kron and KronDecomposed for the structure of +Kron and KronDecomposed for the structure of the Kronecker factors. Kron is used to aggregate factors by summing up and KronDecomposed is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. @@ -909,7 +990,7 @@

    Instance variables

    Kronecker factored Posterior precision P.

    Returns

    -
    precision : KronDecomposed
    +
    precision : KronDecomposed
     
    @@ -929,11 +1010,13 @@

    Inherited members

  • log_det_ratio
  • log_likelihood
  • log_marginal_likelihood
  • +
  • log_prob
  • optimize_prior_precision_base
  • predictive_samples
  • prior_precision_diag
  • sample
  • scatter
  • +
  • square_norm
  • @@ -997,11 +1080,80 @@

    Inherited members

  • log_det_ratio
  • log_likelihood
  • log_marginal_likelihood
  • +
  • log_prob
  • optimize_prior_precision_base
  • predictive_samples
  • prior_precision_diag
  • sample
  • scatter
  • +
  • square_norm
  • + + + + +
    +class LowRankLaplace +(model, likelihood, sigma_noise=1, prior_precision=1, prior_mean=0, temperature=1, backend=laplace.curvature.asdl.AsdlHessian, backend_kwargs=None) +
    +
    +

    Laplace approximation with low-rank log likelihood Hessian (approximation). +The low-rank matrix is represented by an eigendecomposition (vecs, values). +Based on the chosen backend, either a true Hessian or, for example, GGN +approximation could be used. +The posterior precision is computed as + P = V diag(l) V^T + P_0. +To sample, compute the functional variance, and log determinant, algebraic tricks +are usedto reduce the costs of inversion to the that of a K +imes K matrix +if we have a rank of K.

    +

    See BaseLaplace for the full interface.

    +

    Ancestors

    + +

    Instance variables

    +
    +
    var V
    +
    +
    +
    +
    var Kinv
    +
    +
    +
    +
    var posterior_precision
    +
    +

    Return correctly scaled posterior precision that would be constructed +as H[0] @ diag(H[1]) @ H[0].T + self.prior_precision_diag.

    +

    Returns

    +
    +
    H : tuple(eigenvectors, eigenvalues)
    +
    scaled self.H with temperature and loss factors.
    +
    prior_precision_diag : torch.Tensor
    +
    diagonal prior precision shape parameters to be added to H.
    +
    +
    +
    +

    Inherited members

    + @@ -1035,7 +1187,7 @@

    Inherited members

    all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.

    Parameters

    -
    model : torch.nn.Module or FeatureExtractor
    +
    model : torch.nn.Module or FeatureExtractor
     
    likelihood : {'classification', 'regression'}
    determines the log likelihood Hessian approximation
    @@ -1092,11 +1244,13 @@

    Inherited members

  • log_det_ratio
  • log_likelihood
  • log_marginal_likelihood
  • +
  • log_prob
  • optimize_prior_precision_base
  • posterior_precision
  • predictive_samples
  • sample
  • scatter
  • +
  • square_norm
  • @@ -1113,30 +1267,36 @@

    Inherited members

    See FullLaplace, LLLaplace, and BaseLaplace for the full interface.

    Ancestors

    Inherited members

    @@ -1151,35 +1311,37 @@

    Inherited members

    Mathematically, we have for the last parameter group, i.e., torch.nn.Linear, that \P\approx Q \otimes H. See KronLaplace, LLLaplace, and BaseLaplace for the full interface and see -Kron and KronDecomposed for the structure of +Kron and KronDecomposed for the structure of the Kronecker factors. Kron is used to aggregate factors by summing up and KronDecomposed is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. Use of damping is possible by initializing or setting damping=True.

    Ancestors

    Inherited members

    @@ -1195,30 +1357,143 @@

    Inherited members

    See DiagLaplace, LLLaplace, and BaseLaplace for the full interface.

    Ancestors

    Inherited members

    +
    +
    +class SubnetLaplace +(model, likelihood, subnetwork_indices, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None) +
    +
    +

    Class for subnetwork Laplace, which computes the Laplace approximation over +just a subset of the model parameters (i.e. a subnetwork within the neural network), +as proposed in [1]. Subnetwork Laplace only supports a full Hessian approximation; other +approximations could be used in theory, but would not make as much sense conceptually.

    +

    A Laplace approximation is represented by a MAP which is given by the +model parameter and a posterior precision or covariance specifying +a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). +Here, only a subset of the model parameters (i.e. a subnetwork of the +neural network) are treated probabilistically. +The goal of this class is to compute the posterior precision P +which sums as + +P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) +\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. + +The prior is assumed to be Gaussian and therefore we have a simple form for +\nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . +In particular, we assume a scalar or diagonal prior precision so that in +all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.

    +

    The subnetwork Laplace approximation only supports a full, i.e., dense, log likelihood +Hessian approximation and hence posterior precision. +Based on the chosen backend +parameter, the full approximation can be, for example, a generalized Gauss-Newton +matrix. +Mathematically, we have P \in \mathbb{R}^{P \times P}. +See FullLaplace and BaseLaplace for the full interface.

    +

    References

    +

    [1] Daxberger, E., Nalisnick, E., Allingham, JU., Antorán, J., Hernández-Lobato, JM. +Bayesian Deep Learning via Subnetwork Inference. +ICML 2021.

    +

    Parameters

    +
    +
    model : torch.nn.Module or FeatureExtractor
    +
     
    +
    likelihood : {'classification', 'regression'}
    +
    determines the log likelihood Hessian approximation
    +
    subnetwork_indices : torch.LongTensor
    +
    indices of the vectorized model parameters +(i.e. torch.nn.utils.parameters_to_vector(model.parameters())) +that define the subnetwork to apply the Laplace approximation over
    +
    sigma_noise : torch.Tensor or float, default=1
    +
    observation noise for the regression setting; must be 1 for classification
    +
    prior_precision : torch.Tensor or float, default=1
    +
    prior precision of a Gaussian prior (= weight decay); +can be scalar, per-layer, or diagonal in the most general case
    +
    prior_mean : torch.Tensor or float, default=0
    +
    prior mean of a Gaussian prior, useful for continual learning
    +
    temperature : float, default=1
    +
    temperature of the likelihood; lower temperature leads to more +concentrated posterior and vice versa.
    +
    backend : subclasses of CurvatureInterface
    +
    backend for access to curvature/Hessian approximations
    +
    backend_kwargs : dict, default=None
    +
    arguments passed to the backend on initialization, for example to +set the number of MC samples for stochastic approximations.
    +
    +

    Ancestors

    + +

    Instance variables

    +
    +
    var prior_precision_diag
    +
    +

    Obtain the diagonal prior precision p_0 constructed from either +a scalar or diagonal prior precision.

    +

    Returns

    +
    +
    prior_precision_diag : torch.Tensor
    +
     
    +
    +
    +
    +

    Inherited members

    + @@ -1234,8 +1509,9 @@

    Index

  • Structure
  • Extendability
  • Example usage
  • Documentation
  • @@ -1262,11 +1538,10 @@

    Index

  • Functions

    @@ -1290,6 +1565,8 @@

    BaseLaplace

    ParametricLaplace

  • diff --git a/docs/laplace.html b/docs/laplace.html index d72602d6..99dae2b8 100644 --- a/docs/laplace.html +++ b/docs/laplace.html @@ -42,9 +42,9 @@

    Parameters

     
    likelihood : {'classification', 'regression'}
     
    -
    subset_of_weights : {'last_layer', 'all'}, default='last_layer'
    +
    subset_of_weights : {'last_layer', 'subnetwork', 'all'}, default='last_layer'
    subset of weights to consider for inference
    -
    hessian_structure : {'diag', 'kron', 'full'}, default='kron'
    +
    hessian_structure : {'diag', 'kron', 'full', 'lowrank'}, default='kron'
    structure of the Hessian approximation

    Returns

    diff --git a/docs/lllaplace.html b/docs/lllaplace.html index 108e9b0b..6ea940b2 100644 --- a/docs/lllaplace.html +++ b/docs/lllaplace.html @@ -33,6 +33,103 @@

    Module laplace.lllaplace

    Classes

    +
    +class LLLaplace +(model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, last_layer_name=None, backend_kwargs=None) +
    +
    +

    Baseclass for all last-layer Laplace approximations in this library. +Subclasses specify the structure of the Hessian approximation. +See BaseLaplace for the full interface.

    +

    A Laplace approximation is represented by a MAP which is given by the +model parameter and a posterior precision or covariance specifying +a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). +Here, only the parameters of the last layer of the neural network +are treated probabilistically. +The goal of this class is to compute the posterior precision P +which sums as + +P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) +\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. + +Every subclass implements different approximations to the log likelihood Hessians, +for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have +a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . +In particular, we assume a scalar or diagonal prior precision so that in +all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.

    +

    Parameters

    +
    +
    model : torch.nn.Module or FeatureExtractor
    +
     
    +
    likelihood : {'classification', 'regression'}
    +
    determines the log likelihood Hessian approximation
    +
    sigma_noise : torch.Tensor or float, default=1
    +
    observation noise for the regression setting; must be 1 for classification
    +
    prior_precision : torch.Tensor or float, default=1
    +
    prior precision of a Gaussian prior (= weight decay); +can be scalar, per-layer, or diagonal in the most general case
    +
    prior_mean : torch.Tensor or float, default=0
    +
    prior mean of a Gaussian prior, useful for continual learning
    +
    temperature : float, default=1
    +
    temperature of the likelihood; lower temperature leads to more +concentrated posterior and vice versa.
    +
    backend : subclasses of CurvatureInterface
    +
    backend for access to curvature/Hessian approximations
    +
    last_layer_name : str, default=None
    +
    name of the model's last layer, if None it will be determined automatically
    +
    backend_kwargs : dict, default=None
    +
    arguments passed to the backend on initialization, for example to +set the number of MC samples for stochastic approximations.
    +
    +

    Ancestors

    + +

    Subclasses

    + +

    Instance variables

    +
    +
    var prior_precision_diag
    +
    +

    Obtain the diagonal prior precision p_0 constructed from either +a scalar or diagonal prior precision.

    +

    Returns

    +
    +
    prior_precision_diag : torch.Tensor
    +
     
    +
    +
    +
    +

    Inherited members

    + +
    class FullLLLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, last_layer_name=None, backend_kwargs=None) @@ -42,33 +139,39 @@

    Classes

    and hence posterior precision. Based on the chosen backend parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have P \in \mathbb{R}^{P \times P}. -See FullLaplace, LLLaplace, and BaseLaplace for the full interface.

    +See FullLaplace, LLLaplace, and BaseLaplace for the full interface.

    Ancestors

    Inherited members

    @@ -82,36 +185,38 @@

    Inherited members

    and hence posterior precision. Mathematically, we have for the last parameter group, i.e., torch.nn.Linear, that \P\approx Q \otimes H. -See KronLaplace, LLLaplace, and BaseLaplace for the full interface and see -Kron and KronDecomposed for the structure of +See KronLaplace, LLLaplace, and BaseLaplace for the full interface and see +Kron and KronDecomposed for the structure of the Kronecker factors. Kron is used to aggregate factors by summing up and KronDecomposed is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. Use of damping is possible by initializing or setting damping=True.

    Ancestors

    Inherited members

    @@ -124,33 +229,39 @@

    Inherited members

    Last-layer Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have P \approx \textrm{diag}(P). -See DiagLaplace, LLLaplace, and BaseLaplace for the full interface.

    +See DiagLaplace, LLLaplace, and BaseLaplace for the full interface.

    Ancestors

    Inherited members

    @@ -172,6 +283,9 @@

    Index

  • Classes

    • +

      LLLaplace

      +
    • +
    • FullLLLaplace

    • diff --git a/docs/regression_example.png b/docs/regression_example.png index c6a94587..94f94c34 100644 Binary files a/docs/regression_example.png and b/docs/regression_example.png differ diff --git a/docs/regression_example_online.png b/docs/regression_example_online.png index 06f66afd..2f30ee5f 100644 Binary files a/docs/regression_example_online.png and b/docs/regression_example_online.png differ diff --git a/docs/subnetlaplace.html b/docs/subnetlaplace.html new file mode 100644 index 00000000..31d93975 --- /dev/null +++ b/docs/subnetlaplace.html @@ -0,0 +1,171 @@ + + + + + + +laplace.subnetlaplace API documentation + + + + + + + + + + + + +
      +
      +
      +

      Module laplace.subnetlaplace

      +
      +
      +
      +
      +
      +
      +
      +
      +
      +
      +

      Classes

      +
      +
      +class SubnetLaplace +(model, likelihood, subnetwork_indices, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None) +
      +
      +

      Class for subnetwork Laplace, which computes the Laplace approximation over +just a subset of the model parameters (i.e. a subnetwork within the neural network), +as proposed in [1]. Subnetwork Laplace only supports a full Hessian approximation; other +approximations could be used in theory, but would not make as much sense conceptually.

      +

      A Laplace approximation is represented by a MAP which is given by the +model parameter and a posterior precision or covariance specifying +a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). +Here, only a subset of the model parameters (i.e. a subnetwork of the +neural network) are treated probabilistically. +The goal of this class is to compute the posterior precision P +which sums as + +P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) +\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. + +The prior is assumed to be Gaussian and therefore we have a simple form for +\nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . +In particular, we assume a scalar or diagonal prior precision so that in +all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.

      +

      The subnetwork Laplace approximation only supports a full, i.e., dense, log likelihood +Hessian approximation and hence posterior precision. +Based on the chosen backend +parameter, the full approximation can be, for example, a generalized Gauss-Newton +matrix. +Mathematically, we have P \in \mathbb{R}^{P \times P}. +See FullLaplace and BaseLaplace for the full interface.

      +

      References

      +

      [1] Daxberger, E., Nalisnick, E., Allingham, JU., Antorán, J., Hernández-Lobato, JM. +Bayesian Deep Learning via Subnetwork Inference. +ICML 2021.

      +

      Parameters

      +
      +
      model : torch.nn.Module or FeatureExtractor
      +
       
      +
      likelihood : {'classification', 'regression'}
      +
      determines the log likelihood Hessian approximation
      +
      subnetwork_indices : torch.LongTensor
      +
      indices of the vectorized model parameters +(i.e. torch.nn.utils.parameters_to_vector(model.parameters())) +that define the subnetwork to apply the Laplace approximation over
      +
      sigma_noise : torch.Tensor or float, default=1
      +
      observation noise for the regression setting; must be 1 for classification
      +
      prior_precision : torch.Tensor or float, default=1
      +
      prior precision of a Gaussian prior (= weight decay); +can be scalar, per-layer, or diagonal in the most general case
      +
      prior_mean : torch.Tensor or float, default=0
      +
      prior mean of a Gaussian prior, useful for continual learning
      +
      temperature : float, default=1
      +
      temperature of the likelihood; lower temperature leads to more +concentrated posterior and vice versa.
      +
      backend : subclasses of CurvatureInterface
      +
      backend for access to curvature/Hessian approximations
      +
      backend_kwargs : dict, default=None
      +
      arguments passed to the backend on initialization, for example to +set the number of MC samples for stochastic approximations.
      +
      +

      Ancestors

      + +

      Instance variables

      +
      +
      var prior_precision_diag
      +
      +

      Obtain the diagonal prior precision p_0 constructed from either +a scalar or diagonal prior precision.

      +

      Returns

      +
      +
      prior_precision_diag : torch.Tensor
      +
       
      +
      +
      +
      +

      Inherited members

      + +
      +
      +
      +
      + +
      + + + \ No newline at end of file diff --git a/docs/feature_extractor.html b/docs/utils/feature_extractor.html similarity index 84% rename from docs/feature_extractor.html rename to docs/utils/feature_extractor.html index 1c7ef071..9599128b 100644 --- a/docs/feature_extractor.html +++ b/docs/utils/feature_extractor.html @@ -4,7 +4,7 @@ -laplace.feature_extractor API documentation +laplace.utils.feature_extractor API documentation @@ -20,7 +20,7 @@
      -

      Module laplace.feature_extractor

      +

      Module laplace.utils.feature_extractor

      @@ -33,7 +33,7 @@

      Module laplace.feature_extractor

      Classes

      -
      +
      class FeatureExtractor (model: torch.nn.modules.module.Module, last_layer_name: Optional[str] = None)
      @@ -61,18 +61,18 @@

      Ancestors

    Class variables

    -
    var dump_patches : bool
    +
    var dump_patches : bool
    -
    var training : bool
    +
    var training : bool

    Methods

    -
    +
    def forward(self, x: torch.Tensor) ‑> torch.Tensor
    @@ -84,7 +84,7 @@

    Parameters

    one batch of data to use as input for the forward pass
    -
    +
    def forward_with_features(self, x: torch.Tensor) ‑> Tuple[torch.Tensor, torch.Tensor]
    @@ -97,7 +97,7 @@

    Parameters

    one batch of data to use as input for the forward pass
  • -
    +
    def set_last_layer(self, last_layer_name: str) ‑> None
    @@ -109,7 +109,7 @@

    Parameters

    the name of the last layer (fixed in model.named_modules()).
    -
    +
    def find_last_layer(self, x: torch.Tensor) ‑> torch.Tensor
    @@ -138,18 +138,18 @@

    Index

    • Super-module

    • Classes

      diff --git a/docs/utils/index.html b/docs/utils/index.html new file mode 100644 index 00000000..2848898a --- /dev/null +++ b/docs/utils/index.html @@ -0,0 +1,1017 @@ + + + + + + +laplace.utils API documentation + + + + + + + + + + + + +
      +
      +
      +

      Module laplace.utils

      +
      +
      +
      +
      +

      Sub-modules

      +
      +
      laplace.utils.feature_extractor
      +
      +
      +
      +
      laplace.utils.matrix
      +
      +
      +
      +
      laplace.utils.subnetmask
      +
      +
      +
      +
      laplace.utils.swag
      +
      +
      +
      +
      laplace.utils.utils
      +
      +
      +
      +
      +
      +
      +
      +
      +

      Functions

      +
      +
      +def get_nll(out_dist, targets) +
      +
      +
      +
      +
      +def validate(laplace, val_loader, pred_type='glm', link_approx='probit', n_samples=100) +
      +
      +
      +
      +
      +def parameters_per_layer(model) +
      +
      +

      Get number of parameters per layer.

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      +

      Returns

      +
      +
      params_per_layer : list[int]
      +
       
      +
      +
      +
      +def invsqrt_precision(M) +
      +
      +

      Compute M^{-0.5} as a tridiagonal matrix.

      +

      Parameters

      +
      +
      M : torch.Tensor
      +
       
      +
      +

      Returns

      +
      +
      M_invsqrt : torch.Tensor
      +
       
      +
      +
      +
      +def kron(t1, t2) +
      +
      +

      Computes the Kronecker product between two tensors.

      +

      Parameters

      +
      +
      t1 : torch.Tensor
      +
       
      +
      t2 : torch.Tensor
      +
       
      +
      +

      Returns

      +
      +
      kron_product : torch.Tensor
      +
       
      +
      +
      +
      +def diagonal_add_scalar(X, value) +
      +
      +

      Add scalar value value to diagonal of X.

      +

      Parameters

      +
      +
      X : torch.Tensor
      +
       
      +
      value : torch.Tensor or float
      +
       
      +
      +

      Returns

      +
      +
      X_add_scalar : torch.Tensor
      +
       
      +
      +
      +
      +def symeig(M) +
      +
      +

      Symetric eigendecomposition avoiding failure cases by +adding and removing jitter to the diagonal.

      +

      Parameters

      +
      +
      M : torch.Tensor
      +
       
      +
      +

      Returns

      +
      +
      L : torch.Tensor
      +
      eigenvalues
      +
      W : torch.Tensor
      +
      eigenvectors
      +
      +
      +
      +def block_diag(blocks) +
      +
      +

      Compose block-diagonal matrix of individual blocks.

      +

      Parameters

      +
      +
      blocks : list[torch.Tensor]
      +
       
      +
      +

      Returns

      +
      +
      M : torch.Tensor
      +
       
      +
      +
      +
      +def expand_prior_precision(prior_prec, model) +
      +
      +

      Expand prior precision to match the shape of the model parameters.

      +

      Parameters

      +
      +
      prior_prec : torch.Tensor 1-dimensional
      +
      prior precision
      +
      model : torch.nn.Module
      +
      torch model with parameters that are regularized by prior_prec
      +
      +

      Returns

      +
      +
      expanded_prior_prec : torch.Tensor
      +
      expanded prior precision has the same shape as model parameters
      +
      +
      +
      +def fit_diagonal_swag_var(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1, lr=0.01, momentum=0.9, weight_decay=0.0003, min_var=1e-30) +
      +
      +

      Fit diagonal SWAG [1], which estimates marginal variances of model parameters by +computing the first and second moment of SGD iterates with a large learning rate.

      +

      Implementation partly adapted from: +- https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py +- https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py

      +

      References

      +

      [1] Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., Wilson, AG. +A Simple Baseline for Bayesian Uncertainty in Deep Learning. +NeurIPS 2019.

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      train_loader : torch.data.utils.DataLoader
      +
      training data loader to use for snapshot collection
      +
      criterion : torch.nn.CrossEntropyLoss or torch.nn.MSELoss
      +
      loss function to use for snapshot collection
      +
      n_snapshots_total : int
      +
      total number of model snapshots to collect
      +
      snapshot_freq : int
      +
      snapshot collection frequency (in epochs)
      +
      lr : float
      +
      SGD learning rate for collecting snapshots
      +
      momentum : float
      +
      SGD momentum
      +
      weight_decay : float
      +
      SGD weight decay
      +
      min_var : float
      +
      minimum parameter variance to clamp to (for numerical stability)
      +
      +

      Returns

      +
      +
      param_variances : torch.Tensor
      +
      vector of marginal variances for each model parameter
      +
      +
      +
      +
      +
      +

      Classes

      +
      +
      +class FeatureExtractor +(model: torch.nn.modules.module.Module, last_layer_name: Optional[str] = None) +
      +
      +

      Feature extractor for a PyTorch neural network. +A wrapper which can return the output of the penultimate layer in addition to +the output of the last layer for each forward pass. If the name of the last +layer is not known, it can determine it automatically. It assumes that the +last layer is linear and that for every forward pass the last layer is the same. +If the name of the last layer is known, it can be passed as a parameter at +initilization; this is the safest way to use this class. +Based on https://gist.github.com/fkodom/27ed045c9051a39102e8bcf4ce31df76.

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
      PyTorch model
      +
      last_layer_name : str, default=None
      +
      if the name of the last layer is already known, otherwise it will +be determined automatically.
      +
      +

      Initializes internal Module state, shared by both nn.Module and ScriptModule.

      +

      Ancestors

      +
        +
      • torch.nn.modules.module.Module
      • +
      +

      Class variables

      +
      +
      var dump_patches : bool
      +
      +
      +
      +
      var training : bool
      +
      +
      +
      +
      +

      Methods

      +
      +
      +def forward(self, x: torch.Tensor) ‑> torch.Tensor +
      +
      +

      Forward pass. If the last layer is not known yet, it will be +determined when this function is called for the first time.

      +

      Parameters

      +
      +
      x : torch.Tensor
      +
      one batch of data to use as input for the forward pass
      +
      +
      +
      +def forward_with_features(self, x: torch.Tensor) ‑> Tuple[torch.Tensor, torch.Tensor] +
      +
      +

      Forward pass which returns the output of the penultimate layer along +with the output of the last layer. If the last layer is not known yet, +it will be determined when this function is called for the first time.

      +

      Parameters

      +
      +
      x : torch.Tensor
      +
      one batch of data to use as input for the forward pass
      +
      +
      +
      +def set_last_layer(self, last_layer_name: str) ‑> None +
      +
      +

      Set the last layer of the model by its name. This sets the forward +hook to get the output of the penultimate layer.

      +

      Parameters

      +
      +
      last_layer_name : str
      +
      the name of the last layer (fixed in model.named_modules()).
      +
      +
      +
      +def find_last_layer(self, x: torch.Tensor) ‑> torch.Tensor +
      +
      +

      Automatically determines the last layer of the model with one +forward pass. It assumes that the last layer is the same for every +forward pass and that it is an instance of torch.nn.Linear. +Might not work with every architecture, but is tested with all PyTorch +torchvision classification models (besides SqueezeNet, which has no +linear last layer).

      +

      Parameters

      +
      +
      x : torch.Tensor
      +
      one batch of data to use as input for the forward pass
      +
      +
      +
      +
      +
      +class Kron +(kfacs) +
      +
      +

      Kronecker factored approximate curvature representation for a corresponding +neural network. +Each element in kfacs is either a tuple or single matrix. +A tuple represents two Kronecker factors Q, and H and a single element +is just a full block Hessian approximation.

      +

      Parameters

      +
      +
      kfacs : list[Tuple]
      +
      each element in the list is a Tuple of two Kronecker factors Q, H +or a single matrix approximating the Hessian (in case of bias, for example)
      +
      +

      Static methods

      +
      +
      +def init_from_model(model, device) +
      +
      +

      Initialize Kronecker factors based on a models architecture.

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      device : torch.device
      +
       
      +
      +

      Returns

      +
      +
      kron : Kron
      +
       
      +
      +
      +
      +

      Methods

      +
      +
      +def decompose(self, damping=False) +
      +
      +

      Eigendecompose Kronecker factors and turn into KronDecomposed. +Parameters

      +
      +
      +
      damping : bool
      +
      use damping
      +
      +

      Returns

      +
      +
      kron_decomposed : KronDecomposed
      +
       
      +
      +
      +
      +def bmm(self, W: torch.Tensor, exponent: float = 1) ‑> torch.Tensor +
      +
      +

      Batched matrix multiplication with the Kronecker factors. +If Kron is H, we compute H @ W. +This is useful for computing the predictive or a regularization +based on Kronecker factors as in continual learning.

      +

      Parameters

      +
      +
      W : torch.Tensor
      +
      matrix (batch, classes, params)
      +
      exponent : float, default=1
      +
      only can be 1 for Kron, requires KronDecomposed for other +exponent values of the Kronecker factors.
      +
      +

      Returns

      +
      +
      SW : torch.Tensor
      +
      result (batch, classes, params)
      +
      +
      +
      +def logdet(self) ‑> torch.Tensor +
      +
      +

      Compute log determinant of the Kronecker factors and sums them up. +This corresponds to the log determinant of the entire Hessian approximation.

      +

      Returns

      +
      +
      logdet : torch.Tensor
      +
       
      +
      +
      +
      +def diag(self) ‑> torch.Tensor +
      +
      +

      Extract diagonal of the entire Kronecker factorization.

      +

      Returns

      +
      +
      diag : torch.Tensor
      +
       
      +
      +
      +
      +def to_matrix(self) ‑> torch.Tensor +
      +
      +

      Make the Kronecker factorization dense by computing the kronecker product. +Warning: this should only be used for testing purposes as it will allocate +large amounts of memory for big architectures.

      +

      Returns

      +
      +
      block_diag : torch.Tensor
      +
       
      +
      +
      +
      +
      +
      +class KronDecomposed +(eigenvectors, eigenvalues, deltas=None, damping=False) +
      +
      +

      Decomposed Kronecker factored approximate curvature representation +for a corresponding neural network. +Each matrix in Kron is decomposed to obtain KronDecomposed. +Front-loading decomposition allows cheap repeated computation +of inverses and log determinants. +In contrast to Kron, we can add scalar or layerwise scalars but +we cannot add other Kron or KronDecomposed anymore.

      +

      Parameters

      +
      +
      eigenvectors : list[Tuple[torch.Tensor]]
      +
      eigenvectors corresponding to matrices in a corresponding Kron
      +
      eigenvalues : list[Tuple[torch.Tensor]]
      +
      eigenvalues corresponding to matrices in a corresponding Kron
      +
      deltas : torch.Tensor
      +
      addend for each group of Kronecker factors representing, for example, +a prior precision
      +
      dampen : bool, default=False
      +
      use dampen approximation mixing prior and Kron partially multiplicatively
      +
      +

      Methods

      +
      +
      +def detach(self) +
      +
      +
      +
      +
      +def logdet(self) ‑> torch.Tensor +
      +
      +

      Compute log determinant of the Kronecker factors and sums them up. +This corresponds to the log determinant of the entire Hessian approximation. +In contrast to Kron.logdet(), additive deltas corresponding to prior +precisions are added.

      +

      Returns

      +
      +
      logdet : torch.Tensor
      +
       
      +
      +
      +
      +def inv_square_form(self, W: torch.Tensor) ‑> torch.Tensor +
      +
      +
      +
      +
      +def bmm(self, W: torch.Tensor, exponent: float = -1) ‑> torch.Tensor +
      +
      +

      Batched matrix multiplication with the decomposed Kronecker factors. +This is useful for computing the predictive or a regularization loss. +Compared to Kron.bmm(), a prior can be added here in form of deltas +and the exponent can be other than just 1. +Computes H^{exponent} W.

      +

      Parameters

      +
      +
      W : torch.Tensor
      +
      matrix (batch, classes, params)
      +
      exponent : float, default=1
      +
       
      +
      +

      Returns

      +
      +
      SW : torch.Tensor
      +
      result (batch, classes, params)
      +
      +
      +
      +def to_matrix(self, exponent: float = 1) ‑> torch.Tensor +
      +
      +

      Make the Kronecker factorization dense by computing the kronecker product. +Warning: this should only be used for testing purposes as it will allocate +large amounts of memory for big architectures.

      +

      Returns

      +
      +
      block_diag : torch.Tensor
      +
       
      +
      +
      +
      +
      +
      +class SubnetMask +(model) +
      +
      +

      Baseclass for all subnetwork masks in this library (for subnetwork Laplace).

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      +

      Subclasses

      + +

      Instance variables

      +
      +
      var indices
      +
      +
      +
      +
      var n_params_subnet
      +
      +
      +
      +
      +

      Methods

      +
      +
      +def convert_subnet_mask_to_indices(self, subnet_mask) +
      +
      +

      Converts a subnetwork mask into subnetwork indices.

      +

      Parameters

      +
      +
      subnet_mask : torch.Tensor
      +
      a binary vector of size (n_params) where 1s locate the subnetwork parameters +within the vectorized model parameters +(i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
      +
      +

      Returns

      +
      +
      subnet_mask_indices : torch.LongTensor
      +
      a vector of indices of the vectorized model parameters +(i.e. torch.nn.utils.parameters_to_vector(model.parameters())) +that define the subnetwork
      +
      +
      +
      +def select(self, train_loader=None) +
      +
      +

      Select the subnetwork mask.

      +

      Parameters

      +
      +
      train_loader : torch.data.utils.DataLoader, default=None
      +
      each iterate is a training batch (X, y); +train_loader.dataset needs to be set to access N, size of the data set
      +
      +

      Returns

      +
      +
      subnet_mask_indices : torch.LongTensor
      +
      a vector of indices of the vectorized model parameters +(i.e. torch.nn.utils.parameters_to_vector(model.parameters())) +that define the subnetwork
      +
      +
      +
      +def get_subnet_mask(self, train_loader) +
      +
      +

      Get the subnetwork mask.

      +

      Parameters

      +
      +
      train_loader : torch.data.utils.DataLoader
      +
      each iterate is a training batch (X, y); +train_loader.dataset needs to be set to access N, size of the data set
      +
      +

      Returns

      +
      +
      subnet_mask : torch.Tensor
      +
      a binary vector of size (n_params) where 1s locate the subnetwork parameters +within the vectorized model parameters +(i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
      +
      +
      +
      +
      +
      +class RandomSubnetMask +(model, n_params_subnet) +
      +
      +

      Subnetwork mask of parameters sampled uniformly at random.

      +

      Ancestors

      +
        +
      • laplace.utils.subnetmask.ScoreBasedSubnetMask
      • +
      • SubnetMask
      • +
      +

      Methods

      +
      +
      +def compute_param_scores(self, train_loader) +
      +
      +
      +
      +
      +

      Inherited members

      + +
      +
      +class LargestMagnitudeSubnetMask +(model, n_params_subnet) +
      +
      +

      Subnetwork mask identifying the parameters with the largest magnitude.

      +

      Ancestors

      +
        +
      • laplace.utils.subnetmask.ScoreBasedSubnetMask
      • +
      • SubnetMask
      • +
      +

      Methods

      +
      +
      +def compute_param_scores(self, train_loader) +
      +
      +
      +
      +
      +

      Inherited members

      + +
      +
      +class LargestVarianceDiagLaplaceSubnetMask +(model, n_params_subnet, diag_laplace_model) +
      +
      +

      Subnetwork mask identifying the parameters with the largest marginal variances +(estimated using a diagonal Laplace approximation over all model parameters).

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      n_params_subnet : int
      +
      number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
      +
      diag_laplace_model : DiagLaplace
      +
      diagonal Laplace model to use for variance estimation
      +
      +

      Ancestors

      +
        +
      • laplace.utils.subnetmask.ScoreBasedSubnetMask
      • +
      • SubnetMask
      • +
      +

      Methods

      +
      +
      +def compute_param_scores(self, train_loader) +
      +
      +
      +
      +
      +

      Inherited members

      + +
      +
      +class LargestVarianceSWAGSubnetMask +(model, n_params_subnet, likelihood='classification', swag_n_snapshots=40, swag_snapshot_freq=1, swag_lr=0.01) +
      +
      +

      Subnetwork mask identifying the parameters with the largest marginal variances +(estimated using diagonal SWAG over all model parameters).

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      n_params_subnet : int
      +
      number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
      +
      likelihood : str
      +
      'classification' or 'regression'
      +
      swag_n_snapshots : int
      +
      number of model snapshots to collect for SWAG
      +
      swag_snapshot_freq : int
      +
      SWAG snapshot collection frequency (in epochs)
      +
      swag_lr : float
      +
      learning rate for SWAG snapshot collection
      +
      +

      Ancestors

      +
        +
      • laplace.utils.subnetmask.ScoreBasedSubnetMask
      • +
      • SubnetMask
      • +
      +

      Methods

      +
      +
      +def compute_param_scores(self, train_loader) +
      +
      +
      +
      +
      +

      Inherited members

      + +
      +
      +class ParamNameSubnetMask +(model, parameter_names) +
      +
      +

      Subnetwork mask corresponding to the specified parameters of the neural network.

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      parameter_names : List[str]
      +
      list of names of the parameters (as in model.named_parameters()) +that define the subnetwork
      +
      +

      Ancestors

      + +

      Methods

      +
      +
      +def get_subnet_mask(self, train_loader) +
      +
      +

      Get the subnetwork mask identifying the specified parameters.

      +
      +
      +

      Inherited members

      + +
      +
      +class ModuleNameSubnetMask +(model, module_names) +
      +
      +

      Subnetwork mask corresponding to the specified modules of the neural network.

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      parameter_names : List[str]
      +
      list of names of the modules (as in model.named_modules()) that define the subnetwork; +the modules cannot have children, i.e. need to be leaf modules
      +
      +

      Ancestors

      + +

      Subclasses

      + +

      Methods

      +
      +
      +def get_subnet_mask(self, train_loader) +
      +
      +

      Get the subnetwork mask identifying the specified modules.

      +
      +
      +

      Inherited members

      + +
      +
      +class LastLayerSubnetMask +(model, last_layer_name=None) +
      +
      +

      Subnetwork mask corresponding to the last layer of the neural network.

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      last_layer_name : str, default=None
      +
      name of the model's last layer, if None it will be determined automatically
      +
      +

      Ancestors

      + +

      Methods

      +
      +
      +def get_subnet_mask(self, train_loader) +
      +
      +

      Get the subnetwork mask identifying the last layer.

      +
      +
      +

      Inherited members

      + +
      +
      +
      +
      + +
      + + + \ No newline at end of file diff --git a/docs/matrix.html b/docs/utils/matrix.html similarity index 75% rename from docs/matrix.html rename to docs/utils/matrix.html index 6323f7cf..caa3e688 100644 --- a/docs/matrix.html +++ b/docs/utils/matrix.html @@ -4,7 +4,7 @@ -laplace.matrix API documentation +laplace.utils.matrix API documentation @@ -20,7 +20,7 @@
      -

      Module laplace.matrix

      +

      Module laplace.utils.matrix

      @@ -33,7 +33,7 @@

      Module laplace.matrix

      Classes

      -
      +
      class Kron (kfacs)
      @@ -51,7 +51,7 @@

      Parameters

      Static methods

      -
      +
      def init_from_model(model, device)
      @@ -65,18 +65,18 @@

      Parameters

      Returns

      -
      kron : Kron
      +
      kron : Kron
       

    Methods

    -
    +
    def decompose(self, damping=False)
    -

    Eigendecompose Kronecker factors and turn into KronDecomposed. +

    Eigendecompose Kronecker factors and turn into KronDecomposed. Parameters


    @@ -85,11 +85,11 @@

    Methods

    Returns

    -
    kron_decomposed : KronDecomposed
    +
    kron_decomposed : KronDecomposed
     
    -
    +
    def bmm(self, W: torch.Tensor, exponent: float = 1) ‑> torch.Tensor
    @@ -102,7 +102,7 @@

    Parameters

    W : torch.Tensor
    matrix (batch, classes, params)
    exponent : float, default=1
    -
    only can be 1 for Kron, requires KronDecomposed for other +
    only can be 1 for Kron, requires KronDecomposed for other exponent values of the Kronecker factors.

    Returns

    @@ -111,7 +111,7 @@

    Returns

    result (batch, classes, params)
    -
    +
    def logdet(self) ‑> torch.Tensor
    @@ -123,7 +123,7 @@

    Returns

     
    -
    +
    def diag(self) ‑> torch.Tensor
    @@ -134,7 +134,7 @@

    Returns

     
    -
    +
    def to_matrix(self) ‑> torch.Tensor
    @@ -149,24 +149,24 @@

    Returns

    -
    +
    class KronDecomposed (eigenvectors, eigenvalues, deltas=None, damping=False)

    Decomposed Kronecker factored approximate curvature representation for a corresponding neural network. -Each matrix in Kron is decomposed to obtain KronDecomposed. +Each matrix in Kron is decomposed to obtain KronDecomposed. Front-loading decomposition allows cheap repeated computation of inverses and log determinants. -In contrast to Kron, we can add scalar or layerwise scalars but -we cannot add other Kron or KronDecomposed anymore.

    +In contrast to Kron, we can add scalar or layerwise scalars but +we cannot add other Kron or KronDecomposed anymore.

    Parameters

    eigenvectors : list[Tuple[torch.Tensor]]
    -
    eigenvectors corresponding to matrices in a corresponding Kron
    +
    eigenvectors corresponding to matrices in a corresponding Kron
    eigenvalues : list[Tuple[torch.Tensor]]
    -
    eigenvalues corresponding to matrices in a corresponding Kron
    +
    eigenvalues corresponding to matrices in a corresponding Kron
    deltas : torch.Tensor
    addend for each group of Kronecker factors representing, for example, a prior precision
    @@ -175,19 +175,19 @@

    Parameters

    Methods

    -
    +
    def detach(self)
    -
    +
    def logdet(self) ‑> torch.Tensor

    Compute log determinant of the Kronecker factors and sums them up. This corresponds to the log determinant of the entire Hessian approximation. -In contrast to Kron.logdet(), additive deltas corresponding to prior +In contrast to Kron.logdet(), additive deltas corresponding to prior precisions are added.

    Returns

    @@ -195,19 +195,19 @@

    Returns

     
    -
    +
    def inv_square_form(self, W: torch.Tensor) ‑> torch.Tensor
    -
    +
    def bmm(self, W: torch.Tensor, exponent: float = -1) ‑> torch.Tensor

    Batched matrix multiplication with the decomposed Kronecker factors. This is useful for computing the predictive or a regularization loss. -Compared to Kron.bmm(), a prior can be added here in form of deltas +Compared to Kron.bmm(), a prior can be added here in form of deltas and the exponent can be other than just 1. Computes H^{exponent} W.

    Parameters

    @@ -223,7 +223,7 @@

    Returns

    result (batch, classes, params)
    -
    +
    def to_matrix(self, exponent: float = 1) ‑> torch.Tensor
    @@ -249,30 +249,30 @@

    Index

    • Super-module

    • Classes

      diff --git a/docs/utils/subnetmask.html b/docs/utils/subnetmask.html new file mode 100644 index 00000000..15781ff4 --- /dev/null +++ b/docs/utils/subnetmask.html @@ -0,0 +1,466 @@ + + + + + + +laplace.utils.subnetmask API documentation + + + + + + + + + + + + +
      +
      +
      +

      Module laplace.utils.subnetmask

      +
      +
      +
      +
      +
      +
      +
      +
      +
      +
      +

      Classes

      +
      +
      +class SubnetMask +(model) +
      +
      +

      Baseclass for all subnetwork masks in this library (for subnetwork Laplace).

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      +

      Subclasses

      + +

      Instance variables

      +
      +
      var indices
      +
      +
      +
      +
      var n_params_subnet
      +
      +
      +
      +
      +

      Methods

      +
      +
      +def convert_subnet_mask_to_indices(self, subnet_mask) +
      +
      +

      Converts a subnetwork mask into subnetwork indices.

      +

      Parameters

      +
      +
      subnet_mask : torch.Tensor
      +
      a binary vector of size (n_params) where 1s locate the subnetwork parameters +within the vectorized model parameters +(i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
      +
      +

      Returns

      +
      +
      subnet_mask_indices : torch.LongTensor
      +
      a vector of indices of the vectorized model parameters +(i.e. torch.nn.utils.parameters_to_vector(model.parameters())) +that define the subnetwork
      +
      +
      +
      +def select(self, train_loader=None) +
      +
      +

      Select the subnetwork mask.

      +

      Parameters

      +
      +
      train_loader : torch.data.utils.DataLoader, default=None
      +
      each iterate is a training batch (X, y); +train_loader.dataset needs to be set to access N, size of the data set
      +
      +

      Returns

      +
      +
      subnet_mask_indices : torch.LongTensor
      +
      a vector of indices of the vectorized model parameters +(i.e. torch.nn.utils.parameters_to_vector(model.parameters())) +that define the subnetwork
      +
      +
      +
      +def get_subnet_mask(self, train_loader) +
      +
      +

      Get the subnetwork mask.

      +

      Parameters

      +
      +
      train_loader : torch.data.utils.DataLoader
      +
      each iterate is a training batch (X, y); +train_loader.dataset needs to be set to access N, size of the data set
      +
      +

      Returns

      +
      +
      subnet_mask : torch.Tensor
      +
      a binary vector of size (n_params) where 1s locate the subnetwork parameters +within the vectorized model parameters +(i.e. torch.nn.utils.parameters_to_vector(model.parameters()))
      +
      +
      +
      +
      +
      +class RandomSubnetMask +(model, n_params_subnet) +
      +
      +

      Subnetwork mask of parameters sampled uniformly at random.

      +

      Ancestors

      +
        +
      • laplace.utils.subnetmask.ScoreBasedSubnetMask
      • +
      • SubnetMask
      • +
      +

      Methods

      +
      +
      +def compute_param_scores(self, train_loader) +
      +
      +
      +
      +
      +

      Inherited members

      + +
      +
      +class LargestMagnitudeSubnetMask +(model, n_params_subnet) +
      +
      +

      Subnetwork mask identifying the parameters with the largest magnitude.

      +

      Ancestors

      +
        +
      • laplace.utils.subnetmask.ScoreBasedSubnetMask
      • +
      • SubnetMask
      • +
      +

      Methods

      +
      +
      +def compute_param_scores(self, train_loader) +
      +
      +
      +
      +
      +

      Inherited members

      + +
      +
      +class LargestVarianceDiagLaplaceSubnetMask +(model, n_params_subnet, diag_laplace_model) +
      +
      +

      Subnetwork mask identifying the parameters with the largest marginal variances +(estimated using a diagonal Laplace approximation over all model parameters).

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      n_params_subnet : int
      +
      number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
      +
      diag_laplace_model : DiagLaplace
      +
      diagonal Laplace model to use for variance estimation
      +
      +

      Ancestors

      +
        +
      • laplace.utils.subnetmask.ScoreBasedSubnetMask
      • +
      • SubnetMask
      • +
      +

      Methods

      +
      +
      +def compute_param_scores(self, train_loader) +
      +
      +
      +
      +
      +

      Inherited members

      + +
      +
      +class LargestVarianceSWAGSubnetMask +(model, n_params_subnet, likelihood='classification', swag_n_snapshots=40, swag_snapshot_freq=1, swag_lr=0.01) +
      +
      +

      Subnetwork mask identifying the parameters with the largest marginal variances +(estimated using diagonal SWAG over all model parameters).

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      n_params_subnet : int
      +
      number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
      +
      likelihood : str
      +
      'classification' or 'regression'
      +
      swag_n_snapshots : int
      +
      number of model snapshots to collect for SWAG
      +
      swag_snapshot_freq : int
      +
      SWAG snapshot collection frequency (in epochs)
      +
      swag_lr : float
      +
      learning rate for SWAG snapshot collection
      +
      +

      Ancestors

      +
        +
      • laplace.utils.subnetmask.ScoreBasedSubnetMask
      • +
      • SubnetMask
      • +
      +

      Methods

      +
      +
      +def compute_param_scores(self, train_loader) +
      +
      +
      +
      +
      +

      Inherited members

      + +
      +
      +class ParamNameSubnetMask +(model, parameter_names) +
      +
      +

      Subnetwork mask corresponding to the specified parameters of the neural network.

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      parameter_names : List[str]
      +
      list of names of the parameters (as in model.named_parameters()) +that define the subnetwork
      +
      +

      Ancestors

      + +

      Methods

      +
      +
      +def get_subnet_mask(self, train_loader) +
      +
      +

      Get the subnetwork mask identifying the specified parameters.

      +
      +
      +

      Inherited members

      + +
      +
      +class ModuleNameSubnetMask +(model, module_names) +
      +
      +

      Subnetwork mask corresponding to the specified modules of the neural network.

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      parameter_names : List[str]
      +
      list of names of the modules (as in model.named_modules()) that define the subnetwork; +the modules cannot have children, i.e. need to be leaf modules
      +
      +

      Ancestors

      + +

      Subclasses

      + +

      Methods

      +
      +
      +def get_subnet_mask(self, train_loader) +
      +
      +

      Get the subnetwork mask identifying the specified modules.

      +
      +
      +

      Inherited members

      + +
      +
      +class LastLayerSubnetMask +(model, last_layer_name=None) +
      +
      +

      Subnetwork mask corresponding to the last layer of the neural network.

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      last_layer_name : str, default=None
      +
      name of the model's last layer, if None it will be determined automatically
      +
      +

      Ancestors

      + +

      Methods

      +
      +
      +def get_subnet_mask(self, train_loader) +
      +
      +

      Get the subnetwork mask identifying the last layer.

      +
      +
      +

      Inherited members

      + +
      +
      +
      +
      + +
      + + + \ No newline at end of file diff --git a/docs/utils/swag.html b/docs/utils/swag.html new file mode 100644 index 00000000..9f1e1843 --- /dev/null +++ b/docs/utils/swag.html @@ -0,0 +1,102 @@ + + + + + + +laplace.utils.swag API documentation + + + + + + + + + + + + +
      +
      +
      +

      Module laplace.utils.swag

      +
      +
      +
      +
      +
      +
      +
      +
      +

      Functions

      +
      +
      +def fit_diagonal_swag_var(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1, lr=0.01, momentum=0.9, weight_decay=0.0003, min_var=1e-30) +
      +
      +

      Fit diagonal SWAG [1], which estimates marginal variances of model parameters by +computing the first and second moment of SGD iterates with a large learning rate.

      +

      Implementation partly adapted from: +- https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py +- https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py

      +

      References

      +

      [1] Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., Wilson, AG. +A Simple Baseline for Bayesian Uncertainty in Deep Learning. +NeurIPS 2019.

      +

      Parameters

      +
      +
      model : torch.nn.Module
      +
       
      +
      train_loader : torch.data.utils.DataLoader
      +
      training data loader to use for snapshot collection
      +
      criterion : torch.nn.CrossEntropyLoss or torch.nn.MSELoss
      +
      loss function to use for snapshot collection
      +
      n_snapshots_total : int
      +
      total number of model snapshots to collect
      +
      snapshot_freq : int
      +
      snapshot collection frequency (in epochs)
      +
      lr : float
      +
      SGD learning rate for collecting snapshots
      +
      momentum : float
      +
      SGD momentum
      +
      weight_decay : float
      +
      SGD weight decay
      +
      min_var : float
      +
      minimum parameter variance to clamp to (for numerical stability)
      +
      +

      Returns

      +
      +
      param_variances : torch.Tensor
      +
      vector of marginal variances for each model parameter
      +
      +
      +
      +
      +
      +
      +
      + +
      + + + \ No newline at end of file diff --git a/docs/utils.html b/docs/utils/utils.html similarity index 84% rename from docs/utils.html rename to docs/utils/utils.html index 633a3565..aa721218 100644 --- a/docs/utils.html +++ b/docs/utils/utils.html @@ -4,7 +4,7 @@ -laplace.utils API documentation +laplace.utils.utils API documentation @@ -20,7 +20,7 @@
      -

      Module laplace.utils

      +

      Module laplace.utils.utils

      @@ -31,19 +31,19 @@

      Module laplace.utils

      Functions

      -
      +
      def get_nll(out_dist, targets)
      -
      +
      def validate(laplace, val_loader, pred_type='glm', link_approx='probit', n_samples=100)
      -
      +
      def parameters_per_layer(model)
      @@ -59,7 +59,7 @@

      Returns

       
    -
    +
    def invsqrt_precision(M)
    @@ -75,7 +75,7 @@

    Returns

     
    -
    +
    def kron(t1, t2)
    @@ -93,7 +93,7 @@

    Returns

     
    -
    +
    def diagonal_add_scalar(X, value)
    @@ -111,7 +111,7 @@

    Returns

     
    -
    +
    def symeig(M)
    @@ -130,7 +130,7 @@

    Returns

    eigenvectors
    -
    +
    def block_diag(blocks)
    @@ -146,7 +146,7 @@

    Returns

     
    -
    +
    def expand_prior_precision(prior_prec, model)
    @@ -160,7 +160,7 @@

    Parameters

    Returns

    -
    expanded_prior_prec : torch.Tensor
    +
    expanded_prior_prec : torch.Tensor
    expanded prior precision has the same shape as model parameters
    @@ -177,20 +177,20 @@

    Index

    diff --git a/laplace/__init__.py b/laplace/__init__.py index fd71fb1b..429b9f0d 100644 --- a/laplace/__init__.py +++ b/laplace/__init__.py @@ -9,6 +9,7 @@ from laplace.baselaplace import BaseLaplace, ParametricLaplace, FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace from laplace.lllaplace import LLLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace +from laplace.subnetlaplace import SubnetLaplace from laplace.laplace import Laplace from laplace.marglik_training import marglik_training @@ -17,4 +18,5 @@ 'FullLaplace', 'KronLaplace', 'DiagLaplace', 'LowRankLaplace', # all-weights 'LLLaplace', # base-class last-layer 'FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace', # last-layer + 'SubnetLaplace', # subnetwork 'marglik_training'] # methods diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 55b66a54..09c05ca3 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -4,8 +4,7 @@ from torch.nn.utils import parameters_to_vector, vector_to_parameters from torch.distributions import MultivariateNormal, Dirichlet, Normal -from laplace.utils import parameters_per_layer, invsqrt_precision, get_nll, validate -from laplace.matrix import Kron +from laplace.utils import parameters_per_layer, invsqrt_precision, get_nll, validate, Kron from laplace.curvature import BackPackGGN, AsdlHessian @@ -594,7 +593,7 @@ def predictive_samples(self, x, pred_type='glm', n_samples=100): @torch.enable_grad() def _glm_predictive_distribution(self, X): - Js, f_mu = self.backend.jacobians(self.model, X) + Js, f_mu = self.backend.jacobians(X) f_var = self.functional_variance(Js) return f_mu.detach(), f_var.detach() @@ -754,7 +753,7 @@ class KronLaplace(ParametricLaplace): Mathematically, we have for each parameter group, e.g., torch.nn.Module, that \\P\\approx Q \\otimes H\\. See `BaseLaplace` for the full interface and see - `laplace.matrix.Kron` and `laplace.matrix.KronDecomposed` for the structure of + `laplace.utils.matrix.Kron` and `laplace.utils.matrix.KronDecomposed` for the structure of the Kronecker factors. `Kron` is used to aggregate factors by summing up and `KronDecomposed` is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. @@ -812,7 +811,7 @@ def posterior_precision(self): Returns ------- - precision : `laplace.matrix.KronDecomposed` + precision : `laplace.utils.matrix.KronDecomposed` """ self._check_H_init() return self.H * self._H_factor + self.prior_precision diff --git a/laplace/curvature/asdl.py b/laplace/curvature/asdl.py index f9900dfd..a49b98c1 100644 --- a/laplace/curvature/asdl.py +++ b/laplace/curvature/asdl.py @@ -9,8 +9,7 @@ from asdfghjkl.gradient import batch_gradient from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface -from laplace.matrix import Kron -from laplace.utils import _is_batchnorm +from laplace.utils import Kron, _is_batchnorm EPS = 1e-6 @@ -19,14 +18,12 @@ class AsdlInterface(CurvatureInterface): """Interface for asdfghjkl backend. """ - @staticmethod - def jacobians(model, x): + def jacobians(self, x): """Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\) using asdfghjkl's gradient per output dimension. Parameters ---------- - model : torch.nn.Module x : torch.Tensor input data `(batch, input_shape)` on compatible device with model. @@ -38,12 +35,15 @@ def jacobians(model, x): output function `(batch, outputs)` """ Js = list() - for i in range(model.output_size): + for i in range(self.model.output_size): def loss_fn(outputs, targets): return outputs[:, i].sum() - f = batch_gradient(model, loss_fn, x, None).detach() - Js.append(_get_batch_grad(model)) + f = batch_gradient(self.model, loss_fn, x, None).detach() + Jk = _get_batch_grad(self.model) + if self.subnetwork_indices is not None: + Jk = Jk[:, self.subnetwork_indices] + Js.append(Jk) Js = torch.stack(Js, dim=1) return Js, f @@ -65,6 +65,8 @@ def gradients(self, x, y): """ f = batch_gradient(self.model, self.lossfunc, x, y).detach() Gs = _get_batch_grad(self._model) + if self.subnetwork_indices is not None: + Gs = Gs[:, self.subnetwork_indices] loss = self.lossfunc(f, y) return Gs, loss @@ -163,10 +165,10 @@ def eig_lowrank(self, data_loader): class AsdlGGN(AsdlInterface, GGNInterface): """Implementation of the `GGNInterface` using asdfghjkl. """ - def __init__(self, model, likelihood, last_layer=False, stochastic=False): + def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False): if likelihood != 'classification': raise ValueError('This backend only supports classification currently.') - super().__init__(model, likelihood, last_layer) + super().__init__(model, likelihood, last_layer, subnetwork_indices) self.stochastic = stochastic @property diff --git a/laplace/curvature/backpack.py b/laplace/curvature/backpack.py index 885ee2b9..8cffc154 100644 --- a/laplace/curvature/backpack.py +++ b/laplace/curvature/backpack.py @@ -5,25 +5,23 @@ from backpack.context import CTX from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface -from laplace.matrix import Kron +from laplace.utils import Kron class BackPackInterface(CurvatureInterface): """Interface for Backpack backend. """ - def __init__(self, model, likelihood, last_layer=False): - super().__init__(model, likelihood, last_layer) + def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): + super().__init__(model, likelihood, last_layer, subnetwork_indices) extend(self._model) extend(self.lossfunc) - @staticmethod - def jacobians(model, x): + def jacobians(self, x): """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\) using backpack's BatchGrad per output dimension. Parameters ---------- - model : torch.nn.Module x : torch.Tensor input data `(batch, input_shape)` on compatible device with model. @@ -34,7 +32,7 @@ def jacobians(model, x): f : torch.Tensor output function `(batch, outputs)` """ - model = extend(model) + model = extend(self.model) to_stack = [] for i in range(model.output_size): model.zero_grad() @@ -49,6 +47,8 @@ def jacobians(model, x): to_cat.append(param.grad_batch.detach().reshape(x.shape[0], -1)) delattr(param, 'grad_batch') Jk = torch.cat(to_cat, dim=1) + if self.subnetwork_indices is not None: + Jk = Jk[:, self.subnetwork_indices] to_stack.append(Jk) if i == 0: f = out.detach() @@ -83,14 +83,16 @@ def gradients(self, x, y): loss.backward() Gs = torch.cat([p.grad_batch.data.flatten(start_dim=1) for p in self._model.parameters()], dim=1) + if self.subnetwork_indices is not None: + Gs = Gs[:, self.subnetwork_indices] return Gs, loss class BackPackGGN(BackPackInterface, GGNInterface): """Implementation of the `GGNInterface` using Backpack. """ - def __init__(self, model, likelihood, last_layer=False, stochastic=False): - super().__init__(model, likelihood, last_layer) + def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False): + super().__init__(model, likelihood, last_layer, subnetwork_indices) self.stochastic = stochastic def _get_diag_ggn(self): diff --git a/laplace/curvature/curvature.py b/laplace/curvature/curvature.py index 96043066..72b0a041 100644 --- a/laplace/curvature/curvature.py +++ b/laplace/curvature/curvature.py @@ -11,11 +11,14 @@ class CurvatureInterface: Parameters ---------- - model : torch.nn.Module or `laplace.feature_extractor.FeatureExtractor` + model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor` torch model (neural network) likelihood : {'classification', 'regression'} last_layer : bool, default=False only consider curvature of last layer + subnetwork_indices : torch.Tensor, default=None + indices of the vectorized model parameters that define the subnetwork + to apply the Laplace approximation over Attributes ---------- @@ -24,11 +27,12 @@ class CurvatureInterface: conversion factor between torch losses and base likelihoods For example, \\(\\frac{1}{2}\\) to get to \\(\\mathcal{N}(f, 1)\\) from MSELoss. """ - def __init__(self, model, likelihood, last_layer=False): + def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): assert likelihood in ['regression', 'classification'] self.likelihood = likelihood self.model = model self.last_layer = last_layer + self.subnetwork_indices = subnetwork_indices if likelihood == 'regression': self.lossfunc = MSELoss(reduction='sum') self.factor = 0.5 @@ -40,13 +44,11 @@ def __init__(self, model, likelihood, last_layer=False): def _model(self): return self.model.last_layer if self.last_layer else self.model - @staticmethod - def jacobians(model, x): + def jacobians(self, x): """Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\). Parameters ---------- - model : torch.nn.Module x : torch.Tensor input data `(batch, input_shape)` on compatible device with model. @@ -59,14 +61,12 @@ def jacobians(model, x): """ raise NotImplementedError - @staticmethod - def last_layer_jacobians(model, x): + def last_layer_jacobians(self, x): """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\) only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\). Parameters ---------- - model : laplace.feature_extractor.FeatureExtractor x : torch.Tensor Returns @@ -76,7 +76,7 @@ def last_layer_jacobians(model, x): f : torch.Tensor output function `(batch, outputs)` """ - f, phi = model.forward_with_features(x) + f, phi = self.model.forward_with_features(x) bsize = phi.shape[0] output_size = f.shape[-1] @@ -84,7 +84,7 @@ def last_layer_jacobians(model, x): identity = torch.eye(output_size, device=x.device).unsqueeze(0).tile(bsize, 1, 1) # Jacobians are batch x output x params Js = torch.einsum('kp,kij->kijp', phi, identity).reshape(bsize, output_size, -1) - if model.last_layer.bias is not None: + if self.model.last_layer.bias is not None: Js = torch.cat([Js, identity], dim=2) return Js, f.detach() @@ -143,7 +143,7 @@ def kron(self, x, y, **kwargs): Returns ------- loss : torch.Tensor - H : `laplace.matrix.Kron` + H : `laplace.utils.matrix.Kron` Kronecker factored Hessian approximation. """ raise NotImplementedError @@ -175,17 +175,20 @@ class GGNInterface(CurvatureInterface): Parameters ---------- - model : torch.nn.Module or `laplace.feature_extractor.FeatureExtractor` + model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor` torch model (neural network) likelihood : {'classification', 'regression'} last_layer : bool, default=False only consider curvature of last layer + subnetwork_indices : torch.Tensor, default=None + indices of the vectorized model parameters that define the subnetwork + to apply the Laplace approximation over stochastic : bool, default=False Fisher if stochastic else GGN """ - def __init__(self, model, likelihood, last_layer=False, stochastic=False): + def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False): self.stochastic = stochastic - super().__init__(model, likelihood, last_layer) + super().__init__(model, likelihood, last_layer, subnetwork_indices) def _get_full_ggn(self, Js, f, y): """Compute full GGN from Jacobians. @@ -237,9 +240,9 @@ def full(self, x, y, **kwargs): raise ValueError('Stochastic approximation not implemented for full GGN.') if self.last_layer: - Js, f = self.last_layer_jacobians(self.model, x) + Js, f = self.last_layer_jacobians(x) else: - Js, f = self.jacobians(self.model, x) + Js, f = self.jacobians(x) loss, H_ggn = self._get_full_ggn(Js, f, y) return loss, H_ggn @@ -251,11 +254,14 @@ class EFInterface(CurvatureInterface): Parameters ---------- - model : torch.nn.Module or `laplace.feature_extractor.FeatureExtractor` + model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor` torch model (neural network) likelihood : {'classification', 'regression'} last_layer : bool, default=False only consider curvature of last layer + subnetwork_indices : torch.Tensor, default=None + indices of the vectorized model parameters that define the subnetwork + to apply the Laplace approximation over Attributes ---------- diff --git a/laplace/laplace.py b/laplace/laplace.py index a006f170..9631f7c9 100644 --- a/laplace/laplace.py +++ b/laplace/laplace.py @@ -10,7 +10,7 @@ def Laplace(model, likelihood, subset_of_weights='last_layer', hessian_structure ---------- model : torch.nn.Module likelihood : {'classification', 'regression'} - subset_of_weights : {'last_layer', 'all'}, default='last_layer' + subset_of_weights : {'last_layer', 'subnetwork', 'all'}, default='last_layer' subset of weights to consider for inference hessian_structure : {'diag', 'kron', 'full', 'lowrank'}, default='kron' structure of the Hessian approximation @@ -20,6 +20,9 @@ def Laplace(model, likelihood, subset_of_weights='last_layer', hessian_structure laplace : ParametricLaplace chosen subclass of ParametricLaplace instantiated with additional arguments """ + if subset_of_weights == 'subnetwork' and hessian_structure != 'full': + raise ValueError('Subnetwork Laplace requires using a full Hessian approximation!') + laplace_map = {subclass._key: subclass for subclass in _all_subclasses(ParametricLaplace) if hasattr(subclass, '_key')} laplace_class = laplace_map[(subset_of_weights, hessian_structure)] diff --git a/laplace/lllaplace.py b/laplace/lllaplace.py index b232c262..73c552df 100644 --- a/laplace/lllaplace.py +++ b/laplace/lllaplace.py @@ -3,9 +3,7 @@ from torch.nn.utils import parameters_to_vector, vector_to_parameters from laplace.baselaplace import ParametricLaplace, FullLaplace, KronLaplace, DiagLaplace -from laplace.feature_extractor import FeatureExtractor - -from laplace.matrix import Kron +from laplace.utils import FeatureExtractor, Kron from laplace.curvature import BackPackGGN @@ -36,7 +34,7 @@ class LLLaplace(ParametricLaplace): Parameters ---------- - model : torch.nn.Module or `laplace.feature_extractor.FeatureExtractor` + model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor` likelihood : {'classification', 'regression'} determines the log likelihood Hessian approximation sigma_noise : torch.Tensor or float, default=1 @@ -117,7 +115,7 @@ def fit(self, train_loader, override=True): self.mean = parameters_to_vector(self.model.last_layer.parameters()).detach() def _glm_predictive_distribution(self, X): - Js, f_mu = self.backend.last_layer_jacobians(self.model, X) + Js, f_mu = self.backend.last_layer_jacobians(X) f_var = self.functional_variance(Js) return f_mu.detach(), f_var.detach() @@ -168,7 +166,7 @@ class KronLLLaplace(LLLaplace, KronLaplace): Mathematically, we have for the last parameter group, i.e., torch.nn.Linear, that \\P\\approx Q \\otimes H\\. See `KronLaplace`, `LLLaplace`, and `BaseLaplace` for the full interface and see - `laplace.matrix.Kron` and `laplace.matrix.KronDecomposed` for the structure of + `laplace.utils.matrix.Kron` and `laplace.utils.matrix.KronDecomposed` for the structure of the Kronecker factors. `Kron` is used to aggregate factors by summing up and `KronDecomposed` is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py new file mode 100644 index 00000000..86178ba6 --- /dev/null +++ b/laplace/subnetlaplace.py @@ -0,0 +1,130 @@ +import torch +from torch.distributions import MultivariateNormal + +from laplace.baselaplace import FullLaplace +from laplace.curvature import BackPackGGN + + +__all__ = ['SubnetLaplace'] + + +class SubnetLaplace(FullLaplace): + """Class for subnetwork Laplace, which computes the Laplace approximation over + just a subset of the model parameters (i.e. a subnetwork within the neural network), + as proposed in [1]. Subnetwork Laplace only supports a full Hessian approximation; other + approximations could be used in theory, but would not make as much sense conceptually. + + A Laplace approximation is represented by a MAP which is given by the + `model` parameter and a posterior precision or covariance specifying + a Gaussian distribution \\(\\mathcal{N}(\\theta_{MAP}, P^{-1})\\). + Here, only a subset of the model parameters (i.e. a subnetwork of the + neural network) are treated probabilistically. + The goal of this class is to compute the posterior precision \\(P\\) + which sums as + \\[ + P = \\sum_{n=1}^N \\nabla^2_\\theta \\log p(\\mathcal{D}_n \\mid \\theta) + \\vert_{\\theta_{MAP}} + \\nabla^2_\\theta \\log p(\\theta) \\vert_{\\theta_{MAP}}. + \\] + The prior is assumed to be Gaussian and therefore we have a simple form for + \\(\\nabla^2_\\theta \\log p(\\theta) \\vert_{\\theta_{MAP}} = P_0 \\). + In particular, we assume a scalar or diagonal prior precision so that in + all cases \\(P_0 = \\textrm{diag}(p_0)\\) and the structure of \\(p_0\\) can be varied. + + The subnetwork Laplace approximation only supports a full, i.e., dense, log likelihood + Hessian approximation and hence posterior precision. Based on the chosen `backend` + parameter, the full approximation can be, for example, a generalized Gauss-Newton + matrix. Mathematically, we have \\(P \\in \\mathbb{R}^{P \\times P}\\). + See `FullLaplace` and `BaseLaplace` for the full interface. + + References + ---------- + [1] Daxberger, E., Nalisnick, E., Allingham, JU., Antorán, J., Hernández-Lobato, JM. + [*Bayesian Deep Learning via Subnetwork Inference*](https://arxiv.org/abs/2010.14689). + ICML 2021. + + Parameters + ---------- + model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor` + likelihood : {'classification', 'regression'} + determines the log likelihood Hessian approximation + subnetwork_indices : torch.LongTensor + indices of the vectorized model parameters + (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`) + that define the subnetwork to apply the Laplace approximation over + sigma_noise : torch.Tensor or float, default=1 + observation noise for the regression setting; must be 1 for classification + prior_precision : torch.Tensor or float, default=1 + prior precision of a Gaussian prior (= weight decay); + can be scalar, per-layer, or diagonal in the most general case + prior_mean : torch.Tensor or float, default=0 + prior mean of a Gaussian prior, useful for continual learning + temperature : float, default=1 + temperature of the likelihood; lower temperature leads to more + concentrated posterior and vice versa. + backend : subclasses of `laplace.curvature.CurvatureInterface` + backend for access to curvature/Hessian approximations + backend_kwargs : dict, default=None + arguments passed to the backend on initialization, for example to + set the number of MC samples for stochastic approximations. + """ + # key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure) + _key = ('subnetwork', 'full') + + def __init__(self, model, likelihood, subnetwork_indices, sigma_noise=1., prior_precision=1., + prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None): + self.H = None + super().__init__(model, likelihood, sigma_noise=sigma_noise, + prior_precision=prior_precision, prior_mean=prior_mean, + temperature=temperature, backend=backend, backend_kwargs=backend_kwargs) + # check validity of subnetwork indices and pass them to backend + self._check_subnetwork_indices(subnetwork_indices) + self.backend.subnetwork_indices = subnetwork_indices + self.n_params_subnet = len(subnetwork_indices) + self._init_H() + + def _init_H(self): + self.H = torch.zeros(self.n_params_subnet, self.n_params_subnet, device=self._device) + + def _check_subnetwork_indices(self, subnetwork_indices): + """Check that subnetwork indices are valid indices of the vectorized model parameters + (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`). + """ + if subnetwork_indices is None: + raise ValueError('Subnetwork indices cannot be None.') + elif not (isinstance(subnetwork_indices, torch.LongTensor) and + subnetwork_indices.numel() > 0 and len(subnetwork_indices.shape) == 1): + raise ValueError('Subnetwork indices must be non-empty 1-dimensional torch.LongTensor.') + elif not (len(subnetwork_indices[subnetwork_indices < 0]) == 0 and + len(subnetwork_indices[subnetwork_indices >= self.n_params]) == 0): + raise ValueError(f'Subnetwork indices must lie between 0 and n_params={self.n_params}.') + elif not (len(subnetwork_indices.unique()) == len(subnetwork_indices)): + raise ValueError('Subnetwork indices must not contain duplicate entries.') + + @property + def prior_precision_diag(self): + """Obtain the diagonal prior precision \\(p_0\\) constructed from either + a scalar or diagonal prior precision. + + Returns + ------- + prior_precision_diag : torch.Tensor + """ + if len(self.prior_precision) == 1: # scalar + return self.prior_precision * torch.ones(self.n_params_subnet, device=self._device) + + elif len(self.prior_precision) == self.n_params_subnet: # diagonal + return self.prior_precision + + else: + raise ValueError('Mismatch of prior and model. Diagonal or scalar prior.') + + def sample(self, n_samples=100): + # sample parameters just of the subnetwork + subnet_mean = self.mean[self.backend.subnetwork_indices] + dist = MultivariateNormal(loc=subnet_mean, scale_tril=self.posterior_scale) + subnet_samples = dist.sample((n_samples,)) + + # set all other parameters to their MAP estimates + full_samples = self.mean.repeat(n_samples, 1) + full_samples[:, self.backend.subnetwork_indices] = subnet_samples + return full_samples diff --git a/laplace/utils/__init__.py b/laplace/utils/__init__.py new file mode 100644 index 00000000..10f559e0 --- /dev/null +++ b/laplace/utils/__init__.py @@ -0,0 +1,14 @@ +from laplace.utils.utils import get_nll, validate, parameters_per_layer, invsqrt_precision, _is_batchnorm, _is_valid_scalar, kron, diagonal_add_scalar, symeig, block_diag, expand_prior_precision +from laplace.utils.feature_extractor import FeatureExtractor +from laplace.utils.matrix import Kron, KronDecomposed +from laplace.utils.swag import fit_diagonal_swag_var +from laplace.utils.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask + + +__all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron', + 'diagonal_add_scalar', 'symeig', 'block_diag', 'expand_prior_precision', + 'FeatureExtractor', + 'Kron', 'KronDecomposed', + 'fit_diagonal_swag_var', + 'SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask', + 'LargestVarianceSWAGSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask'] diff --git a/laplace/feature_extractor.py b/laplace/utils/feature_extractor.py similarity index 100% rename from laplace/feature_extractor.py rename to laplace/utils/feature_extractor.py diff --git a/laplace/matrix.py b/laplace/utils/matrix.py similarity index 99% rename from laplace/matrix.py rename to laplace/utils/matrix.py index 61c07ab5..14a84bfe 100644 --- a/laplace/matrix.py +++ b/laplace/utils/matrix.py @@ -6,6 +6,9 @@ from laplace.utils import _is_valid_scalar, symeig, kron, block_diag +__all__ = ['Kron', 'KronDecomposed'] + + class Kron: """Kronecker factored approximate curvature representation for a corresponding neural network. diff --git a/laplace/utils/subnetmask.py b/laplace/utils/subnetmask.py new file mode 100644 index 00000000..00d73ff4 --- /dev/null +++ b/laplace/utils/subnetmask.py @@ -0,0 +1,359 @@ +from copy import deepcopy + +import torch +from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn.utils import parameters_to_vector + +from laplace.utils import FeatureExtractor, fit_diagonal_swag_var + + +__all__ = ['SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', + 'LargestVarianceDiagLaplaceSubnetMask', 'LargestVarianceSWAGSubnetMask', + 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask'] + + +class SubnetMask: + """Baseclass for all subnetwork masks in this library (for subnetwork Laplace). + + Parameters + ---------- + model : torch.nn.Module + """ + def __init__(self, model): + self.model = model + self.parameter_vector = parameters_to_vector(self.model.parameters()).detach() + self._n_params = len(self.parameter_vector) + self._device = next(self.model.parameters()).device + self._indices = None + self._n_params_subnet = None + + def _check_select(self): + if self._indices is None: + raise AttributeError('Subnetwork mask not selected. Run select() first.') + + @property + def indices(self): + self._check_select() + return self._indices + + @property + def n_params_subnet(self): + if self._n_params_subnet is None: + self._check_select() + self._n_params_subnet = len(self._indices) + return self._n_params_subnet + + def convert_subnet_mask_to_indices(self, subnet_mask): + """Converts a subnetwork mask into subnetwork indices. + + Parameters + ---------- + subnet_mask : torch.Tensor + a binary vector of size (n_params) where 1s locate the subnetwork parameters + within the vectorized model parameters + (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`) + + Returns + ------- + subnet_mask_indices : torch.LongTensor + a vector of indices of the vectorized model parameters + (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`) + that define the subnetwork + """ + if not isinstance(subnet_mask, torch.Tensor): + raise ValueError('Subnetwork mask needs to be torch.Tensor!') + elif subnet_mask.dtype not in [torch.int64, torch.int32, torch.int16, torch.int8, + torch.uint8, torch.bool] or len(subnet_mask.shape) != 1: + raise ValueError( + 'Subnetwork mask needs to be 1-dimensional integral or boolean tensor!') + elif (len(subnet_mask) != self._n_params or len(subnet_mask[subnet_mask == 0]) + + len(subnet_mask[subnet_mask == 1]) != self._n_params): + raise ValueError('Subnetwork mask needs to be a binary vector of' + 'size (n_params) where 1s locate the subnetwork' + 'parameters within the vectorized model parameters' + '(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)!') + + subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0] + return subnet_mask_indices + + def select(self, train_loader=None): + """ Select the subnetwork mask. + + Parameters + ---------- + train_loader : torch.data.utils.DataLoader, default=None + each iterate is a training batch (X, y); + `train_loader.dataset` needs to be set to access \\(N\\), size of the data set + + Returns + ------- + subnet_mask_indices : torch.LongTensor + a vector of indices of the vectorized model parameters + (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`) + that define the subnetwork + """ + if self._indices is not None: + raise ValueError('Subnetwork mask already selected.') + + subnet_mask = self.get_subnet_mask(train_loader) + self._indices = self.convert_subnet_mask_to_indices(subnet_mask) + return self._indices + + def get_subnet_mask(self, train_loader): + """ Get the subnetwork mask. + + Parameters + ---------- + train_loader : torch.data.utils.DataLoader + each iterate is a training batch (X, y); + `train_loader.dataset` needs to be set to access \\(N\\), size of the data set + + Returns + ------- + subnet_mask: torch.Tensor + a binary vector of size (n_params) where 1s locate the subnetwork parameters + within the vectorized model parameters + (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`) + """ + raise NotImplementedError + + +class ScoreBasedSubnetMask(SubnetMask): + """Baseclass for subnetwork masks defined by selecting + the top-scoring parameters according to some criterion. + + Parameters + ---------- + model : torch.nn.Module + n_params_subnet : int + number of parameters in the subnetwork (i.e. number of top-scoring parameters to select) + """ + def __init__(self, model, n_params_subnet): + super().__init__(model) + + if n_params_subnet is None: + raise ValueError( + 'Need to pass number of subnetwork parameters when using subnetwork Laplace.') + if n_params_subnet > self._n_params: + raise ValueError( + f'Subnetwork ({n_params_subnet}) cannot be larger than model ({self._n_params}).') + self._n_params_subnet = n_params_subnet + self._param_scores = None + + def compute_param_scores(self, train_loader): + raise NotImplementedError + + def _check_param_scores(self): + if self._param_scores.shape != self.parameter_vector.shape: + raise ValueError('Parameter scores need to be of same shape as parameter vector.') + + def get_subnet_mask(self, train_loader): + """ Get the subnetwork mask by (descendingly) ranking parameters based on their scores.""" + + if self._param_scores is None: + self._param_scores = self.compute_param_scores(train_loader) + self._check_param_scores() + + idx = torch.argsort(self._param_scores, descending=True)[:self._n_params_subnet] + idx = idx.sort()[0] + subnet_mask = torch.zeros_like(self.parameter_vector).bool() + subnet_mask[idx] = 1 + return subnet_mask + + +class RandomSubnetMask(ScoreBasedSubnetMask): + """Subnetwork mask of parameters sampled uniformly at random.""" + def compute_param_scores(self, train_loader): + return torch.rand_like(self.parameter_vector) + + +class LargestMagnitudeSubnetMask(ScoreBasedSubnetMask): + """Subnetwork mask identifying the parameters with the largest magnitude. """ + def compute_param_scores(self, train_loader): + return self.parameter_vector.abs() + + +class LargestVarianceDiagLaplaceSubnetMask(ScoreBasedSubnetMask): + """Subnetwork mask identifying the parameters with the largest marginal variances + (estimated using a diagonal Laplace approximation over all model parameters). + + Parameters + ---------- + model : torch.nn.Module + n_params_subnet : int + number of parameters in the subnetwork (i.e. number of top-scoring parameters to select) + diag_laplace_model : `laplace.baselaplace.DiagLaplace` + diagonal Laplace model to use for variance estimation + """ + def __init__(self, model, n_params_subnet, diag_laplace_model): + super().__init__(model, n_params_subnet) + self.diag_laplace_model = diag_laplace_model + + def compute_param_scores(self, train_loader): + if train_loader is None: + raise ValueError('Need to pass train loader for subnet selection.') + + self.diag_laplace_model.fit(train_loader) + return self.diag_laplace_model.posterior_variance + + +class LargestVarianceSWAGSubnetMask(ScoreBasedSubnetMask): + """Subnetwork mask identifying the parameters with the largest marginal variances + (estimated using diagonal SWAG over all model parameters). + + Parameters + ---------- + model : torch.nn.Module + n_params_subnet : int + number of parameters in the subnetwork (i.e. number of top-scoring parameters to select) + likelihood : str + 'classification' or 'regression' + swag_n_snapshots : int + number of model snapshots to collect for SWAG + swag_snapshot_freq : int + SWAG snapshot collection frequency (in epochs) + swag_lr : float + learning rate for SWAG snapshot collection + """ + def __init__(self, model, n_params_subnet, likelihood='classification', + swag_n_snapshots=40, swag_snapshot_freq=1, swag_lr=0.01): + super().__init__(model, n_params_subnet) + self.likelihood = likelihood + self.swag_n_snapshots = swag_n_snapshots + self.swag_snapshot_freq = swag_snapshot_freq + self.swag_lr = swag_lr + + def compute_param_scores(self, train_loader): + if train_loader is None: + raise ValueError('Need to pass train loader for subnet selection.') + + if self.likelihood == 'classification': + criterion = CrossEntropyLoss(reduction='mean') + elif self.likelihood == 'regression': + criterion = MSELoss(reduction='mean') + param_variances = fit_diagonal_swag_var(self.model, train_loader, criterion, + n_snapshots_total=self.swag_n_snapshots, + snapshot_freq=self.swag_snapshot_freq, + lr=self.swag_lr) + return param_variances + + +class ParamNameSubnetMask(SubnetMask): + """Subnetwork mask corresponding to the specified parameters of the neural network. + + Parameters + ---------- + model : torch.nn.Module + parameter_names: List[str] + list of names of the parameters (as in `model.named_parameters()`) + that define the subnetwork + """ + def __init__(self, model, parameter_names): + super().__init__(model) + self._parameter_names = parameter_names + self._n_params_subnet = None + + def _check_param_names(self): + param_names = deepcopy(self._parameter_names) + if len(param_names) == 0: + raise ValueError(f'Parameter name list cannot be empty.') + + for name, _ in self.model.named_parameters(): + if name in param_names: + param_names.remove(name) + if len(param_names) > 0: + raise ValueError(f'Parameters {param_names} do not exist in model.') + + def get_subnet_mask(self, train_loader): + """ Get the subnetwork mask identifying the specified parameters.""" + + self._check_param_names() + + subnet_mask_list = [] + for name, param in self.model.named_parameters(): + if name in self._parameter_names: + mask_method = torch.ones_like + else: + mask_method = torch.zeros_like + subnet_mask_list.append(mask_method(parameters_to_vector(param))) + subnet_mask = torch.cat(subnet_mask_list).bool() + return subnet_mask + + +class ModuleNameSubnetMask(SubnetMask): + """Subnetwork mask corresponding to the specified modules of the neural network. + + Parameters + ---------- + model : torch.nn.Module + parameter_names: List[str] + list of names of the modules (as in `model.named_modules()`) that define the subnetwork; + the modules cannot have children, i.e. need to be leaf modules + """ + def __init__(self, model, module_names): + super().__init__(model) + self._module_names = module_names + self._n_params_subnet = None + + def _check_module_names(self): + module_names = deepcopy(self._module_names) + if len(module_names) == 0: + raise ValueError(f'Module name list cannot be empty.') + + for name, module in self.model.named_modules(): + if name in module_names: + if len(list(module.children())) > 0: + raise ValueError(f'Module "{name}" has children, which is not supported.') + elif len(list(module.parameters())) == 0: + raise ValueError(f'Module "{name}" does not have any parameters.') + else: + module_names.remove(name) + if len(module_names) > 0: + raise ValueError(f'Modules {module_names} do not exist in model.') + + def get_subnet_mask(self, train_loader): + """ Get the subnetwork mask identifying the specified modules.""" + + self._check_module_names() + + subnet_mask_list = [] + for name, module in self.model.named_modules(): + if len(list(module.children())) > 0 or len(list(module.parameters())) == 0: + continue + if name in self._module_names: + mask_method = torch.ones_like + else: + mask_method = torch.zeros_like + subnet_mask_list.append(mask_method(parameters_to_vector(module.parameters()))) + subnet_mask = torch.cat(subnet_mask_list).bool() + return subnet_mask + + +class LastLayerSubnetMask(ModuleNameSubnetMask): + """Subnetwork mask corresponding to the last layer of the neural network. + + Parameters + ---------- + model : torch.nn.Module + last_layer_name: str, default=None + name of the model's last layer, if None it will be determined automatically + """ + def __init__(self, model, last_layer_name=None): + super().__init__(model, None) + self._feature_extractor = FeatureExtractor(self.model, last_layer_name=last_layer_name) + self._n_params_subnet = None + + def get_subnet_mask(self, train_loader): + """ Get the subnetwork mask identifying the last layer.""" + + if train_loader is None: + raise ValueError('Need to pass train loader for subnet selection.') + + self._feature_extractor.eval() + if self._feature_extractor.last_layer is None: + X = next(iter(train_loader))[0] + with torch.no_grad(): + self._feature_extractor.find_last_layer(X[:1].to(self._device)) + self._module_names = [self._feature_extractor._last_layer_name] + + return super().get_subnet_mask(train_loader) diff --git a/laplace/utils/swag.py b/laplace/utils/swag.py new file mode 100644 index 00000000..a6aba701 --- /dev/null +++ b/laplace/utils/swag.py @@ -0,0 +1,87 @@ +from copy import deepcopy + +import torch +from torch.nn.utils import parameters_to_vector + + +__all__ = ['fit_diagonal_swag_var'] + + +def _param_vector(model): + return parameters_to_vector(model.parameters()).detach() + + +def fit_diagonal_swag_var(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1, + lr=0.01, momentum=0.9, weight_decay=3e-4, min_var=1e-30): + """ + Fit diagonal SWAG [1], which estimates marginal variances of model parameters by + computing the first and second moment of SGD iterates with a large learning rate. + + Implementation partly adapted from: + - https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py + - https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py + + References + ---------- + [1] Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., Wilson, AG. + [*A Simple Baseline for Bayesian Uncertainty in Deep Learning*](https://arxiv.org/abs/1902.02476). + NeurIPS 2019. + + Parameters + ---------- + model : torch.nn.Module + train_loader : torch.data.utils.DataLoader + training data loader to use for snapshot collection + criterion : torch.nn.CrossEntropyLoss or torch.nn.MSELoss + loss function to use for snapshot collection + n_snapshots_total : int + total number of model snapshots to collect + snapshot_freq : int + snapshot collection frequency (in epochs) + lr : float + SGD learning rate for collecting snapshots + momentum : float + SGD momentum + weight_decay : float + SGD weight decay + min_var : float + minimum parameter variance to clamp to (for numerical stability) + + Returns + ------- + param_variances : torch.Tensor + vector of marginal variances for each model parameter + """ + + # create a copy of the model to avoid undesired changes to the original model parameters + _model = deepcopy(model) + _model.train() + device = next(_model.parameters()).device + + # initialize running estimates of first and second moment of model parameters + mean = torch.zeros_like(_param_vector(_model)) + sq_mean = torch.zeros_like(_param_vector(_model)) + n_snapshots = 0 + + # run SGD to collect model snapshots + optimizer = torch.optim.SGD( + _model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay) + n_epochs = snapshot_freq * n_snapshots_total + for epoch in range(n_epochs): + for inputs, targets in train_loader: + inputs, targets = inputs.to(device), targets.to(device) + optimizer.zero_grad() + loss = criterion(_model(inputs), targets) + loss.backward() + optimizer.step() + + if epoch % snapshot_freq == 0: + # update running estimates of first and second moment of model parameters + old_fac, new_fac = n_snapshots / (n_snapshots + 1), 1 / (n_snapshots + 1) + mean = mean * old_fac + _param_vector(_model) * new_fac + sq_mean = sq_mean * old_fac + _param_vector(_model) ** 2 * new_fac + n_snapshots += 1 + + # compute marginal parameter variances, Var[P] = E[P^2] - E[P]^2 + param_variances = torch.clamp(sq_mean - mean ** 2, min_var) + return param_variances diff --git a/laplace/utils.py b/laplace/utils/utils.py similarity index 97% rename from laplace/utils.py rename to laplace/utils/utils.py index 5b059d31..a00dc2f4 100644 --- a/laplace/utils.py +++ b/laplace/utils/utils.py @@ -8,6 +8,10 @@ from torch.distributions.multivariate_normal import _precision_to_scale_tril +__all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron', + 'diagonal_add_scalar', 'symeig', 'block_diag', 'expand_prior_precision'] + + def get_nll(out_dist, targets): return F.nll_loss(torch.log(out_dist), targets) diff --git a/tests/test_baselaplace.py b/tests/test_baselaplace.py index 75529be8..36fe8a16 100644 --- a/tests/test_baselaplace.py +++ b/tests/test_baselaplace.py @@ -12,7 +12,7 @@ from torchvision.models import wide_resnet50_2 from laplace.laplace import FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace -from laplace.matrix import KronDecomposed +from laplace.utils import KronDecomposed from tests.utils import jacobians_naive diff --git a/tests/test_feature_extractor.py b/tests/test_feature_extractor.py index 37494d76..d3b95ad5 100644 --- a/tests/test_feature_extractor.py +++ b/tests/test_feature_extractor.py @@ -2,7 +2,7 @@ import torch.nn as nn import torchvision.models as models -from laplace.feature_extractor import FeatureExtractor +from laplace.utils import FeatureExtractor class CNN(nn.Module): diff --git a/tests/test_jacobians.py b/tests/test_jacobians.py index 7a5a22ef..13d2466e 100644 --- a/tests/test_jacobians.py +++ b/tests/test_jacobians.py @@ -1,10 +1,9 @@ import pytest import torch from torch import nn -from torch.nn.utils import parameters_to_vector from laplace.curvature import AsdlInterface, BackPackInterface -from laplace.feature_extractor import FeatureExtractor +from laplace.utils import FeatureExtractor from tests.utils import jacobians_naive @@ -35,10 +34,11 @@ def X(): return torch.randn(200, 3) -@pytest.mark.parametrize('backend', [AsdlInterface, BackPackInterface]) -def test_linear_jacobians(linear_model, X, backend): +@pytest.mark.parametrize('backend_cls', [AsdlInterface, BackPackInterface]) +def test_linear_jacobians(linear_model, X, backend_cls): # jacobian of linear model is input X. - Js, f = backend.jacobians(linear_model, X) + backend = backend_cls(linear_model, 'classification') + Js, f = backend.jacobians(X) # into Jacs shape (batch_size, output_size, params) true_Js = X.reshape(len(X), 1, -1) assert true_Js.shape == Js.shape @@ -46,10 +46,11 @@ def test_linear_jacobians(linear_model, X, backend): assert torch.allclose(f, linear_model(X), atol=1e-5) -@pytest.mark.parametrize('backend', [AsdlInterface, BackPackInterface]) -def test_jacobians_singleoutput(singleoutput_model, X, backend): +@pytest.mark.parametrize('backend_cls', [AsdlInterface, BackPackInterface]) +def test_jacobians_singleoutput(singleoutput_model, X, backend_cls): model = singleoutput_model - Js, f = backend.jacobians(model, X) + backend = backend_cls(model, 'classification') + Js, f = backend.jacobians(X) Js_naive, f_naive = jacobians_naive(model, X) assert Js.shape == Js_naive.shape assert torch.abs(Js-Js_naive).max() < 1e-6 @@ -57,10 +58,11 @@ def test_jacobians_singleoutput(singleoutput_model, X, backend): assert torch.allclose(f, f_naive) -@pytest.mark.parametrize('backend', [AsdlInterface, BackPackInterface]) -def test_jacobians_multioutput(multioutput_model, X, backend): +@pytest.mark.parametrize('backend_cls', [AsdlInterface, BackPackInterface]) +def test_jacobians_multioutput(multioutput_model, X, backend_cls): model = multioutput_model - Js, f = backend.jacobians(model, X) + backend = backend_cls(model, 'classification') + Js, f = backend.jacobians(X) Js_naive, f_naive = jacobians_naive(model, X) assert Js.shape == Js_naive.shape assert torch.abs(Js-Js_naive).max() < 1e-6 @@ -68,10 +70,11 @@ def test_jacobians_multioutput(multioutput_model, X, backend): assert torch.allclose(f, f_naive) -@pytest.mark.parametrize('backend', [AsdlInterface, BackPackInterface]) -def test_last_layer_jacobians_singleoutput(singleoutput_model, X, backend): +@pytest.mark.parametrize('backend_cls', [AsdlInterface, BackPackInterface]) +def test_last_layer_jacobians_singleoutput(singleoutput_model, X, backend_cls): model = FeatureExtractor(singleoutput_model) - Js, f = backend.last_layer_jacobians(model, X) + backend = backend_cls(model, 'classification') + Js, f = backend.last_layer_jacobians(X) _, phi = model.forward_with_features(X) Js_naive, f_naive = jacobians_naive(model.last_layer, phi) assert Js.shape == Js_naive.shape @@ -80,10 +83,11 @@ def test_last_layer_jacobians_singleoutput(singleoutput_model, X, backend): assert torch.allclose(f, f_naive) -@pytest.mark.parametrize('backend', [AsdlInterface, BackPackInterface]) -def test_last_layer_jacobians_multioutput(multioutput_model, X, backend): +@pytest.mark.parametrize('backend_cls', [AsdlInterface, BackPackInterface]) +def test_last_layer_jacobians_multioutput(multioutput_model, X, backend_cls): model = FeatureExtractor(multioutput_model) - Js, f = backend.last_layer_jacobians(model, X) + backend = backend_cls(model, 'classification') + Js, f = backend.last_layer_jacobians(X) _, phi = model.forward_with_features(X) Js_naive, f_naive = jacobians_naive(model.last_layer, phi) assert Js.shape == Js_naive.shape diff --git a/tests/test_lllaplace.py b/tests/test_lllaplace.py index ccf581c5..0e6855aa 100644 --- a/tests/test_lllaplace.py +++ b/tests/test_lllaplace.py @@ -8,8 +8,8 @@ from torch.distributions import Normal, Categorical from torchvision.models import wide_resnet50_2 -from laplace.lllaplace import LLLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace -from laplace.feature_extractor import FeatureExtractor +from laplace.lllaplace import FullLLLaplace, KronLLLaplace, DiagLLLaplace +from laplace.utils import FeatureExtractor from tests.utils import jacobians_naive @@ -309,7 +309,7 @@ def test_laplace_functionality(laplace, lh, model, reg_loader, class_loader): Js, f = jacobians_naive(feature_extractor.last_layer, phi) true_f_var = torch.einsum('mkp,pq,mcq->mkc', Js, Sigma, Js) # test last-layer Jacobians - comp_Js, comp_f = lap.backend.last_layer_jacobians(lap.model, X) + comp_Js, comp_f = lap.backend.last_layer_jacobians(X) assert torch.allclose(Js, comp_Js) assert torch.allclose(f, comp_f) comp_f_var = lap.functional_variance(comp_Js) diff --git a/tests/test_matrix.py b/tests/test_matrix.py index fb5bef1e..7c366990 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -4,10 +4,9 @@ from torch import nn from torch.nn.utils import parameters_to_vector -from laplace.matrix import Kron, KronDecomposed +from laplace.utils import Kron, block_diag from laplace.utils import kron as kron_prod from laplace.curvature import BackPackGGN -from laplace.utils import block_diag from tests.utils import get_psd_matrix, jacobians_naive diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py new file mode 100644 index 00000000..f51f5a5e --- /dev/null +++ b/tests/test_subnetlaplace.py @@ -0,0 +1,553 @@ +import pytest +from itertools import product + +import torch +from torch import nn +from torch.nn.utils import parameters_to_vector +from torch.utils.data import DataLoader, TensorDataset +from torchvision.models import wide_resnet50_2 + +from laplace import Laplace, SubnetLaplace +from laplace.baselaplace import DiagLaplace +from laplace.utils import (SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, + LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, + ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask) + + +torch.manual_seed(240) +torch.set_default_tensor_type(torch.DoubleTensor) +score_based_subnet_masks = [RandomSubnetMask, LargestMagnitudeSubnetMask, + LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask] +layer_subnet_masks = [ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask] +all_subnet_masks = score_based_subnet_masks + layer_subnet_masks +likelihoods = ['classification', 'regression'] + + +@pytest.fixture +def model(): + model = torch.nn.Sequential(nn.Linear(3, 20), nn.Linear(20, 2)) + model_params = list(model.parameters()) + setattr(model, 'n_params', len(parameters_to_vector(model_params))) + return model + + +@pytest.fixture +def large_model(): + model = wide_resnet50_2() + return model + + +@pytest.fixture +def class_loader(): + X = torch.randn(10, 3) + y = torch.randint(2, (10,)) + return DataLoader(TensorDataset(X, y), batch_size=3) + + +@pytest.fixture +def reg_loader(): + X = torch.randn(10, 3) + y = torch.randn(10, 2) + return DataLoader(TensorDataset(X, y), batch_size=3) + + +@pytest.mark.parametrize('likelihood', likelihoods) +def test_subnet_laplace_init(model, likelihood): + # use random subnet mask for this test + subnetwork_mask = RandomSubnetMask + subnetmask_kwargs = dict(model=model, n_params_subnet=10) + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select() + + # subnet Laplace with full Hessian should work + hessian_structure = 'full' + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + assert isinstance(lap, SubnetLaplace) + + # subnet Laplace without specifying subnetwork indices should raise an error + with pytest.raises(TypeError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + hessian_structure=hessian_structure) + + # subnet Laplace with diag, kron or lowrank Hessians should raise errors + hessian_structure = 'diag' + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + hessian_structure = 'kron' + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + hessian_structure = 'lowrank' + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + + +@pytest.mark.parametrize('likelihood', likelihoods) +def test_subnet_laplace_large_init(large_model, likelihood): + # use random subnet mask for this test + subnetwork_mask = RandomSubnetMask + n_param_subnet = 10 + subnetmask_kwargs = dict(model=large_model, n_params_subnet=n_param_subnet) + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select() + + lap = Laplace(large_model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, hessian_structure='full') + assert lap.n_params_subnet == n_param_subnet + assert lap.H.shape == (lap.n_params_subnet, lap.n_params_subnet) + H = lap.H.clone() + lap._init_H() + assert torch.allclose(H, lap.H) + + +@pytest.mark.parametrize('likelihood', likelihoods) +def test_custom_subnetwork_indices(model, likelihood, class_loader, reg_loader): + loader = class_loader if likelihood == 'classification' else reg_loader + + # subnetwork indices that are None should raise an error + subnetwork_indices = None + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + + # subnetwork indices that are not PyTorch tensors should raise an error + subnetwork_indices = [0, 5, 11, 42] + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + + # subnetwork indices that are empty tensors should raise an error + subnetwork_indices = torch.LongTensor([]) + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + + # subnetwork indices that are scalar tensors should raise an error + subnetwork_indices = torch.LongTensor(11) + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + + # subnetwork indices that are not 1D PyTorch tensors should raise an error + subnetwork_indices = torch.LongTensor([[0, 5], [11, 42]]) + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + + # subnetwork indices that are double tensors should raise an error + subnetwork_indices = torch.DoubleTensor([0.0, 5.0, 11.0, 42.0]) + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + + # subnetwork indices that are float tensors should raise an error + subnetwork_indices = torch.FloatTensor([0.0, 5.0, 11.0, 42.0]) + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + + # subnetwork indices that are half tensors should raise an error + subnetwork_indices = torch.HalfTensor([0.0, 5.0, 11.0, 42.0]) + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + + # subnetwork indices that are int tensors should raise an error + subnetwork_indices = torch.IntTensor([0, 5, 11, 42]) + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + + # subnetwork indices that are short tensors should raise an error + subnetwork_indices = torch.ShortTensor([0, 5, 11, 42]) + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + + # subnetwork indices that are char tensors should raise an error + subnetwork_indices = torch.CharTensor([0, 5, 11, 42]) + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + + # subnetwork indices that are bool tensors should raise an error + subnetwork_indices = torch.BoolTensor([0, 5, 11, 42]) + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + + # subnetwork indices that contain elements smaller than zero should raise an error + subnetwork_indices = torch.LongTensor([0, -1, -11]) + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + + # subnetwork indices that contain elements larger than n_params should raise an error + subnetwork_indices = torch.LongTensor([model.n_params + 1, model.n_params + 42]) + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + + # subnetwork indices that contain duplicate entries should raise an error + subnetwork_indices = torch.LongTensor([0, 0, 5, 11, 11, 42]) + with pytest.raises(ValueError): + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + + # Non-empty, 1-dimensional torch.LongTensor with valid entries should work + subnetwork_indices = torch.LongTensor([0, 5, 11, 42]) + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, hessian_structure='full') + lap.fit(loader) + assert isinstance(lap, SubnetLaplace) + assert lap.n_params_subnet == 4 + assert lap.H.shape == (4, 4) + assert lap.backend.subnetwork_indices.equal(subnetwork_indices) + + +@pytest.mark.parametrize('subnetwork_mask,likelihood', product(score_based_subnet_masks, likelihoods)) +def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_loader, reg_loader): + loader = class_loader if likelihood == 'classification' else reg_loader + model_params = parameters_to_vector(model.parameters()) + + # set subnetwork mask arguments + if subnetwork_mask == LargestVarianceDiagLaplaceSubnetMask: + diag_laplace_model = DiagLaplace(model, likelihood) + subnetmask_kwargs = dict(model=model, diag_laplace_model=diag_laplace_model) + elif subnetwork_mask == LargestVarianceSWAGSubnetMask: + subnetmask_kwargs = dict(model=model, likelihood=likelihood) + else: + subnetmask_kwargs = dict(model=model) + + # should raise error if we don't pass number of subnet parameters within the subnetmask_kwargs + with pytest.raises(TypeError): + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select(loader) + + # should raise error if we set number of subnet parameters to None + subnetmask_kwargs.update(n_params_subnet=None) + with pytest.raises(ValueError): + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select(loader) + + # should raise error if number of subnet parameters is larger than number of model parameters + subnetmask_kwargs.update(n_params_subnet=99999) + with pytest.raises(ValueError): + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select(loader) + + # define subnetwork mask + n_params_subnet = 32 + subnetmask_kwargs.update(n_params_subnet=n_params_subnet) + subnetmask = subnetwork_mask(**subnetmask_kwargs) + + # should raise error if we try to access the subnet indices before the subnet has been selected + with pytest.raises(AttributeError): + subnetmask.indices + + # select subnet mask + subnetmask.select(loader) + + # should raise error if we try to select the subnet again + with pytest.raises(ValueError): + subnetmask.select(loader) + + # define valid subnet Laplace model + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, hessian_structure='full') + assert isinstance(lap, SubnetLaplace) + + # fit Laplace model + lap.fit(loader) + + # check some parameters + assert subnetmask.indices.equal(lap.backend.subnetwork_indices) + assert subnetmask.n_params_subnet == n_params_subnet + assert lap.n_params_subnet == n_params_subnet + assert parameters_to_vector(model.parameters()).equal(model_params) + + # check that Hessian and prior precision is of correct shape + assert lap.H.shape == (n_params_subnet, n_params_subnet) + assert lap.prior_precision_diag.shape == (n_params_subnet,) + + +@pytest.mark.parametrize('subnetwork_mask,likelihood', product(layer_subnet_masks, likelihoods)) +def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, reg_loader): + loader = class_loader if likelihood == 'classification' else reg_loader + subnetmask_kwargs = dict(model=model) + + # fit last-layer Laplace model + lllap = Laplace(model, likelihood=likelihood, subset_of_weights='last_layer', + hessian_structure='full') + lllap.fit(loader) + + # should raise error if we pass number of subnet parameters + subnetmask_kwargs.update(n_params_subnet=32) + with pytest.raises(TypeError): + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select(loader) + + subnetmask_kwargs = dict(model=model) + if subnetwork_mask == ParamNameSubnetMask: + # should raise error if we pass no parameter name list + subnetmask_kwargs.update() + with pytest.raises(TypeError): + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select(loader) + + # should raise error if we pass an empty parameter name list + subnetmask_kwargs.update(parameter_names=[]) + with pytest.raises(ValueError): + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select(loader) + + # should raise error if we pass a parameter name list with invalid parameter names + subnetmask_kwargs.update(parameter_names=['123']) + with pytest.raises(ValueError): + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select(loader) + + # define last-layer Laplace model by parameter names and check that + # Hessian is identical to that of a full LLLaplace model + subnetmask_kwargs.update(parameter_names=['1.weight', '1.bias']) + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select(loader) + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, hessian_structure='full') + lap.fit(loader) + assert lllap.H.equal(lap.H) + + # define valid parameter name subnet mask + subnetmask_kwargs.update(parameter_names=['0.weight', '1.bias']) + subnetmask = subnetwork_mask(**subnetmask_kwargs) + + # should raise error if we access number of subnet parameters before selecting the subnet + n_params_subnet = 62 + with pytest.raises(AttributeError): + n_params_subnet = subnetmask.n_params_subnet + + # select subnet mask and fit Laplace model + subnetmask.select(loader) + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, hessian_structure='full') + lap.fit(loader) + assert isinstance(lap, SubnetLaplace) + + elif subnetwork_mask == ModuleNameSubnetMask: + # should raise error if we pass no module name list + subnetmask_kwargs.update() + with pytest.raises(TypeError): + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select(loader) + + # should raise error if we pass an empty module name list + subnetmask_kwargs.update(module_names=[]) + with pytest.raises(ValueError): + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select(loader) + + # should raise error if we pass a module name list with invalid module names + subnetmask_kwargs.update(module_names=['123']) + with pytest.raises(ValueError): + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select(loader) + + # define last-layer Laplace model by module name and check that + # Hessian is identical to that of a full LLLaplace model + subnetmask_kwargs.update(module_names=['1']) + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select(loader) + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, hessian_structure='full') + lap.fit(loader) + assert lllap.H.equal(lap.H) + + # define valid parameter name subnet mask + subnetmask_kwargs.update(module_names=['0']) + subnetmask = subnetwork_mask(**subnetmask_kwargs) + + # should raise error if we access number of subnet parameters before selecting the subnet + n_params_subnet = 80 + with pytest.raises(AttributeError): + n_params_subnet = subnetmask.n_params_subnet + + # select subnet mask and fit Laplace model + subnetmask.select(loader) + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, hessian_structure='full') + lap.fit(loader) + assert isinstance(lap, SubnetLaplace) + + elif subnetwork_mask == LastLayerSubnetMask: + # should raise error if we pass invalid last-layer name + subnetmask_kwargs.update(last_layer_name='123') + with pytest.raises(KeyError): + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select(loader) + + # define valid last-layer subnet mask (without passing the last-layer name) + subnetmask_kwargs = dict(model=model) + subnetmask = subnetwork_mask(**subnetmask_kwargs) + + # should raise error if we access number of subnet parameters before selecting the subnet + with pytest.raises(AttributeError): + n_params_subnet = subnetmask.n_params_subnet + + # select subnet mask and fit Laplace model + subnetmask.select(loader) + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, hessian_structure='full') + lap.fit(loader) + assert isinstance(lap, SubnetLaplace) + + # check that Hessian is identical to that of a full LLLaplace model + assert lllap.H.equal(lap.H) + + # define valid last-layer subnet mask (with passing the last-layer name) + subnetmask_kwargs.update(last_layer_name='1') + subnetmask = subnetwork_mask(**subnetmask_kwargs) + + # should raise error if we access number of subnet parameters before selecting the subnet + n_params_subnet = 42 + with pytest.raises(AttributeError): + n_params_subnet = subnetmask.n_params_subnet + + # select subnet mask and fit Laplace model + subnetmask.select(loader) + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, hessian_structure='full') + lap.fit(loader) + assert isinstance(lap, SubnetLaplace) + + # check that Hessian is identical to that of a full LLLaplace model + assert lllap.H.equal(lap.H) + + # check some parameters + assert subnetmask.indices.equal(lap.backend.subnetwork_indices) + assert subnetmask.n_params_subnet == n_params_subnet + assert lap.n_params_subnet == n_params_subnet + + # check that Hessian and prior precision is of correct shape + assert lap.H.shape == (n_params_subnet, n_params_subnet) + assert lap.prior_precision_diag.shape == (n_params_subnet,) + + +@pytest.mark.parametrize('likelihood', likelihoods) +def test_full_subnet_mask(model, likelihood, class_loader, reg_loader): + loader = class_loader if likelihood == 'classification' else reg_loader + + # define full model 'subnet' mask class (i.e. where all parameters are part of the subnet) + class FullSubnetMask(SubnetMask): + def get_subnet_mask(self, train_loader): + return torch.ones(model.n_params).byte() + + # define and fit valid full subnet Laplace model + subnetwork_mask = FullSubnetMask + subnetmask = subnetwork_mask(model=model) + subnetmask.select(loader) + lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, hessian_structure='full') + lap.fit(loader) + assert isinstance(lap, SubnetLaplace) + + # check some parameters + assert subnetmask.indices.equal(torch.tensor(list(range(model.n_params)))) + assert subnetmask.n_params_subnet == model.n_params + assert lap.n_params_subnet == model.n_params + + # check that the Hessian is identical to that of a all-weights FullLaplace model + full_lap = Laplace(model, likelihood=likelihood, subset_of_weights='all', + hessian_structure='full') + full_lap.fit(loader) + assert full_lap.H.equal(lap.H) + + +@pytest.mark.parametrize('subnetwork_mask', all_subnet_masks) +def test_regression_predictive(model, reg_loader, subnetwork_mask): + subnetmask_kwargs = dict(model=model) + if subnetwork_mask in score_based_subnet_masks: + subnetmask_kwargs.update(n_params_subnet=32) + if subnetwork_mask == LargestVarianceSWAGSubnetMask: + subnetmask_kwargs.update(likelihood='regression') + elif subnetwork_mask == LargestVarianceDiagLaplaceSubnetMask: + diag_laplace_model = DiagLaplace(model, 'regression') + subnetmask_kwargs.update(diag_laplace_model=diag_laplace_model) + elif subnetwork_mask == ParamNameSubnetMask: + subnetmask_kwargs.update(parameter_names=['0.weight', '1.bias']) + elif subnetwork_mask == ModuleNameSubnetMask: + subnetmask_kwargs.update(module_names=['0']) + + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select(reg_loader) + lap = Laplace(model, likelihood='regression', subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, hessian_structure='full') + assert isinstance(lap, SubnetLaplace) + + lap.fit(reg_loader) + X, _ = reg_loader.dataset.tensors + f = model(X) + + # error + with pytest.raises(ValueError): + lap(X, pred_type='linear') + + # GLM predictive + f_mu, f_var = lap(X, pred_type='glm') + assert torch.allclose(f_mu, f) + assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]]) + assert len(f_mu) == len(X) + + # NN predictive (only diagonal variance estimation) + f_mu, f_var = lap(X, pred_type='nn') + assert f_mu.shape == f_var.shape + assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1]]) + assert len(f_mu) == len(X) + + +@pytest.mark.parametrize('subnetwork_mask', all_subnet_masks) +def test_classification_predictive(model, class_loader, subnetwork_mask): + subnetmask_kwargs = dict(model=model) + if subnetwork_mask in score_based_subnet_masks: + subnetmask_kwargs.update(n_params_subnet=32) + if subnetwork_mask == LargestVarianceSWAGSubnetMask: + subnetmask_kwargs.update(likelihood='classification') + elif subnetwork_mask == LargestVarianceDiagLaplaceSubnetMask: + diag_laplace_model = DiagLaplace(model, 'classification') + subnetmask_kwargs.update(diag_laplace_model=diag_laplace_model) + elif subnetwork_mask == ParamNameSubnetMask: + subnetmask_kwargs.update(parameter_names=['0.weight', '1.bias']) + elif subnetwork_mask == ModuleNameSubnetMask: + subnetmask_kwargs.update(module_names=['0']) + + subnetmask = subnetwork_mask(**subnetmask_kwargs) + subnetmask.select(class_loader) + lap = Laplace(model, likelihood='classification', subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, hessian_structure='full') + assert isinstance(lap, SubnetLaplace) + + lap.fit(class_loader) + X, _ = class_loader.dataset.tensors + f = torch.softmax(model(X), dim=-1) + + # error + with pytest.raises(ValueError): + lap(X, pred_type='linear') + + # GLM predictive + f_pred = lap(X, pred_type='glm', link_approx='mc', n_samples=100) + assert f_pred.shape == f.shape + assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1 + f_pred = lap(X, pred_type='glm', link_approx='probit') + assert f_pred.shape == f.shape + assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1 + f_pred = lap(X, pred_type='glm', link_approx='bridge') + assert f_pred.shape == f.shape + assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1 + + # NN predictive + f_pred = lap(X, pred_type='nn', n_samples=100) + assert f_pred.shape == f.shape + assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1