Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Implement Full-Rank VI #720

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
Doc: formatting
gil2rok committed Aug 16, 2024
commit b13eb12c7b35136d97d932763a0eaec76bd89318
6 changes: 3 additions & 3 deletions blackjax/vi/fullrank_vi.py
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@

class FRVIState(NamedTuple):
mu: ArrayTree
chol_params: ArrayTree # flattened Cholesky factor
chol_params: ArrayTree # flattened Cholesky factor
opt_state: OptState


@@ -148,8 +148,8 @@ def _unflatten_cholesky(chol_params):
"""Construct the Cholesky factor from a flattened vector of cholesky parameters.

Transforms a flattened vector representation of a lower triangular matrix
into a full Cholesky factor. The input vector contains n = d(d+1)/2 elements
consisting of d diagonal elements followed by n - d off-diagonal elements in
into a full Cholesky factor. The input vector contains n = d(d+1)/2 elements
consisting of d diagonal elements followed by n - d off-diagonal elements in
row-major order, where d is the dimension of the matrix.

The diagonal elements are passed through a softplus function to ensure (numerically