Skip to content

Commit

Permalink
cc
Browse files Browse the repository at this point in the history
  • Loading branch information
ahnitz committed Aug 8, 2023
1 parent 1525f95 commit b843b33
Showing 1 changed file with 149 additions and 55 deletions.
204 changes: 149 additions & 55 deletions pycbc/inference/sampler/refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def resample_equal(samples, logwt, seed=0):
idx = numpy.zeros(N, dtype=int)
cumulative_sum = numpy.cumsum(weights)
cumulative_sum /= cumulative_sum[-1]

i, j = 0, 0
while i < N:
if positions[i] < cumulative_sum[j]:
Expand Down Expand Up @@ -75,9 +76,12 @@ def __init__(self, model, *args, nprocesses=1, use_mpi=False,
iterative_kde_samples=int(1e3),
min_refinement_steps=5,
max_refinement_steps=40,
entropy=0.001,
offbase_fraction=0.7,
entropy=0.01,
dlogz=0.01,
kde=None,
update_groups=None,
max_kde_samples=int(5e4),
**kwargs):
super().__init__(model, *args)

Expand All @@ -91,19 +95,43 @@ def __init__(self, model, *args, nprocesses=1, use_mpi=False,

self.num_samples = int(num_samples)
self.iterative_kde_samples = int(iterative_kde_samples)
self.max_kde_samples = int(max_kde_samples)
self.min_refinement_steps = int(min_refinement_steps)
self.max_refinement_steps = int(max_refinement_steps)
self.offbase_fraction = float(offbase_fraction)
self.entropy = float(entropy)
self.dlogz_target = float(dlogz)

def draw_samples(self, size):
self.param_groups = []
if update_groups is None:
self.param_groups.append(self.vparam)
else:
for gname in update_groups.split():
gvalue = kwargs[gname]
if gvalue == 'all':
self.param_groups.append(self.vparam)
else:
self.param_groups.append(gvalue.split())

def draw_samples(self, size, update_params=None):
"""Draw new samples within the model priors"""
logging.info('getting from kde')

params = {}
ksamples = self.kde.resample(size=size)
params = {k: ksamples[i, :] for i, k in enumerate(self.vparam)}
j = 0
for i, k in enumerate(self.vparam):
if update_params is not None and k not in update_params:
params[k] = numpy.ones(size) * self.fixed_samples[i]
else:
params[k] = ksamples[j, :]
j += 1

logging.info('checking prior')
keep = self.model.prior_distribution.contains(params)
return ksamples[:, keep]
logging.info('done checking')
r = numpy.array([params[k][keep] for k in self.vparam])
return r

@staticmethod
def compare_kde(kde1, kde2, size=int(1e4)):
Expand All @@ -112,13 +140,16 @@ def compare_kde(kde1, kde2, size=int(1e4)):
s = kde1.resample(size=size)
return sentropy(kde1.pdf(s), kde2.pdf(s))

def converged(self, step, kde_new, factor):
def converged(self, step, kde_new, factor, logp):
""" Check that kde is converged by comparing to previous iteration
"""
logging.info('checking convergence')
if not hasattr(self, 'old_logz'):
self.old_logz = numpy.inf

entropy_diff = self.compare_kde(self.kde, kde_new)
entropy_diff = -1
if self.entropy < 1:
entropy_diff = self.compare_kde(self.kde, kde_new)

# Compare how the logz changes when adding new samples
# this is guaranteed to decrease as old samples included
Expand All @@ -133,11 +164,16 @@ def converged(self, step, kde_new, factor):
logz3 = logsumexp(choice3) - numpy.log(len(choice3))
dlogz2 = logz3 - logz2

logging.info('%s: Checking convergence: dlogz_iter=%.4f,'
'dlogz_half=%.4f, entropy=%.4f',
step, dlogz, dlogz2, entropy_diff)
# If kde matches posterior, the weights should be uniform
# check fraction that are significant deviation from peak
frac_offbase = (logp < logp.max() - 5.0).sum() / len(logp)

logging.info('%s: dlogz_iter=%.4f,'
'dlogz_half=%.4f, entropy=%.4f offbase fraction=%.4f',
step, dlogz, dlogz2, entropy_diff, frac_offbase)
if (entropy_diff < self.entropy and step >= self.min_refinement_steps
and max(abs(dlogz), abs(dlogz2)) < self.dlogz_target):
and max(abs(dlogz), abs(dlogz2)) < self.dlogz_target
and frac_offbase < self.offbase_fraction):
return True
else:
return False
Expand All @@ -156,35 +192,69 @@ def from_config(cls, cp, model, output_file=None, nprocesses=1,
def set_start_from_config(self, cp):
"""Sets the initial state of the sampler from config file
"""
num_samples = self.iterative_kde_samples
if cp.has_option('sampler', 'start-file'):
start_file = cp.get('sampler', 'start-file')
logging.info("Using file %s for initial positions", start_file)
samples = loadfile(start_file, 'r').read_samples(self.vparam)
f = loadfile(start_file, 'r')
fsamples = f.read_samples(f['samples'].keys())
num_samples = len(fsamples)

init_prior = initial_dist_from_config(
cp, self.model.variable_params, self.model.static_params)
if init_prior is not None:
samples = init_prior.rvs(size=num_samples)
else:
init_prior = initial_dist_from_config(
cp, self.model.variable_params, self.model.static_params)
if init_prior is not None:
samples = init_prior.rvs(size=self.iterative_kde_samples)
else:
p = self.model.prior_distribution
samples = p.rvs(size=self.iterative_kde_samples)
p = self.model.prior_distribution
samples = p.rvs(size=num_samples)

ksamples = numpy.array([samples[v] for v in self.vparam])
self.kde = gaussian_kde(ksamples)
ksamples = []
for v in self.vparam:
if v in fsamples:
ksamples.append(fsamples[v])
else:
ksamples.append(samples[v])

self.kde = gaussian_kde(numpy.array(ksamples))

def run_samples(self, ksamples):
def run_samples(self, ksamples, update_params=None, iteration=False):
""" Calculate the likelihoods and weights for a set of samples
"""
# Calculate likelihood for each sample
logging.info('calculating likelihoods...')
args = []
for i in range(len(ksamples[0])):
param = {k: ksamples[j][i] for j, k in enumerate(self.vparam)}
args.append(param)

result = self.pool.map(call_model, args)
logging.info('..done with likelihood')

logp = numpy.array([r[0] for r in result])
logl = numpy.array([r[1] for r in result])
logw = logp - numpy.log(self.kde.pdf(ksamples))


if update_params is not None:
ksamples = numpy.array([ksamples[i, :]
for i, k in enumerate(self.vparam)
if k in update_params])

# Weights for iteration
if iteration:
logw = logp - numpy.log(self.kde.pdf(ksamples))
logw = logw - logsumexp(logw)

# To avoid single samples dominating the weighting kde before
# we will put a cap on the minimum and maximum logw
sort = logw.argsort()
cap = logw[sort[-len(sort)//5]]
low = logw[sort[len(sort)//5]]
logw[logw > cap] = cap
logw[logw < low] = low
else:
# Weights for monte-carlo selection
logw = logp - numpy.log(self.kde.pdf(ksamples))
logw = logw - logsumexp(logw)

k = logp != - numpy.inf
ksamples = ksamples[:, k]
Expand All @@ -194,39 +264,60 @@ def run_samples(self, ksamples):
def run(self):
""" Iterative sample from kde and update based on likelihood values
"""
total_samples = None
total_logp = None
total_logw = None
total_logl = None

for r in range(self.max_refinement_steps):
logging.info('calculating likelihoods...')
ksamples = self.draw_samples(self.iterative_kde_samples)
ksamples, logp, logl, logw = self.run_samples(ksamples)

logging.info('..done')

if total_samples is not None:
total_samples = numpy.concatenate([total_samples,
ksamples], axis=1)
total_logp = numpy.concatenate([total_logp, logp])
total_logw = numpy.concatenate([total_logw, logw])
total_logl = numpy.concatenate([total_logl, logl])
else:
total_samples = ksamples
total_logp = logp
total_logw = logw
total_logl = logl

logging.info('setting up next kde iteration..')
ntotal_logw = total_logw - logsumexp(total_logw)
kde_new = gaussian_kde(total_samples,
weights=numpy.exp(ntotal_logw))
logging.info('done')
if self.converged(r, kde_new, total_logl + total_logw):
break

self.kde = kde_new
self.group_kde = self.kde
for param_group in self.param_groups:
total_samples = None
total_logp = None
total_logw = None
total_logl = None

gsample = self.group_kde.resample(int(1e5))
gsample = [gsample[i, :] for i, k in enumerate(self.vparam)
if k in param_group]
self.kde = gaussian_kde(numpy.array(gsample))
self.fixed_samples = self.group_kde.resample(1)

logging.info('updating: %s', param_group)
for r in range(self.max_refinement_steps):
ksamples = self.draw_samples(self.iterative_kde_samples,
update_params=param_group)
ksamples, logp, logl, logw = self.run_samples(ksamples,
update_params=param_group,
iteration=True)

if total_samples is not None:
total_samples = numpy.concatenate([total_samples,
ksamples], axis=1)
total_logp = numpy.concatenate([total_logp, logp])
total_logw = numpy.concatenate([total_logw, logw])
total_logl = numpy.concatenate([total_logl, logl])
else:
total_samples = ksamples
total_logp = logp
total_logw = logw
total_logl = logl

logging.info('setting up next kde iteration..')
ntotal_logw = total_logw - logsumexp(total_logw)
kde_new = gaussian_kde(total_samples,
weights=numpy.exp(ntotal_logw))

if self.converged(r, kde_new, total_logl + total_logw, logp):
break

self.kde = kde_new

full_samples = []
gsample = self.group_kde.resample(len(total_samples[0]))
i = 0
for j, k in enumerate(self.vparam):
if k in param_group:
full_samples.append(total_samples[i])
i += 1
else:
full_samples.append(gsample[j])

self.group_kde = gaussian_kde(numpy.array(full_samples))

logging.info('Drawing final samples')
ksamples = self.draw_samples(self.num_samples)
Expand All @@ -235,4 +326,7 @@ def run(self):
self._samples = {k: ksamples[j,:] for j, k in enumerate(self.vparam)}
self._samples['loglikelihood'] = logl
logging.info("Reweighting to equal samples")



self._samples = resample_equal(self._samples, logw)

0 comments on commit b843b33

Please sign in to comment.