Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Nested sampling implementation #755

Draft
wants to merge 60 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
9f84393
First draft
williamjameshandley Jul 29, 2024
cc804c2
Added testing script
williamjameshandley Jul 29, 2024
70cb149
basic particle scan
yallup Jul 30, 2024
f8b54ec
refine carry
yallup Jul 30, 2024
34b594b
sketch of history
yallup Jul 30, 2024
2f20304
cleaner scan
yallup Jul 30, 2024
639b84a
Merge remote-tracking branch 'origin/ns_history' into elliptical
yallup Aug 11, 2024
ba67732
preliminary loop
yallup Aug 11, 2024
c651b09
reject from prior
yallup Aug 12, 2024
69e7770
correct contour
yallup Aug 12, 2024
80ade3e
check in progress
yallup Sep 9, 2024
d4d211e
working loop
yallup Sep 11, 2024
066f04b
cleanup
yallup Sep 12, 2024
1aec36e
example
yallup Sep 12, 2024
ab09868
document
yallup Sep 18, 2024
8ab00fc
Merge pull request #4 from handley-lab/elliptical
yallup Sep 18, 2024
bb13cd5
add basic slice code
yallup Sep 18, 2024
91b2b8a
Merge remote-tracking branch 'origin/nested_sampling' into slice
yallup Sep 18, 2024
f4cb098
examples
yallup Sep 18, 2024
de984e3
Merge pull request #7 from handley-lab/slice
yallup Sep 18, 2024
7b3aa02
add upper limit iter
yallup Sep 27, 2024
106781b
cleanup example
yallup Sep 30, 2024
012514b
Merge branch 'nested_sampling' of github.com:handley-lab/blackjax int…
yallup Sep 30, 2024
a6e318b
slice sampling nested sampling
yallup Oct 15, 2024
7aad4e2
vertical slicing draft
williamjameshandley Oct 15, 2024
ba01a2e
Updated with slice sampling loop
williamjameshandley Oct 16, 2024
fbea2ad
Fixed prior sign issue and other shape problems
williamjameshandley Oct 16, 2024
9dc4a7f
Tidying up
williamjameshandley Oct 16, 2024
f67e7a6
Working(ish)
williamjameshandley Oct 16, 2024
a1e6b4b
selection from live
yallup Oct 17, 2024
89fe002
remove duplicated normalization
yallup Oct 17, 2024
a7f59e3
attempt convergence criteria
yallup Oct 17, 2024
c91c233
remove double choice
yallup Oct 17, 2024
11e0638
Minor corrections to evidence accumulation
williamjameshandley Oct 18, 2024
982cc1e
Corrected direction choice
williamjameshandley Oct 18, 2024
78e8eb9
better accumulation
yallup Oct 18, 2024
0e2808a
better accuulation and diagnostics
yallup Oct 21, 2024
c713df1
remove old inner kernel code
yallup Oct 21, 2024
86cf581
remove spurious normalizaiton factor guess
yallup Oct 21, 2024
de62751
better accumulation
yallup Oct 22, 2024
5b33691
tweak stepping in test
yallup Oct 23, 2024
f6213a0
fix rng key passing and stop overstepping on stepping out
yallup Oct 24, 2024
94a38ba
remove unnecessary checks on steps
yallup Oct 29, 2024
a0b5164
Multiple deletion calculation (not currently working)
williamjameshandley Nov 1, 2024
03c3019
cleanup and refactor
yallup Nov 4, 2024
652fb05
Merge branch 'nested_sampling' of github.com:handley-lab/blackjax int…
yallup Nov 5, 2024
7fe4666
add example
yallup Nov 6, 2024
7bf2938
correct contour indexing
yallup Nov 6, 2024
88fb151
loosen prior bound
yallup Nov 7, 2024
5daa6c5
minor suggestions from after-lunch meeting 2024-11-07
williamjameshandley Nov 7, 2024
84b6849
cleanup and docstring
yallup Nov 11, 2024
c955f10
precommit on
yallup Nov 11, 2024
2021dbb
Merge branch 'blackjax-devs:main' into nested_sampling
yallup Nov 11, 2024
4251639
Updated evidence accumulation
williamjameshandley Nov 11, 2024
972894a
actually add files
yallup Nov 11, 2024
e69f0bb
Merge branch 'nested_sampling' of github.com:handley-lab/blackjax int…
yallup Nov 11, 2024
685e781
Merge branch 'nested_sampling' of github.com:handley-lab/blackjax int…
williamjameshandley Nov 11, 2024
70889e3
comments added on the example and fix commit hooks
yallup Nov 11, 2024
00d8a68
remove unused old slice code
yallup Nov 11, 2024
438afbd
include ns in top level api import
yallup Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .smc import inner_kernel_tuning as _inner_kernel_tuning
from .smc import partial_posteriors_path as _partial_posteriors_smc
from .smc import tempered
from .ns import adaptive
from .vi import meanfield_vi as _meanfield_vi
from .vi import pathfinder as _pathfinder
from .vi import schrodinger_follmer as _schrodinger_follmer
Expand Down Expand Up @@ -122,6 +123,8 @@ def generate_top_level_api_from(module):
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning)
partial_posteriors_smc = generate_top_level_api_from(_partial_posteriors_smc)

adaptive_ns = generate_top_level_api_from(adaptive)

smc_family = [tempered_smc, adaptive_tempered_smc, partial_posteriors_smc]
"Step_fn returning state has a .particles attribute"

Expand Down
5 changes: 5 additions & 0 deletions blackjax/ns/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from . import adaptive

__all__ = [
"adaptive",
]
195 changes: 195 additions & 0 deletions blackjax/ns/adaptive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from functools import partial
from typing import Callable, Dict

