Skip to content

Commit

Permalink
Merge pull request #225 from aleximmer/asdfghjkl
Browse files Browse the repository at this point in the history
Make importing `asdfghjkl` conditional to the installed optional dependency
  • Loading branch information
wiseodd authored Aug 20, 2024
2 parents 3fdda9a + 7968c66 commit 4e7b9cf
Show file tree
Hide file tree
Showing 19 changed files with 854 additions and 244 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install pytest-mock
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -e .
- name: Test without the old asdfghjkl
run: |
pip uninstall -y asdfghjkl-old
pytest tests
- name: Test with pytest
run: |
pip install pytest
pip install pytest-mock
pip install git+https://[email protected]/wiseodd/asdl@asdfghjkl
pytest tests
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ The [code](https://github.com/runame/laplace-redux) to reproduce the experiments
To install laplace with `pip`, run the following:

```bash
pip install git+https://github.com/aleximmer/laplace@update-deps
pip install laplace-torch
```

> [!WARNING]
> The ASDL dependency has recently been updated and it breaks the compatibility with
> `laplace-torch`. Please _do not_ install `laplace-torch` from `pip` or the `main`
> branch. Instead install it as above, from the `update-deps` branch.
> We're actively fixing this issue.
Additionally, if you want to use the `asdfghjkl` backend, please install it via:

```bash
pip install git+https://[email protected]/wiseodd/asdl@asdfghjkl
```

For development purposes, e.g. if you would like to make contributions,
clone the repository and then install:
Expand Down
219 changes: 217 additions & 2 deletions docs/baselaplace.html

Large diffs are not rendered by default.

182 changes: 0 additions & 182 deletions docs/curvature/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -556,169 +556,6 @@ <h3>Inherited members</h3>
</li>
</ul>
</dd>
<dt id="laplace.curvature.AsdfghjklInterface"><code class="flex name class">
<span>class <span class="ident">AsdfghjklInterface</span></span>
<span>(</span><span>model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')</span>
</code></dt>
<dd>
<div class="desc"><p>Interface for asdfghjkl backend.</p></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li><a title="laplace.curvature.curvature.CurvatureInterface" href="curvature.html#laplace.curvature.curvature.CurvatureInterface">CurvatureInterface</a></li>
</ul>
<h3>Subclasses</h3>
<ul class="hlist">
<li><a title="laplace.curvature.asdfghjkl.AsdfghjklEF" href="asdfghjkl.html#laplace.curvature.asdfghjkl.AsdfghjklEF">AsdfghjklEF</a></li>
<li><a title="laplace.curvature.asdfghjkl.AsdfghjklGGN" href="asdfghjkl.html#laplace.curvature.asdfghjkl.AsdfghjklGGN">AsdfghjklGGN</a></li>
<li><a title="laplace.curvature.asdfghjkl.AsdfghjklHessian" href="asdfghjkl.html#laplace.curvature.asdfghjkl.AsdfghjklHessian">AsdfghjklHessian</a></li>
</ul>
<h3>Methods</h3>
<dl>
<dt id="laplace.curvature.AsdfghjklInterface.jacobians"><code class="name flex">
<span>def <span class="ident">jacobians</span></span>(<span>self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], enable_backprop: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]</span>
</code></dt>
<dd>
<div class="desc"><p>Compute Jacobians <span><span class="MathJax_Preview">\nabla_\theta f(x;\theta)</span><script type="math/tex">\nabla_\theta f(x;\theta)</script></span> at current parameter <span><span class="MathJax_Preview">\theta</span><script type="math/tex">\theta</script></span>
using asdfghjkl's gradient per output dimension.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>x</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>input data <code>(batch, input_shape)</code> on compatible device with model.</dd>
<dt><strong><code>enable_backprop</code></strong> :&ensp;<code>bool</code>, default <code>= False</code></dt>
<dd>whether to enable backprop through the Js and f w.r.t. x</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>Js</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>Jacobians <code>(batch, parameters, outputs)</code></dd>
<dt><strong><code>f</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>output function <code>(batch, outputs)</code></dd>
</dl></div>
</dd>
<dt id="laplace.curvature.AsdfghjklInterface.gradients"><code class="name flex">
<span>def <span class="ident">gradients</span></span>(<span>self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor) ‑> tuple[torch.Tensor, torch.Tensor]</span>
</code></dt>
<dd>
<div class="desc"><p>Compute gradients <span><span class="MathJax_Preview">\nabla_\theta \ell(f(x;\theta, y)</span><script type="math/tex">\nabla_\theta \ell(f(x;\theta, y)</script></span> at current parameter
<span><span class="MathJax_Preview">\theta</span><script type="math/tex">\theta</script></span> using asdfghjkl's backend.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>x</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>input data <code>(batch, input_shape)</code> on compatible device with model.</dd>
<dt><strong><code>y</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>loss</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
<dt><strong><code>Gs</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>gradients <code>(batch, parameters)</code></dd>
</dl></div>
</dd>
</dl>
<h3>Inherited members</h3>
<ul class="hlist">
<li><code><b><a title="laplace.curvature.curvature.CurvatureInterface" href="curvature.html#laplace.curvature.curvature.CurvatureInterface">CurvatureInterface</a></b></code>:
<ul class="hlist">
<li><code><a title="laplace.curvature.curvature.CurvatureInterface.diag" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.diag">diag</a></code></li>
<li><code><a title="laplace.curvature.curvature.CurvatureInterface.full" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.full">full</a></code></li>
<li><code><a title="laplace.curvature.curvature.CurvatureInterface.functorch_jacobians" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.functorch_jacobians">functorch_jacobians</a></code></li>
<li><code><a title="laplace.curvature.curvature.CurvatureInterface.kron" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.kron">kron</a></code></li>
<li><code><a title="laplace.curvature.curvature.CurvatureInterface.last_layer_jacobians" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.last_layer_jacobians">last_layer_jacobians</a></code></li>
</ul>
</li>
</ul>
</dd>
<dt id="laplace.curvature.AsdfghjklGGN"><code class="flex name class">
<span>class <span class="ident">AsdfghjklGGN</span></span>
<span>(</span><span>model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', stochastic: bool = False)</span>
</code></dt>
<dd>
<div class="desc"><p>Implementation of the <code><a title="laplace.curvature.GGNInterface" href="#laplace.curvature.GGNInterface">GGNInterface</a></code> using asdfghjkl.</p></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface" href="asdfghjkl.html#laplace.curvature.asdfghjkl.AsdfghjklInterface">AsdfghjklInterface</a></li>
<li><a title="laplace.curvature.curvature.GGNInterface" href="curvature.html#laplace.curvature.curvature.GGNInterface">GGNInterface</a></li>
<li><a title="laplace.curvature.curvature.CurvatureInterface" href="curvature.html#laplace.curvature.curvature.CurvatureInterface">CurvatureInterface</a></li>
</ul>
<h3>Inherited members</h3>
<ul class="hlist">
<li><code><b><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface" href="asdfghjkl.html#laplace.curvature.asdfghjkl.AsdfghjklInterface">AsdfghjklInterface</a></b></code>:
<ul class="hlist">
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.diag" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.diag">diag</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.full" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.full">full</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.functorch_jacobians" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.functorch_jacobians">functorch_jacobians</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.gradients" href="asdfghjkl.html#laplace.curvature.asdfghjkl.AsdfghjklInterface.gradients">gradients</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.jacobians" href="asdfghjkl.html#laplace.curvature.asdfghjkl.AsdfghjklInterface.jacobians">jacobians</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.kron" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.kron">kron</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.last_layer_jacobians" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.last_layer_jacobians">last_layer_jacobians</a></code></li>
</ul>
</li>
</ul>
</dd>
<dt id="laplace.curvature.AsdfghjklEF"><code class="flex name class">
<span>class <span class="ident">AsdfghjklEF</span></span>
<span>(</span><span>model: nn.Module, likelihood: Likelihood | None, last_layer: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')</span>
</code></dt>
<dd>
<div class="desc"><p>Implementation of the <code><a title="laplace.curvature.EFInterface" href="#laplace.curvature.EFInterface">EFInterface</a></code> using asdfghjkl.</p></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface" href="asdfghjkl.html#laplace.curvature.asdfghjkl.AsdfghjklInterface">AsdfghjklInterface</a></li>
<li><a title="laplace.curvature.curvature.EFInterface" href="curvature.html#laplace.curvature.curvature.EFInterface">EFInterface</a></li>
<li><a title="laplace.curvature.curvature.CurvatureInterface" href="curvature.html#laplace.curvature.curvature.CurvatureInterface">CurvatureInterface</a></li>
</ul>
<h3>Inherited members</h3>
<ul class="hlist">
<li><code><b><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface" href="asdfghjkl.html#laplace.curvature.asdfghjkl.AsdfghjklInterface">AsdfghjklInterface</a></b></code>:
<ul class="hlist">
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.diag" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.diag">diag</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.full" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.full">full</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.functorch_jacobians" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.functorch_jacobians">functorch_jacobians</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.gradients" href="asdfghjkl.html#laplace.curvature.asdfghjkl.AsdfghjklInterface.gradients">gradients</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.jacobians" href="asdfghjkl.html#laplace.curvature.asdfghjkl.AsdfghjklInterface.jacobians">jacobians</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.kron" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.kron">kron</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.last_layer_jacobians" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.last_layer_jacobians">last_layer_jacobians</a></code></li>
</ul>
</li>
</ul>
</dd>
<dt id="laplace.curvature.AsdfghjklHessian"><code class="flex name class">
<span>class <span class="ident">AsdfghjklHessian</span></span>
<span>(</span><span>model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', low_rank: int = 10)</span>
</code></dt>
<dd>
<div class="desc"><p>Interface for asdfghjkl backend.</p></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface" href="asdfghjkl.html#laplace.curvature.asdfghjkl.AsdfghjklInterface">AsdfghjklInterface</a></li>
<li><a title="laplace.curvature.curvature.CurvatureInterface" href="curvature.html#laplace.curvature.curvature.CurvatureInterface">CurvatureInterface</a></li>
</ul>
<h3>Methods</h3>
<dl>
<dt id="laplace.curvature.AsdfghjklHessian.eig_lowrank"><code class="name flex">
<span>def <span class="ident">eig_lowrank</span></span>(<span>self, data_loader: DataLoader) ‑> tuple[torch.Tensor, torch.Tensor, torch.Tensor]</span>
</code></dt>
<dd>
<div class="desc"></div>
</dd>
</dl>
<h3>Inherited members</h3>
<ul class="hlist">
<li><code><b><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface" href="asdfghjkl.html#laplace.curvature.asdfghjkl.AsdfghjklInterface">AsdfghjklInterface</a></b></code>:
<ul class="hlist">
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.diag" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.diag">diag</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.full" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.full">full</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.functorch_jacobians" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.functorch_jacobians">functorch_jacobians</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.gradients" href="asdfghjkl.html#laplace.curvature.asdfghjkl.AsdfghjklInterface.gradients">gradients</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.jacobians" href="asdfghjkl.html#laplace.curvature.asdfghjkl.AsdfghjklInterface.jacobians">jacobians</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.kron" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.kron">kron</a></code></li>
<li><code><a title="laplace.curvature.asdfghjkl.AsdfghjklInterface.last_layer_jacobians" href="curvature.html#laplace.curvature.curvature.CurvatureInterface.last_layer_jacobians">last_layer_jacobians</a></code></li>
</ul>
</li>
</ul>
</dd>
<dt id="laplace.curvature.AsdlInterface"><code class="flex name class">
<span>class <span class="ident">AsdlInterface</span></span>
<span>(</span><span>model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')</span>
Expand Down Expand Up @@ -1056,25 +893,6 @@ <h4><code><a title="laplace.curvature.BackPackGGN" href="#laplace.curvature.Back
<h4><code><a title="laplace.curvature.BackPackEF" href="#laplace.curvature.BackPackEF">BackPackEF</a></code></h4>
</li>
<li>
<h4><code><a title="laplace.curvature.AsdfghjklInterface" href="#laplace.curvature.AsdfghjklInterface">AsdfghjklInterface</a></code></h4>
<ul class="">
<li><code><a title="laplace.curvature.AsdfghjklInterface.jacobians" href="#laplace.curvature.AsdfghjklInterface.jacobians">jacobians</a></code></li>
<li><code><a title="laplace.curvature.AsdfghjklInterface.gradients" href="#laplace.curvature.AsdfghjklInterface.gradients">gradients</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="laplace.curvature.AsdfghjklGGN" href="#laplace.curvature.AsdfghjklGGN">AsdfghjklGGN</a></code></h4>
</li>
<li>
<h4><code><a title="laplace.curvature.AsdfghjklEF" href="#laplace.curvature.AsdfghjklEF">AsdfghjklEF</a></code></h4>
</li>
<li>
<h4><code><a title="laplace.curvature.AsdfghjklHessian" href="#laplace.curvature.AsdfghjklHessian">AsdfghjklHessian</a></code></h4>
<ul class="">
<li><code><a title="laplace.curvature.AsdfghjklHessian.eig_lowrank" href="#laplace.curvature.AsdfghjklHessian.eig_lowrank">eig_lowrank</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="laplace.curvature.AsdlInterface" href="#laplace.curvature.AsdlInterface">AsdlInterface</a></code></h4>
<ul class="">
<li><code><a title="laplace.curvature.AsdlInterface.jacobians" href="#laplace.curvature.AsdlInterface.jacobians">jacobians</a></code></li>
Expand Down
Loading

0 comments on commit 4e7b9cf

Please sign in to comment.