Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement more additive models #161

Merged
merged 7 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 201 additions & 19 deletions src/elisa/models/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,36 @@
__all__ = [
'Band',
'BandEp',
'BentPL',
'Blackbody',
'BlackbodyRad',
'BrokenPL',
'Compt',
'CutoffPL',
'Gauss',
'Lorentz',
'OTTB',
'OTTS',
'PLEnFlux',
'PLPhFlux',
'PowerLaw',
'SmoothlyBrokenPL',
]


def _powerlaw_integral(egrid: JAXArray, alpha: JAXArray) -> JAXArray:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider refactoring the integration logic in _powerlaw_integral and BrokenPL.integral into smaller, purpose-specific helper functions to improve readability and maintainability.

Consider extracting parts of the integration logic into small, purpose‐specific helper functions to simplify the conditional handling and array slicing. For example, in _powerlaw_integral you can separate the power‐law and logarithmic parts:

def _integral_non_unity(egrid: JAXArray, alpha: JAXArray) -> JAXArray:
    one_minus_alpha = 1.0 - alpha
    f = jnp.power(egrid, one_minus_alpha) / one_minus_alpha
    return f[1:] - f[:-1]

def _integral_unity(egrid: JAXArray) -> JAXArray:
    f = jnp.log(egrid)
    return f[1:] - f[:-1]

def _powerlaw_integral(egrid: JAXArray, alpha: JAXArray) -> JAXArray:
    # Use jnp.not_equal to decide which integration rule to apply
    condition = jnp.not_equal(alpha, 1.0)
    f1 = _integral_non_unity(egrid, alpha)
    f2 = _integral_unity(egrid)
    return jnp.where(condition[:-1], f1, f2)

Similarly, in the integral method for BrokenPL, isolate the break-region adjustment:

def _adjust_break_integral(egrid_: JAXArray, Eb: float, alpha1: float, alpha2: float, idx: int) -> float:
    pb1 = _powerlaw_integral(jnp.hstack([egrid_[idx], Eb]), alpha1)
    pb2 = _powerlaw_integral(jnp.hstack([Eb, egrid_[idx + 1]]), alpha2)
    return pb1[0] + pb2[0]

# In the BrokenPL.integral method:
mask = egrid[:-1] <= Eb
# ... existing setup ...
idx = jnp.flatnonzero(mask, size=egrid.size - 1)[-1]
f = f.at[idx].set(_adjust_break_integral(egrid_, Eb, alpha1, alpha2, idx))

These small refactorings keep functionality intact while isolating complex operations, making the code easier to read and maintain without reverting any changes.

cond = jnp.full(len(egrid), jnp.not_equal(alpha, 1.0))

one_minus_alpha = jnp.where(cond, 1.0 - alpha, 1.0)
f1 = jnp.power(egrid, one_minus_alpha) / one_minus_alpha
f1 = f1[1:] - f1[:-1]

f2 = jnp.log(egrid)
f2 = f2[1:] - f2[:-1]

return jnp.where(cond[:-1], f1, f2)


class Band(NumIntAdditive):
r"""Gamma-ray burst continuum developed by Band et al. (1993) [1]_.

Expand Down Expand Up @@ -183,6 +200,42 @@ def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray:
return K * jnp.exp(log)


class BentPL(NumIntAdditive):
r"""Bent power law.

.. math::
N(E) = 2K \left[
1 + \left(\frac{E}{E_\mathrm{b}}\right)^\alpha
\right]^{-1}

