Efficient sampling from the truncated multivariate Normal distribution.
Reimplementation using Python of the minimax tilting algorithm by Botev (2016) for simulation and iid sampling of the truncated multivariate Normal distribution. The original MATLAB implementation by the author can be found here.
The main features of this algorithm are:
- simulation from multivariate truncated Normal distribution using an accept-reject algorithm based on minimax exponential tilting
- (quasi) Monte-Carlo estimation of the distribution function using separation-of-variables together with exponential tilting for provable performances and theoretical upper bound on the error
- Cholesky decomposition using the reordering algorithm of Gibson, Glasbey and Elston (1994).
Feel free to use this code, but don't forget to cite Botev (2016)!
Also consider checking out our paper "On Controller Tuning with Time-Varying Bayesian Optimmization" (2022) and the corresonding repo where we utilize this sampling method to efficiently tune controllers.
The Python implementation provided here uses a modification of the Powell hybrid method implemented in SciPy (method='hybr'
) for finding the optimal tilting parameter compared to the Trust-Region-Dogleg Algorithm used in the MATLAB implementation.
d = 10 # dimensions
# random mu and cov
mu = np.random.rand(d)
cov = 0.5 - np.random.rand(d ** 2).reshape((d, d))
cov = np.triu(cov)
cov += cov.T - np.diag(cov.diagonal())
cov = np.dot(cov, cov)
# constraints
lb = np.zeros_like(mu) - 1
ub = np.ones_like(mu) * np.inf
# create truncated normal and sample from it
n_samples = 100000
tmvn = TruncatedMVN(mu, cov, lb, ub)
samples = tmvn.sample(n_samples)
Ploting the results of the first dimension and comparing it to the nontruncated normal distribution results in:
Disregard the scaling, as the normal and truncated normal are plotted on a different y-axis.
The implementation is based on the MATLAB implemenation by author. An R implemetation by the author can be found here or installed from CRAN via
install.packages("TruncatedNormal")
or from Github via
devtools::install_github("lbelzile/TruncatedNormal")