Skip to content

Releases: blackjax-devs/blackjax

BlackJAX v0.6.0

17 May 12:42
Compare
Choose a tag to compare

What's Changed

New Contributors

Full Changelog: 0.5.0...0.6.0

BlackJAX v0.5.0

29 Apr 13:50
Compare
Choose a tag to compare

What's Changed

New Contributors

BlackJAX 0.4.0

29 Mar 07:14
Compare
Choose a tag to compare

Breaking changes

⚠️ This release changes the high-level API as well as import paths ⚠️

This release simplifies the high-level API for samplers. For instance, to initialize and use a HMC kernel:

import blackjax

hmc = blackjax.hmc(logprob_fn step_size, inverse_mass_matrix, num_integration_steps)
state = hmc.init(position)
new_state, info = hmc.step(rng_key, state)

hmc is now a namedtuple with a init and a step function; you only need to pass logprob_fn at initialization unlike the previous version. The internals were simplified a lot, and the hierarchy is now more flat. For instance, to use the base HMC kernel directly:

import blackjax.mcmc.integrators as integrators
import blackjax.mcmc.hmc as hmc

kernel = hmc.kernel(integrators.mclachlan)
state = hmc.init(position, logprob_fn)
state, info = kernel(rng_key, state, logprob_fn, step_size, inverse_mass_matrix, num_integration_steps)

The API of the base kernels has also been changed to be more flexible.

Performance improvements

Thanks to the work of @zaxtax @junpenglao and @rlouf the performance of the NUTS sampler (especially the warmup) has been greatly improved and is now at least on par with numpyro.

What's Changed

No new algorithm in this release, but important work was done on the API, the internals and the examples.

New Contributors

Full Changelog: 0.3.0...0.4.0

BlackJAX 0.3.0

22 Nov 16:19
Compare
Choose a tag to compare

What changed

Breaking changes

To build a HMC or NUTS kernel in 0.2.1 and previous versions one needed to provide a potential_fn function:

kernel = nuts.kernel(potential_fn, step_size, inverse_mass_matrix)

Instead we now ask the users to provide the more commonly used log-probability function:

kernel = nuts.kernel(logprob_fn, step_size, inverse_mass_matrix)

where logprob_fn = lambda x: -potential_fn(*x)

New features

Bugs

  • Missing key splitting in trajectory integration (@wiep #53 )

BlackJAX 0.2.1

08 Jun 09:04
Compare
Choose a tag to compare

What changed

  • momentum and position were passed to the kinetic energy in the wrong order, leading to biased sampling as noticed in #46. We corrected this behavior and added a new test.

BlackJAX 0.2

03 Jun 08:07
Compare
Choose a tag to compare

What changed

  • The Stan adaptation scheme, including dual averaging, computing covariance with Welford's algorithm and the schedule (@rlouf)
  • Recursive implementation of NUTS (@junpenglao)
  • Many BUG fixes on NUTS (@junpenglao)

BlackJAX 0.1

23 May 19:47
8209172
Compare
Choose a tag to compare

New features

  • hmc kernel
  • nuts kernel
  • Notebook with examples of how to sample one or multiple chains with HMC, NUTS