Parameters
----------
alpha : Parameter, optional
The power law photon index :math:`\alpha`, dimensionless.
Eb : Parameter, optional
The break energy :math:`E_\mathrm{b}`, in units of keV.
K : Parameter, optional
The amplitude :math:`K`, in units of ph cm⁻² s⁻¹ keV⁻¹.
latex : str, optional
:math:`\LaTeX` format of the component. Defaults to class name.
method : {'trapz', 'simpson'}, optional
Numerical integration method. Defaults to 'trapz'.
"""

_config = (
ParamConfig('alpha', 'alpha', '', 1.5, -3.0, 10.0),
ParamConfig('Eb', r'E_\mathrm{b}', 'keV', 5.0, 1e-2, 1e6),
ParamConfig('K', 'K', 'ph cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10),
)

@staticmethod
def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray:
alpha = params['alpha']
Eb = params['Eb']
K = params['K']
return K * 2.0 / (1.0 + jnp.power(egrid / Eb, alpha))


class Blackbody(NumIntAdditive):
r"""Blackbody function.

Expand Down Expand Up @@ -287,14 +340,60 @@ def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray:


class BrokenPL(AnaIntAdditive):
pass
r"""Broken power law.

.. math::
N(E) = K
\begin{cases}
\left(\frac{E}{E_\mathrm{b}}\right)^{-\alpha_1},
&\text{if } E \le E_\mathrm{b},
\\\\
\left(\frac{E}{E_\mathrm{b}}\right)^{-\alpha_2},
&\text{otherwise}.
\end{cases}

class DoubleBrokenPL(AnaIntAdditive):
pass
Parameters
----------
alpha1 : Parameter, optional
The low-energy power law photon index :math:`\alpha_1`, dimensionless.
alpha2 : Parameter, optional
The high-energy power law photon index :math:`\alpha_2`, dimensionless.
Eb : Parameter, optional
The break energy :math:`E_\mathrm{b}`, in units of keV.
K : Parameter, optional
The amplitude :math:`K`, in units of ph cm⁻² s⁻¹ keV⁻¹.
latex : str, optional
:math:`\LaTeX` format of the component. Defaults to class name.
method : {'trapz', 'simpson'}, optional
Numerical integration method. Defaults to 'trapz'.
"""

