Skip to content

Commit

Permalink
experimental: support for power spectrum data
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve committed Feb 21, 2025
1 parent f54ed4c commit 7e4181c
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 7 deletions.
9 changes: 9 additions & 0 deletions src/elisa/infer/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,15 @@ def get_stat(d: FixedData) -> Statistic:
def check_stat(d: FixedData, s: Statistic):
"""Check if data type and likelihood are matched."""
name = d.name

if s == 'whittle':
if d.spec_poisson:
raise ValueError(
f'{name} data has Poisson uncertainties, '
'and using Whittle likelihood (whittle) is invalid'
)
return

if not d.spec_poisson and s != 'chi2':
raise ValueError(
f'{name} data has Gaussian uncertainties, '
Expand Down
20 changes: 16 additions & 4 deletions src/elisa/infer/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
cstat,
pgstat,
pstat,
whittle,
wstat,
)
from elisa.util.config import get_parallel_number
Expand Down Expand Up @@ -212,7 +213,7 @@ def get_counts_data(counts: dict[str, JAXArray]) -> dict[str, JAXArray]:
obs_counts = {
f'{k}_Non': (
v.net_counts
if stat[k] in _STATISTIC_SPEC_NORMAL
if stat[k] in _STATISTIC_SPEC_NORMAL or stat[k] == 'whittle'
else v.spec_counts
)
for k, v in data.items()
Expand All @@ -225,7 +226,9 @@ def get_counts_data(counts: dict[str, JAXArray]) -> dict[str, JAXArray]:
obs_data = get_counts_data(obs_counts)

# ======================== count data simulator ===========================
def simulator_factory(data_dist: Literal['norm', 'poisson'], *dist_args):
def simulator_factory(
data_dist: Literal['norm', 'poisson', 'exp'], *dist_args
):
"""Factory to create data simulator."""

def simulator(
Expand All @@ -244,18 +247,26 @@ def simulator(
return rng.normal(model_values, *dist_args, shape)
elif data_dist == 'poisson':
return rng.poisson(model_values, shape)
elif data_dist == 'exp':
return rng.exponential(model_values, shape)
else:
raise NotImplementedError(f'{data_dist = }')

return simulator

simulators = {}
sampling_dist: dict[str, tuple[Literal['norm', 'poisson'], tuple]] = {}
sampling_dist: dict[
str,
tuple[Literal['norm', 'poisson', 'exp'], tuple],
] = {}
for k, s in stat.items():
d = data[k]

name = f'{k}_Non'
if s in _STATISTIC_SPEC_NORMAL:
if s == 'whittle':
simulators[name] = simulator_factory('exp')
sampling_dist[name] = ('exp', ())
elif s in _STATISTIC_SPEC_NORMAL:
simulators[name] = simulator_factory('norm', d.spec_errors)
sampling_dist[name] = ('norm', (d.spec_errors,))
else:
Expand Down Expand Up @@ -323,6 +334,7 @@ def simulate(
'pstat': pstat,
'wstat': wstat,
'pgstat': pgstat,
'whittle': whittle,
}
likelihood: dict[str, Callable[[JAXArray], None]] = {
k: likelihood_wrapper[stat[k]](v, model[k].eval)
Expand Down
49 changes: 47 additions & 2 deletions src/elisa/infer/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jax import lax
from jax.experimental.sparse import BCSR
from jax.scipy.special import xlogy
from numpyro.distributions import Normal, Poisson
from numpyro.distributions import Exponential, Normal, Poisson
from numpyro.distributions.util import validate_sample

if TYPE_CHECKING:
Expand All @@ -30,7 +30,7 @@
# for source estimation, which is probably due to the choice of conjugate
# prior of Poisson background data.
# 'lstat' will be included here with a proper prior at some point.
Statistic = Literal['chi2', 'cstat', 'pstat', 'pgstat', 'wstat']
Statistic = Literal['chi2', 'cstat', 'pstat', 'pgstat', 'wstat', 'whittle']

_STATISTIC_OPTIONS: frozenset[str] = frozenset(get_args(Statistic))
_STATISTIC_SPEC_NORMAL: frozenset[str] = frozenset({'chi2'})
Expand Down Expand Up @@ -174,6 +174,13 @@ def log_prob(self, value):
return jnp.clip(logp - gof, a_max=0.0)


class BetterExponential(Exponential):
@validate_sample
def log_prob(self, value):
gof = -jnp.log(value) - 1.0
return jnp.log(self.rate) - self.rate * value - gof


def _get_resp_matrix(data: FixedData) -> JAXArray | BCSR:
if data.response_sparse:
return BCSR.from_scipy_sparse(data.sparse_matrix.T)
Expand Down Expand Up @@ -448,3 +455,41 @@ def likelihood(params: ParamNameValMapping, predictive: bool = False):
)

return likelihood


def whittle(
data: FixedData,
model: ModelCompiledFn,
) -> Callable[[ParamNameValMapping, bool], None]:
"""Whittle likelihood for power spectrum (periodogram)."""
name = str(data.name)
power = jnp.array(data.net_counts, float)
freq_bins = jnp.array(data.photon_egrid, float)
df = jnp.diff(freq_bins)

def likelihood(
params: ParamNameValMapping,
predictive: bool = False,
) -> None:
"""Whittle likelihood defined via numpyro primitives."""
pmodel = model(freq_bins, params)
numpyro.deterministic(name, pmodel / df)
numpyro.deterministic(f'{name}_Non_model', pmodel)
pdata = numpyro.primitives.mutable(f'{name}_Non_data', power)

with numpyro.plate(f'{name}_plate', len(power)):
dist_on = BetterExponential(1.0 / pmodel)
numpyro.sample(
name=f'{name}_Non',
fn=dist_on,
obs=None if predictive else pdata,
)

# record log likelihood into chains to avoid re-computation
if not predictive:
loglike_on = numpyro.deterministic(
name=f'{name}_Non_loglike', value=dist_on.log_prob(pdata)
)
numpyro.deterministic(name=f'{name}_loglike', value=loglike_on)

return likelihood
11 changes: 10 additions & 1 deletion src/elisa/plot/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ def pit(self) -> tuple[Array, Array]:
on_data = self.spec_counts
on_model = self.model('on', 'mle')

if stat == 'whittle':
pit = stats.expon.cdf(on_data, scale=on_model)
return pit, pit

if stat in _STATISTIC_SPEC_NORMAL: # chi2
pit = stats.norm.cdf((on_data - on_model) / self.net_errors)
return pit, pit
Expand Down Expand Up @@ -575,7 +579,7 @@ def _pearson_residuals(
stat = self.statistic

if rtype == 'mle':
if stat in _STATISTIC_SPEC_NORMAL:
if stat in _STATISTIC_SPEC_NORMAL or stat == 'whittle':
on_data = self.net_counts
else:
on_data = self.spec_counts
Expand All @@ -584,6 +588,8 @@ def _pearson_residuals(

if stat in _STATISTIC_SPEC_NORMAL:
std = self.net_errors
elif stat == 'whittle':
std = np.sqrt(self.model('on', rtype))
else:
std = None

Expand Down Expand Up @@ -654,6 +660,7 @@ def quantile_residuals_mle(
lower = np.full(r.shape, False)
lower[lower_mask] = True

assert np.isfinite(r).all()
return r, lower, upper


Expand Down Expand Up @@ -967,6 +974,8 @@ def _pearson_residuals(

if stat in _STATISTIC_SPEC_NORMAL:
std = self.net_errors
elif stat == 'whittle':
std = np.sqrt(on_model)
else:
std = None

Expand Down

0 comments on commit 7e4181c

Please sign in to comment.