import jax.numpy as jnp

from blackjax import SamplingAlgorithm
from blackjax.ns.base import NSInfo, NSState
from blackjax.ns.base import build_kernel as base_ns
from blackjax.ns.base import delete_fn
from blackjax.ns.base import init as init_base
from blackjax.ns.vectorized_slice import build_kernel as slice_kernel
from blackjax.ns.vectorized_slice import init as slice_init
from blackjax.smc.inner_kernel_tuning import StateWithParameterOverride
from blackjax.smc.tuning.from_particles import particles_covariance_matrix
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey

__all__ = ["init", "as_top_level_api", "build_kernel", "nss"]


def init(position, loglikelihood_fn, parameter_update_function):
state = init_base(position, loglikelihood_fn)
initial_parameter_value = parameter_update_function(
state, NSInfo(state, state, state, None)
)
return StateWithParameterOverride(state, initial_parameter_value)


def build_kernel(
logprior_fn: Callable,
loglikelihood_fn: Callable,
delete_fn: Callable,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
mcmc_parameter_update_fn: Callable,
num_mcmc_steps: int,
) -> Callable:
r"""Build an adaptive Nested Sampling kernel. Tunes the inner kernel parameters
at each iteration.

Parameters
----------
logprior_fn : Callable
A function that computes the log prior probability.
loglikelihood_fn : Callable
A function that computes the log likelihood.
delete_fn : Callable
Function that takes an array of log likelihoods and marks particles for deletion and updates.
mcmc_step_fn:
The initialisation of the transition kernel, should take as parameters.
kernel = mcmc_step_fn(logprior, loglikelihood, logL0 (likelihood threshold), **mcmc_parameter_update_fn())
mcmc_init_fn
A callable that initializes the inner kernel
mcmc_parameter_update_fn : Callable[[NSState, NSInfo], Dict[str, ArrayTree]]
Function that updates the parameters of the inner kernel.
num_mcmc_steps: int
Number of MCMC steps to perform. Recommended is 5 times the dimension of the parameter space.

Returns
-------
Callable
A function that takes a rng_key and a NSState that contains the current state
of the chain and returns a new state of the chain along with
information about the transition.
"""

def kernel(
rng_key: PRNGKey,
state: StateWithParameterOverride,
) -> tuple[StateWithParameterOverride, NSInfo]:
step_fn = base_ns(
logprior_fn,
loglikelihood_fn,
delete_fn,
mcmc_step_fn,
mcmc_init_fn,
num_mcmc_steps,
)
new_state, info = step_fn(
rng_key, state.sampler_state, state.parameter_override
)
new_parameter_override = mcmc_parameter_update_fn(new_state, info)
return (
StateWithParameterOverride(new_state, new_parameter_override),
info,
)

return kernel


def as_top_level_api(
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
mcmc_parameter_update_fn: Callable[[NSState, NSInfo], Dict[str, ArrayTree]],
num_mcmc_steps: int,
n_delete: int = 1,
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the Adaptive Nested Sampling kernel.

Parameters
----------
logprior_fn : Callable
A function that computes the log prior probability.
loglikelihood_fn : Callable
A function that computes the log likelihood.
mcmc_step_fn:
The initialisation of the transition kernel, should take as parameters.
kernel = mcmc_step_fn(logprior, loglikelihood, logL0 (likelihood threshold), **mcmc_parameter_update_fn())
mcmc_init_fn
A callable that initializes the inner kernel
mcmc_parameter_update_fn : Callable[[NSState, NSInfo], Dict[str, ArrayTree]]
A function that updates the parameters given the current state and info.
num_mcmc_steps: int
Number of MCMC steps to perform. Recommended is 5 times the dimension of the parameter space.
n_delete : int, optional
Number of particles to delete in each iteration. Default is 1.

Returns
-------
SamplingAlgorithm
A sampling algorithm object.
"""
delete_func = partial(delete_fn, n_delete=n_delete)

kernel = build_kernel(
logprior_fn,
loglikelihood_fn,
delete_func,
mcmc_step_fn,
mcmc_init_fn,
mcmc_parameter_update_fn,
num_mcmc_steps,
)

def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return init(position, loglikelihood_fn, mcmc_parameter_update_fn)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, state)

return SamplingAlgorithm(init_fn, step_fn)


def nss(
logprior_fn: Callable,
loglikelihood_fn: Callable,
num_mcmc_steps: int,
n_delete: int = 1,
) -> SamplingAlgorithm:
"""Implements the a baseline Nested Slice Sampling kernel.

Parameters
----------
logprior_fn: Callable
A function that computes the log prior probability.
loglikelihood_fn: Callable
A function that computes the log likelihood.
num_mcmc_steps: int
Number of MCMC steps to perform. Recommended is 5 times the dimension of the parameter space.
n_delete: int, optional
Number of particles to delete in each iteration. Default is 1.

Returns
-------
SamplingAlgorithm
A sampling algorithm object.
"""
delete_func = partial(delete_fn, n_delete=n_delete)
mcmc_step_fn = slice_kernel
mcmc_init_fn = slice_init

def parameter_update_fn(state, _):
cov = jnp.atleast_2d(particles_covariance_matrix(state.particles))
return {"cov": cov}

kernel = build_kernel(
logprior_fn,
loglikelihood_fn,
delete_func,
mcmc_step_fn,
mcmc_init_fn,
parameter_update_fn,
num_mcmc_steps,
)

def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return init(position, loglikelihood_fn, parameter_update_fn)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, state)

return SamplingAlgorithm(init_fn, step_fn)
Loading
Loading