AutoDiff-Inference (Bijax)
This repository contains code for implementing Automatic Differentiation Variational Inference (ADVI) and different variants of Laplace Approximation based on major research papers.
- ADVI Implementation
- Laplace Approximation: Implementation of Laplace Approximation for constrained variables, inspired by Automatic Differentiation Variational Inference (ADVI).
## Creation of the dataset for Laplace Approximation
data_dist = tfd.Bernoulli(probs=0.7)
data = data_dist.sample(sample_shape=(100,), seed=jax.random.PRNGKey(3))
prior_theta = [3.0, 5.0]
## Bernoulli likelihood function
def likelihood_fn(theta, data):
return tfd.Bernoulli(probs=theta).log_prob(data).sum()
# For Posterior distribution
alpha = prior_theta[0] + data.sum()
beta = prior_theta[1] + len(data) - data.sum()
Normal Laplace Approximation
## Using Identity bijector for normal Laplace Approximation
la = LaplaceApproximation(
prior=tfd.Beta(prior_theta[0], prior_theta[1]),
bijector=tfp.bijectors.Identity(),
likelihood=likelihood_fn)
true_posterior = tfd.Beta(alpha, beta) ## True posterior
fig = la.plot_approx_posterior(true_posterior=true_posterior)
plt.xlim(-0.5,1.5)
plt.figure()
plt.savefig("plots/la_coin_toss.png")
Autodiff- Laplace Appoximation
## Using Sigmoid bijector for constrained Laplace Approximation
la_cov = LaplaceApproximation(
prior=tfd.Beta(prior_theta[0], prior_theta[1]),
bijector=tfp.bijectors.Sigmoid(),
likelihood=likelihood_fn)
true_posterior = tfd.Beta(alpha, beta)
fig_cov = la_cov.plot_approx_posterior(true_posterior=true_posterior)
plt.figure()
plt.savefig("plots/la_cov_coin_toss.png")
fig = la_cov.plot_log_approx_posterior(true_posterior=true_posterior)
plt.savefig("plots/log_la_cov_coin_toss.png")
In addition to the implemented library for Laplace approximation, you'll find two additional notebooks showcasing diagonal Laplace approximation and low-rank Laplace approximation.
tfd = tfp.distributions
data_dist = tfd.Bernoulli(probs=0.7)
data = data_dist.sample(sample_shape=(100,), seed=jax.random.PRNGKey(3))
prior_theta = [3.0, 5.0]
def likelihood_fn(theta, data):
return tfd.Bernoulli(probs=theta).log_prob(data).sum()
advi = ADVI(
prior=tfd.Beta(prior_theta[0], prior_theta[1]),
bijector=tfp.bijectors.NormalCDF(),
likelihood=likelihood_fn,
)
appx_post = advi.approx_posterior(data)