Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experimental: support for power spectrum data #162

Closed
wants to merge 2 commits into from
Closed

Conversation

wcxve
Copy link
Owner

@wcxve wcxve commented Feb 21, 2025

Summary by Sourcery

This pull request introduces support for power spectrum data analysis using the Whittle likelihood. It also enhances the AIES sampler and includes a new BetterExponential distribution.

New Features:

  • Adds support for the Whittle likelihood, enabling power spectrum (periodogram) analysis.

Enhancements:

  • Improves the AIES sampler by removing the chain_method argument and simplifying the parallel execution logic, using jax.pmap for parallel chains.
  • Adds a new BetterExponential distribution to numpyro to improve the goodness of fit.

Tests:

  • Adds tests for the new Whittle likelihood.

Copy link

sourcery-ai bot commented Feb 21, 2025

Reviewer's Guide by Sourcery

This pull request introduces support for power spectrum data analysis using the Whittle likelihood. It also refactors the AIES sampling method for improved parallelization. The Whittle likelihood is implemented with a custom exponential distribution and integrated into the existing data handling and simulation framework. The AIES sampler now uses jax.pmap for parallel execution and removes the chain_method parameter.

Sequence diagram for Whittle likelihood calculation

sequenceDiagram
    participant Data as FixedData
    participant Model as ModelCompiledFn
    participant Likelihood as whittle
    participant numpyro
    Data->>Likelihood: whittle(data, model)
    Likelihood->>Likelihood: power = data.net_counts
    Likelihood->>Likelihood: freq_bins = data.photon_egrid
    Likelihood->>Likelihood: df = jnp.diff(freq_bins)
    Likelihood->>Model: pmodel = model(freq_bins, params)
    Model-->>Likelihood: Returns model prediction
    Likelihood->>numpyro: numpyro.deterministic(name, pmodel / df)
    Likelihood->>numpyro: numpyro.deterministic(f'{name}_Non_model', pmodel)
    Likelihood->>numpyro: pdata = numpyro.primitives.mutable(f'{name}_Non_data', power)
    loop For each frequency bin
        Likelihood->>numpyro: numpyro.sample(name=f'{name}_Non', fn=BetterExponential(1.0 / pmodel), obs=pdata)
    end
    Likelihood->>numpyro: loglike_on = numpyro.deterministic(name=f'{name}_Non_loglike', value=dist_on.log_prob(pdata))
    Likelihood->>numpyro: numpyro.deterministic(name=f'{name}_loglike', value=loglike_on)
Loading

Sequence diagram for AIES sampling with parallel execution

sequenceDiagram
    participant User
    participant AIES as aies
    participant MCMC
    participant jax.pmap
    participant AIES_kernel
    User->>AIES: aies(warmup, steps, chains, n_parallel, init, moves, **aies_kwargs)
    AIES->>AIES: init = self._helper.free_default['constr_dic']
    AIES->>AIES: aies_kwargs['model'] = self._helper.numpyro_model
    AIES->>AIES: aies_kwargs['moves'] = {AIES.StretchMove(): 1.0}
    alt n_parallel > 0
        AIES->>AIES_kernel: aies_kernel = AIES(**aies_kwargs)
        AIES->>jax.pmap: traces = jax.pmap(do_mcmc)(rng_keys)
        jax.pmap-->>AIES: Returns traces
        AIES->>MCMC: sampler = MCMC(aies_kernel, num_warmup, num_samples)
        AIES->>MCMC: sampler._states = {sampler._sample_field: traces}
    else n_parallel == 0
        AIES->>MCMC: sampler = MCMC(AIES_kernel, num_warmup, num_samples, num_chains, chain_method='vectorized')
        MCMC->>MCMC: sampler.run(rng_key, init_params=init)
    end
    MCMC-->>AIES: Returns sampler
    AIES-->>User: Returns sampler
Loading

Updated class diagram for likelihood statistics

classDiagram
    class FixedData {
        +name: str
        +net_counts: JAXArray
        +photon_egrid: JAXArray
        +spec_poisson: bool
        +spec_counts: JAXArray
        +spec_errors: JAXArray
        +response_sparse: bool
        +sparse_matrix: scipy.sparse.spmatrix
    }
    class BetterExponential {
        +rate: JAXArray
        +log_prob(value: JAXArray): JAXArray
    }
    class Exponential {
      <<Abstract>>
    }
    BetterExponential --|> Exponential : inherits from

    note for FixedData "Added photon_egrid attribute to store frequency bins for power spectrum data."
    note for BetterExponential "Custom exponential distribution for Whittle likelihood, includes gof term."