_config = (
ParamConfig('alpha1', 'alpha_1', '', 1.0, -3.0, 10.0),
ParamConfig('Eb', r'E_\mathrm{b}', 'keV', 5.0, 1e-2, 1e6),
ParamConfig('alpha2', 'alpha_2', '', 2.0, -3.0, 10.0),
ParamConfig('K', 'K', 'ph cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10),
)

class SmoothlyBrokenPL(NumIntAdditive):
@staticmethod
def integral(egrid: JAXArray, params: NameValMapping) -> JAXArray:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider refactoring the integral calculation in BrokenPL to use vectorized operations and piecewise integration, avoiding manual index finding and in-place updates.

Consider refactoring the integral so that you avoid manual index‐finding and in-place updates. For example, instead of building a mask, extracting an index with `jnp.flatnonzero`, and then patching that bin with a custom “cross‐bin” calculation, you can compute the definite integrals in a vectorized, piecewise manner. One approach is:

```python
@staticmethod
def integral(egrid: JAXArray, params: NameValMapping) -> JAXArray:
    alpha1 = params['alpha1']
    Eb = params['Eb']
    alpha2 = params['alpha2']
    K = params['K']

    # Define the antiderivative in terms of the normalized energy E/Eb.
    def antideriv(E, alpha):
        return _powerlaw_integral(E / Eb, alpha)

    # Bin edges
    E_low = egrid[:-1]
    E_high = egrid[1:]

    # Determine which bins are fully below, fully above, or cross the break.
    lower_bins = E_high <= Eb
    upper_bins = E_low >= Eb
    cross_bins = ~(lower_bins | upper_bins)

    # Compute integrals for bins entirely in one regime:
    int_lower = antideriv(E_high[lower_bins], alpha1) - antideriv(E_low[lower_bins], alpha1)
    int_upper = antideriv(E_high[upper_bins], alpha2) - antideriv(E_low[upper_bins], alpha2)
    # For bins that cross the break, split the integration at Eb.
    int_cross = (antideriv(Eb, alpha1) - antideriv(E_low[cross_bins], alpha1) +
                 antideriv(E_high[cross_bins], alpha2) - antideriv(Eb, alpha2))

    # Allocate and fill the result.
    result = jnp.empty(egrid.size - 1)
    result = result.at[lower_bins].set(int_lower)
    result = result.at[upper_bins].set(int_upper)
    result = result.at[cross_bins].set(int_cross)

    return K * result

This approach removes the need for manual index extraction and in-place patching while preserving all functionality. Adjust the antiderivative logic if needed based on your implementation of _powerlaw_integral.

alpha1 = params['alpha1']
Eb = params['Eb']
alpha2 = params['alpha2']
K = params['K']
mask = egrid[:-1] <= Eb
egrid_ = egrid / Eb
p1 = _powerlaw_integral(egrid_, alpha1)
p2 = _powerlaw_integral(egrid_, alpha2)
f = jnp.where(mask, p1, p2)
idx = jnp.flatnonzero(mask, size=egrid.size - 1)[-1]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Potential normalization mismatch in BrokenPL.integral endpoints.

In BrokenPL.integral, the energy grid is normalized (egrid_ = egrid / Eb) so that the break occurs at 1, yet when computing pb1 and pb2 the code uses Eb directly. Verify that the endpoints for _powerlaw_integral are correctly scaled (possibly using 1.0 rather than Eb) to ensure consistency.

pb1 = _powerlaw_integral(jnp.hstack([egrid_[idx], Eb]), alpha1)
pb2 = _powerlaw_integral(jnp.hstack([Eb, egrid_[idx + 1]]), alpha2)
f.at[idx].set(pb1[0] + pb2[0])
return K * f


class DoubleBrokenPL(AnaIntAdditive):
pass


Expand Down Expand Up @@ -421,8 +520,45 @@ class LogParabola(NumIntAdditive):
pass


class Lorentz(NumIntAdditive):
pass
class Lorentz(AnaIntAdditive):
r"""Lorentzian line profile.

.. math::
N(E) = \frac{K/\mathcal{F}}{(E - E_\mathrm{l})^2 + (\sigma/2)^2},

where

.. math::
\mathcal{F}=\int_0^{+\infty}
\frac{1}{(E - E_\mathrm{l})^2 + (\sigma/2)^2} \ \mathrm{d}E.

Parameters
----------
El : Parameter, optional
The line energy :math:`E_\mathrm{l}`, in units of keV.
sigma : Parameter, optional
The FWHM line width :math:`\sigma`, in units of keV.
K : Parameter, optional
The total photon flux :math:`K` of the line, in units of ph cm⁻² s⁻¹.
latex : str, optional
:math:`\LaTeX` format of the component. Defaults to class name.
"""

_config = (
ParamConfig('El', r'E_\mathrm{l}', 'keV', 6.5, 1e-6, 1e6),
ParamConfig('sigma', r'\sigma', 'keV', 0.1, 1e-3, 20),
ParamConfig('K', 'K', 'ph cm^-2 s^-1', 1.0, 1e-10, 1e10),
)

@staticmethod
def integral(egrid: JAXArray, params: NameValMapping) -> JAXArray:
El = params['El']
sigma = params['sigma']
K = params['K']
x = 2.0 * (egrid - El) / sigma
integral = jnp.arctan(x)
norm = 2.0 * K / (jnp.pi + 2 * jnp.arctan(2 * El / sigma))
return norm * (integral[1:] - integral[:-1])


class OTTB(NumIntAdditive):
Expand Down Expand Up @@ -493,19 +629,6 @@ def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray:
return K * jnp.exp(-jnp.power(egrid / Ec, 1.0 / 3.0))


def _powerlaw_integral(egrid: JAXArray, alpha: JAXArray) -> JAXArray:
cond = jnp.full(len(egrid), jnp.not_equal(alpha, 1.0))

one_minus_alpha = jnp.where(cond, 1.0 - alpha, 1.0)
f1 = jnp.power(egrid, one_minus_alpha) / one_minus_alpha
f1 = f1[1:] - f1[:-1]

f2 = jnp.log(egrid)
f2 = f2[1:] - f2[:-1]

return jnp.where(cond[:-1], f1, f2)


class PowerLaw(AnaIntAdditive):
r"""Power law function.

Expand Down Expand Up @@ -677,3 +800,62 @@ class PLEnFlux(PLFluxNorm):
log=True,
),
)


class SmoothlyBrokenPL(NumIntAdditive):
r"""Smoothly broken power law.

.. math::
N(E) = K
\left( \frac{E}{E_0} \right) ^ {-\alpha_1}
\left\{
2\left[
1 + \left(\frac{E}{E_\mathrm{b}}
\right)^{\left( \alpha_2 - \alpha_1 \right) / \rho}
\right]
\right\}^{-\rho},

where :math:`E_0` is the pivot energy fixed at 1 keV.

Parameters
----------
alpha1 : Parameter, optional
The low-energy power law index :math:`\alpha_1`, dimensionless.
Eb : Parameter, optional
The break energy :math:`E_\mathrm{b}`, in units of keV.
alpha2 : Parameter, optional
The high-energy power law index :math:`\alpha_2`, dimensionless.
rho : Parameter, optional
The smoothness parameter :math:`\rho`, dimensionless.
K : Parameter, optional
The amplitude :math:`K`, in units of ph cm⁻² s⁻¹ keV⁻¹.
latex : str, optional
:math:`\LaTeX` format of the component. Defaults to class name.
method : {'trapz', 'simpson'}, optional
Numerical integration method. Defaults to 'trapz'.
"""

_config = (
ParamConfig('alpha1', 'alpha_1', '', 1.0, -3.0, 10.0),
ParamConfig('Eb', r'E_\mathrm{b}', 'keV', 5.0, 1e-2, 1e6),
ParamConfig('alpha2', 'alpha_2', '', 2.0, -3.0, 10.0),
ParamConfig('rho', r'\rho', '', 1.0, 1e-2, 1e2, fixed=True),
ParamConfig('K', 'K', 'ph cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10),
)

@staticmethod
def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray:
alpha1 = params['alpha1']
Eb = params['Eb']
alpha2 = params['alpha2']
rho = params['rho']
K = params['K']

e = egrid / Eb
x = (alpha2 - alpha1) / rho * jnp.log(e)

threshold = 30
alpha = jnp.where(x > threshold, alpha2, alpha1)
r = jnp.where(jnp.abs(x) > threshold, 0.5, 0.5 * (1.0 + jnp.exp(x)))

return K * jnp.power(e, -alpha) * jnp.power(r, -rho)
23 changes: 22 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import jax
import numpy as np
import pytest
from astropy.cosmology import Planck18
from astropy.units import Unit

from elisa import ConstantValue, ParamConfig, PyAnaInt, PyNumInt
from elisa import ConstantValue, ParamConfig, PyAnaInt, PyNumInt, models
from elisa.models import PhAbs, PLPhFlux, PowerLaw, ZAShift


Expand Down Expand Up @@ -172,3 +173,23 @@ def powerlaw(alpha, K, egrid):

assert np.all(data.spec_counts == total_counts)
assert np.all(data.back_counts == back_counts)


@pytest.mark.parametrize(
'model, kwargs',
[
pytest.param(
getattr(models, name),
{}
if name not in ['PLPhFlux', 'PLEnFlux']
else {'emin': 1e-3, 'emax': 1e3},
id=name,
)
for name in models.add.__all__ + models.mul.__all__
],
)
def test_model_eval(model, kwargs):
egrid = np.geomspace(1e-4, 1e9, 1301)
values = model(**kwargs).compile().eval(egrid)
assert not np.any(np.isnan(values))
assert not np.any(np.isinf(values))