-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
experimental: support for power spectrum data #162
Conversation
Reviewer's Guide by SourceryThis pull request introduces support for power spectrum data analysis using the Whittle likelihood. It also refactors the AIES sampling method for improved parallelization. The Whittle likelihood is implemented with a custom exponential distribution and integrated into the existing data handling and simulation framework. The AIES sampler now uses jax.pmap for parallel execution and removes the chain_method parameter. Sequence diagram for Whittle likelihood calculationsequenceDiagram
participant Data as FixedData
participant Model as ModelCompiledFn
participant Likelihood as whittle
participant numpyro
Data->>Likelihood: whittle(data, model)
Likelihood->>Likelihood: power = data.net_counts
Likelihood->>Likelihood: freq_bins = data.photon_egrid
Likelihood->>Likelihood: df = jnp.diff(freq_bins)
Likelihood->>Model: pmodel = model(freq_bins, params)
Model-->>Likelihood: Returns model prediction
Likelihood->>numpyro: numpyro.deterministic(name, pmodel / df)
Likelihood->>numpyro: numpyro.deterministic(f'{name}_Non_model', pmodel)
Likelihood->>numpyro: pdata = numpyro.primitives.mutable(f'{name}_Non_data', power)
loop For each frequency bin
Likelihood->>numpyro: numpyro.sample(name=f'{name}_Non', fn=BetterExponential(1.0 / pmodel), obs=pdata)
end
Likelihood->>numpyro: loglike_on = numpyro.deterministic(name=f'{name}_Non_loglike', value=dist_on.log_prob(pdata))
Likelihood->>numpyro: numpyro.deterministic(name=f'{name}_loglike', value=loglike_on)
Sequence diagram for AIES sampling with parallel executionsequenceDiagram
participant User
participant AIES as aies
participant MCMC
participant jax.pmap
participant AIES_kernel
User->>AIES: aies(warmup, steps, chains, n_parallel, init, moves, **aies_kwargs)
AIES->>AIES: init = self._helper.free_default['constr_dic']
AIES->>AIES: aies_kwargs['model'] = self._helper.numpyro_model
AIES->>AIES: aies_kwargs['moves'] = {AIES.StretchMove(): 1.0}
alt n_parallel > 0
AIES->>AIES_kernel: aies_kernel = AIES(**aies_kwargs)
AIES->>jax.pmap: traces = jax.pmap(do_mcmc)(rng_keys)
jax.pmap-->>AIES: Returns traces
AIES->>MCMC: sampler = MCMC(aies_kernel, num_warmup, num_samples)
AIES->>MCMC: sampler._states = {sampler._sample_field: traces}
else n_parallel == 0
AIES->>MCMC: sampler = MCMC(AIES_kernel, num_warmup, num_samples, num_chains, chain_method='vectorized')
MCMC->>MCMC: sampler.run(rng_key, init_params=init)
end
MCMC-->>AIES: Returns sampler
AIES-->>User: Returns sampler
Updated class diagram for likelihood statisticsclassDiagram
class FixedData {
+name: str
+net_counts: JAXArray
+photon_egrid: JAXArray
+spec_poisson: bool
+spec_counts: JAXArray
+spec_errors: JAXArray
+response_sparse: bool
+sparse_matrix: scipy.sparse.spmatrix
}
class BetterExponential {
+rate: JAXArray
+log_prob(value: JAXArray): JAXArray
}
class Exponential {
<<Abstract>>
}
BetterExponential --|> Exponential : inherits from
note for FixedData "Added photon_egrid attribute to store frequency bins for power spectrum data."
note for BetterExponential "Custom exponential distribution for Whittle likelihood, includes gof term."
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @wcxve - I've reviewed your changes - here's some feedback:
Overall Comments:
- Consider adding a short example demonstrating how to use the new
whittle
likelihood. - The
BetterExponential
distribution seems to be working around a gradient issue; can you add a comment explaining why it's needed?
Here's what I looked at during the review
- 🟢 General issues: all looks good
- 🟢 Security: all looks good
- 🟢 Testing: all looks good
- 🟡 Complexity: 1 issue found
- 🟢 Documentation: all looks good
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
src/elisa/infer/fit.py
Outdated
@@ -964,12 +972,11 @@ def aies( | |||
Affine-invariant ensemble sampling [1]_ is a gradient-free method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (complexity): Consider extracting the parallel sampling logic into a helper function to reduce nested branching and improve code clarity.
Consider extracting the parallel sampling logic (including the nested do_mcmc
function) into one or more helper functions. This reduces the deeply nested branching and clarifies the flow.
For example, you might extract the parallel MCMC run into a helper function:
def run_parallel_mcmc(kernel, warmup, steps, chains, init, seed, progress_bar):
def do_mcmc(rng_key):
mcmc = MCMC(
kernel,
num_warmup=warmup,
num_samples=steps,
num_chains=chains,
chain_method='vectorized',
progress_bar=progress_bar,
)
mcmc.run(rng_key, init_params=init)
return mcmc.get_samples(group_by_chain=True)
rng_keys = jax.random.split(jax.random.PRNGKey(seed), get_parallel_number(n_parallel))
traces = jax.pmap(do_mcmc)(rng_keys)
return {k: np.concatenate(v) for k, v in traces.items()}
Then in aies
or ess
replace the inlined block with a call to the helper:
if chain_method == 'parallel': # or if n_parallel:
aies_kernel = AIES(**aies_kwargs)
trace = run_parallel_mcmc(
kernel=aies_kernel,
warmup=warmup,
steps=steps,
chains=chains,
init=init,
seed=self._helper.seed['mcmc'],
progress_bar=False,
)
sampler = MCMC(aies_kernel, num_warmup=warmup, num_samples=steps)
sampler._states = {sampler._sample_field: trace}
else:
# existing non-parallel branch
This refactoring isolates the complexity, improves readability, and preserves all functionality.
src/elisa/infer/fit.py
Outdated
if moves is None: | ||
aies_kwargs['moves'] = { | ||
AIES.DEMove(): 0.5, | ||
AIES.StretchMove(): 0.5, | ||
} | ||
aies_kwargs['moves'] = {AIES.StretchMove(): 1.0} | ||
else: | ||
aies_kwargs['moves'] = moves | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (code-quality): Replace if statement with if expression (assign-if-exp
)
if moves is None: | |
aies_kwargs['moves'] = { | |
AIES.DEMove(): 0.5, | |
AIES.StretchMove(): 0.5, | |
} | |
aies_kwargs['moves'] = {AIES.StretchMove(): 1.0} | |
else: | |
aies_kwargs['moves'] = moves | |
aies_kwargs['moves'] = {AIES.StretchMove(): 1.0} if moves is None else moves |
Summary by Sourcery
This pull request introduces support for power spectrum data analysis using the Whittle likelihood. It also enhances the AIES sampler and includes a new
BetterExponential
distribution.New Features:
Enhancements:
chain_method
argument and simplifying the parallel execution logic, usingjax.pmap
for parallel chains.BetterExponential
distribution tonumpyro
to improve the goodness of fit.Tests: