Skip to content

Commit

Permalink
cc
Browse files Browse the repository at this point in the history
  • Loading branch information
ahnitz committed Aug 8, 2023
1 parent b843b33 commit e51d7b0
Showing 1 changed file with 102 additions and 81 deletions.
183 changes: 102 additions & 81 deletions pycbc/inference/sampler/refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

def call_model(params):
models._global_instance.update(**params)
return (models._global_instance.logposterior,
models._global_instance.loglikelihood)
return (
models._global_instance.logposterior,
models._global_instance.loglikelihood,
)


def resample_equal(samples, logwt, seed=0):
Expand Down Expand Up @@ -69,20 +71,27 @@ class RefineSampler(DummySampler):
kde: kde
The inital kde to use.
"""
name = 'refine'

def __init__(self, model, *args, nprocesses=1, use_mpi=False,
num_samples=int(1e5),
iterative_kde_samples=int(1e3),
min_refinement_steps=5,
max_refinement_steps=40,
offbase_fraction=0.7,
entropy=0.01,
dlogz=0.01,
kde=None,
update_groups=None,
max_kde_samples=int(5e4),
**kwargs):

name = "refine"

def __init__(
self,
model,
*args,
nprocesses=1,
use_mpi=False,
num_samples=int(1e5),
iterative_kde_samples=int(1e3),
min_refinement_steps=5,
max_refinement_steps=40,
offbase_fraction=0.7,
entropy=0.01,
dlogz=0.01,
kde=None,
update_groups=None,
max_kde_samples=int(5e4),
**kwargs
):
super().__init__(model, *args)

self.model = model
Expand All @@ -108,14 +117,14 @@ def __init__(self, model, *args, nprocesses=1, use_mpi=False,
else:
for gname in update_groups.split():
gvalue = kwargs[gname]
if gvalue == 'all':
if gvalue == "all":
self.param_groups.append(self.vparam)
else:
self.param_groups.append(gvalue.split())

def draw_samples(self, size, update_params=None):
"""Draw new samples within the model priors"""
logging.info('getting from kde')
logging.info("getting from kde")

params = {}
ksamples = self.kde.resample(size=size)
Expand All @@ -127,24 +136,22 @@ def draw_samples(self, size, update_params=None):
params[k] = ksamples[j, :]
j += 1

logging.info('checking prior')
logging.info("checking prior")
keep = self.model.prior_distribution.contains(params)
logging.info('done checking')
logging.info("done checking")
r = numpy.array([params[k][keep] for k in self.vparam])
return r

@staticmethod
def compare_kde(kde1, kde2, size=int(1e4)):
""" Calculate information difference between two kde distributions
"""
"""Calculate information difference between two kde distributions"""
s = kde1.resample(size=size)
return sentropy(kde1.pdf(s), kde2.pdf(s))

def converged(self, step, kde_new, factor, logp):
""" Check that kde is converged by comparing to previous iteration
"""
logging.info('checking convergence')
if not hasattr(self, 'old_logz'):
"""Check that kde is converged by comparing to previous iteration"""
logging.info("checking convergence")
if not hasattr(self, "old_logz"):
self.old_logz = numpy.inf

entropy_diff = -1
Expand All @@ -168,40 +175,49 @@ def converged(self, step, kde_new, factor, logp):
# check fraction that are significant deviation from peak
frac_offbase = (logp < logp.max() - 5.0).sum() / len(logp)

logging.info('%s: dlogz_iter=%.4f,'
'dlogz_half=%.4f, entropy=%.4f offbase fraction=%.4f',
step, dlogz, dlogz2, entropy_diff, frac_offbase)
if (entropy_diff < self.entropy and step >= self.min_refinement_steps
logging.info(
"%s: dlogz_iter=%.4f,"
"dlogz_half=%.4f, entropy=%.4f offbase fraction=%.4f",
step,
dlogz,
dlogz2,
entropy_diff,
frac_offbase,
)
if (
entropy_diff < self.entropy
and step >= self.min_refinement_steps
and max(abs(dlogz), abs(dlogz2)) < self.dlogz_target
and frac_offbase < self.offbase_fraction):
and frac_offbase < self.offbase_fraction
):
return True
else:
return False

@classmethod
def from_config(cls, cp, model, output_file=None, nprocesses=1,
use_mpi=False):
"""This should initialize the sampler given a config file.
"""
kwargs = {k: cp.get('sampler', k) for k in cp.options('sampler')}
def from_config(
cls, cp, model, output_file=None, nprocesses=1, use_mpi=False
):
"""This should initialize the sampler given a config file."""
kwargs = {k: cp.get("sampler", k) for k in cp.options("sampler")}
obj = cls(model, nprocesses=nprocesses, use_mpi=use_mpi, **kwargs)
obj.set_start_from_config(cp)
setup_output(obj, output_file, check_nsamples=False, validate=False)
return obj

def set_start_from_config(self, cp):
"""Sets the initial state of the sampler from config file
"""
"""Sets the initial state of the sampler from config file"""
num_samples = self.iterative_kde_samples
if cp.has_option('sampler', 'start-file'):
start_file = cp.get('sampler', 'start-file')
if cp.has_option("sampler", "start-file"):
start_file = cp.get("sampler", "start-file")
logging.info("Using file %s for initial positions", start_file)
f = loadfile(start_file, 'r')
fsamples = f.read_samples(f['samples'].keys())
f = loadfile(start_file, "r")
fsamples = f.read_samples(f["samples"].keys())
num_samples = len(fsamples)

init_prior = initial_dist_from_config(
cp, self.model.variable_params, self.model.static_params)
cp, self.model.variable_params, self.model.static_params
)
if init_prior is not None:
samples = init_prior.rvs(size=num_samples)
else:
Expand All @@ -214,56 +230,57 @@ def set_start_from_config(self, cp):
ksamples.append(fsamples[v])
else:
ksamples.append(samples[v])

self.kde = gaussian_kde(numpy.array(ksamples))

def run_samples(self, ksamples, update_params=None, iteration=False):
""" Calculate the likelihoods and weights for a set of samples
"""
"""Calculate the likelihoods and weights for a set of samples"""
# Calculate likelihood for each sample
logging.info('calculating likelihoods...')
logging.info("calculating likelihoods...")
args = []
for i in range(len(ksamples[0])):
param = {k: ksamples[j][i] for j, k in enumerate(self.vparam)}
args.append(param)

result = self.pool.map(call_model, args)
logging.info('..done with likelihood')
logging.info("..done with likelihood")

logp = numpy.array([r[0] for r in result])
logl = numpy.array([r[1] for r in result])


if update_params is not None:
ksamples = numpy.array([ksamples[i, :]
for i, k in enumerate(self.vparam)
if k in update_params])
ksamples = numpy.array(
[
ksamples[i, :]
for i, k in enumerate(self.vparam)
if k in update_params
]
)

# Weights for iteration
if iteration:
logw = logp - numpy.log(self.kde.pdf(ksamples))
logw = logp - numpy.log(self.kde.pdf(ksamples))
logw = logw - logsumexp(logw)

# To avoid single samples dominating the weighting kde before
# we will put a cap on the minimum and maximum logw
sort = logw.argsort()
cap = logw[sort[-len(sort)//5]]
low = logw[sort[len(sort)//5]]
cap = logw[sort[-len(sort) // 5]]
low = logw[sort[len(sort) // 5]]
logw[logw > cap] = cap
logw[logw < low] = low
else:
# Weights for monte-carlo selection
logw = logp - numpy.log(self.kde.pdf(ksamples))
logw = logp - numpy.log(self.kde.pdf(ksamples))
logw = logw - logsumexp(logw)

k = logp != - numpy.inf
k = logp != -numpy.inf
ksamples = ksamples[:, k]
logp, logl, logw = logp[k], logl[k], logw[k]
return ksamples, logp, logl, logw

def run(self):
""" Iterative sample from kde and update based on likelihood values
"""
"""Iterative sample from kde and update based on likelihood values"""
self.group_kde = self.kde
for param_group in self.param_groups:
total_samples = None
Expand All @@ -272,22 +289,27 @@ def run(self):
total_logl = None

gsample = self.group_kde.resample(int(1e5))
gsample = [gsample[i, :] for i, k in enumerate(self.vparam)
if k in param_group]
gsample = [
gsample[i, :]
for i, k in enumerate(self.vparam)
if k in param_group
]
self.kde = gaussian_kde(numpy.array(gsample))
self.fixed_samples = self.group_kde.resample(1)

logging.info('updating: %s', param_group)
logging.info("updating: %s", param_group)
for r in range(self.max_refinement_steps):
ksamples = self.draw_samples(self.iterative_kde_samples,
update_params=param_group)
ksamples, logp, logl, logw = self.run_samples(ksamples,
update_params=param_group,
iteration=True)
ksamples = self.draw_samples(
self.iterative_kde_samples, update_params=param_group
)
ksamples, logp, logl, logw = self.run_samples(
ksamples, update_params=param_group, iteration=True
)

if total_samples is not None:
total_samples = numpy.concatenate([total_samples,
ksamples], axis=1)
total_samples = numpy.concatenate(
[total_samples, ksamples], axis=1
)
total_logp = numpy.concatenate([total_logp, logp])
total_logw = numpy.concatenate([total_logw, logw])
total_logl = numpy.concatenate([total_logl, logl])
Expand All @@ -297,10 +319,11 @@ def run(self):
total_logw = logw
total_logl = logl

logging.info('setting up next kde iteration..')
logging.info("setting up next kde iteration..")
ntotal_logw = total_logw - logsumexp(total_logw)
kde_new = gaussian_kde(total_samples,
weights=numpy.exp(ntotal_logw))
kde_new = gaussian_kde(
total_samples, weights=numpy.exp(ntotal_logw)
)

if self.converged(r, kde_new, total_logl + total_logw, logp):
break
Expand All @@ -319,14 +342,12 @@ def run(self):

self.group_kde = gaussian_kde(numpy.array(full_samples))

logging.info('Drawing final samples')
logging.info("Drawing final samples")
ksamples = self.draw_samples(self.num_samples)
logging.info('Calculating final likelihoods')
logging.info("Calculating final likelihoods")
ksamples, logp, logl, logw = self.run_samples(ksamples)
self._samples = {k: ksamples[j,:] for j, k in enumerate(self.vparam)}
self._samples['loglikelihood'] = logl
self._samples = {k: ksamples[j, :] for j, k in enumerate(self.vparam)}
self._samples["loglikelihood"] = logl
logging.info("Reweighting to equal samples")




self._samples = resample_equal(self._samples, logw)

0 comments on commit e51d7b0

Please sign in to comment.