Loading

File-Level Changes

Change Details Files
Introduced the 'whittle' statistic for power spectrum data analysis, including a new likelihood function and modifications to data handling and simulation.
  • Added 'whittle' to the Statistic literal type.
  • Implemented the whittle likelihood function for power spectrum data.
  • Added a BetterExponential distribution for use with the Whittle likelihood.
  • Modified get_counts_data to handle 'whittle' statistic.
  • Added 'exp' as a possible distribution for data simulation.
  • Updated simulate to include the whittle likelihood.
  • Modified _pearson_residuals and quantile_residuals_mle to handle the 'whittle' statistic.
  • Updated pit to handle the 'whittle' statistic.
src/elisa/infer/likelihood.py
src/elisa/infer/helper.py
src/elisa/plot/data.py
Refactored the AIES sampling method to improve parallelization and remove the chain_method parameter.
  • Removed the chain_method parameter from the aies function.
  • Modified the AIES sampling method to use jax.pmap for parallel execution.
  • Updated the AIES moves to use AIES.StretchMove()
  • Ensured that n_parallel is a non-negative integer.
src/elisa/infer/fit.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!
  • Generate a plan of action for an issue: Comment @sourcery-ai plan on
    an issue to generate a plan of action for it.

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

Copy link

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @wcxve - I've reviewed your changes - here's some feedback:

Overall Comments:

  • Consider adding a short example demonstrating how to use the new whittle likelihood.
  • The BetterExponential distribution seems to be working around a gradient issue; can you add a comment explaining why it's needed?
Here's what I looked at during the review
  • 🟢 General issues: all looks good
  • 🟢 Security: all looks good
  • 🟢 Testing: all looks good
  • 🟡 Complexity: 1 issue found
  • 🟢 Documentation: all looks good

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

@@ -964,12 +972,11 @@ def aies(
Affine-invariant ensemble sampling [1]_ is a gradient-free method
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider extracting the parallel sampling logic into a helper function to reduce nested branching and improve code clarity.

Consider extracting the parallel sampling logic (including the nested do_mcmc function) into one or more helper functions. This reduces the deeply nested branching and clarifies the flow.

For example, you might extract the parallel MCMC run into a helper function:

def run_parallel_mcmc(kernel, warmup, steps, chains, init, seed, progress_bar):
    def do_mcmc(rng_key):
        mcmc = MCMC(
            kernel,
            num_warmup=warmup,
            num_samples=steps,
            num_chains=chains,
            chain_method='vectorized',
            progress_bar=progress_bar,
        )
        mcmc.run(rng_key, init_params=init)
        return mcmc.get_samples(group_by_chain=True)

    rng_keys = jax.random.split(jax.random.PRNGKey(seed), get_parallel_number(n_parallel))
    traces = jax.pmap(do_mcmc)(rng_keys)
    return {k: np.concatenate(v) for k, v in traces.items()}

Then in aies or ess replace the inlined block with a call to the helper:

if chain_method == 'parallel':  # or if n_parallel:
    aies_kernel = AIES(**aies_kwargs)
    trace = run_parallel_mcmc(
        kernel=aies_kernel,
        warmup=warmup,
        steps=steps,
        chains=chains,
        init=init,
        seed=self._helper.seed['mcmc'],
        progress_bar=False,
    )
    sampler = MCMC(aies_kernel, num_warmup=warmup, num_samples=steps)
    sampler._states = {sampler._sample_field: trace}
else:
    # existing non-parallel branch

This refactoring isolates the complexity, improves readability, and preserves all functionality.

Comment on lines 1038 to 1042
if moves is None:
aies_kwargs['moves'] = {
AIES.DEMove(): 0.5,
AIES.StretchMove(): 0.5,
}
aies_kwargs['moves'] = {AIES.StretchMove(): 1.0}
else:
aies_kwargs['moves'] = moves

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Replace if statement with if expression (assign-if-exp)

Suggested change
if moves is None:
aies_kwargs['moves'] = {
AIES.DEMove(): 0.5,
AIES.StretchMove(): 0.5,
}
aies_kwargs['moves'] = {AIES.StretchMove(): 1.0}
else:
aies_kwargs['moves'] = moves
aies_kwargs['moves'] = {AIES.StretchMove(): 1.0} if moves is None else moves

@wcxve wcxve closed this Feb 21, 2025
@wcxve wcxve deleted the power-spectra branch February 21, 2025 18:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant