Skip to content

Commit

Permalink
Implement more additive models (#161)
Browse files Browse the repository at this point in the history
* Implement more additive models

* Add simple tests for add and mul type models
  • Loading branch information
wcxve authored Feb 21, 2025
1 parent 6fa26ed commit 86827d2
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 20 deletions.
220 changes: 201 additions & 19 deletions src/elisa/models/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,36 @@
__all__ = [
'Band',
'BandEp',
'BentPL',
'Blackbody',
'BlackbodyRad',
'BrokenPL',
'Compt',
'CutoffPL',
'Gauss',
'Lorentz',
'OTTB',
'OTTS',
'PLEnFlux',
'PLPhFlux',
'PowerLaw',
'SmoothlyBrokenPL',
]


def _powerlaw_integral(egrid: JAXArray, alpha: JAXArray) -> JAXArray:
cond = jnp.full(len(egrid), jnp.not_equal(alpha, 1.0))

one_minus_alpha = jnp.where(cond, 1.0 - alpha, 1.0)
f1 = jnp.power(egrid, one_minus_alpha) / one_minus_alpha
f1 = f1[1:] - f1[:-1]

f2 = jnp.log(egrid)
f2 = f2[1:] - f2[:-1]

return jnp.where(cond[:-1], f1, f2)


class Band(NumIntAdditive):
r"""Gamma-ray burst continuum developed by Band et al. (1993) [1]_.
Expand Down Expand Up @@ -183,6 +200,42 @@ def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray:
return K * jnp.exp(log)


class BentPL(NumIntAdditive):
r"""Bent power law.
.. math::
N(E) = 2K \left[
1 + \left(\frac{E}{E_\mathrm{b}}\right)^\alpha
\right]^{-1}
Parameters
----------
alpha : Parameter, optional
The power law photon index :math:`\alpha`, dimensionless.
Eb : Parameter, optional
The break energy :math:`E_\mathrm{b}`, in units of keV.
K : Parameter, optional
The amplitude :math:`K`, in units of ph cm⁻² s⁻¹ keV⁻¹.
latex : str, optional
:math:`\LaTeX` format of the component. Defaults to class name.
method : {'trapz', 'simpson'}, optional
Numerical integration method. Defaults to 'trapz'.
"""

_config = (
ParamConfig('alpha', 'alpha', '', 1.5, -3.0, 10.0),
ParamConfig('Eb', r'E_\mathrm{b}', 'keV', 5.0, 1e-2, 1e6),
ParamConfig('K', 'K', 'ph cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10),
)

@staticmethod
def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray:
alpha = params['alpha']
Eb = params['Eb']
K = params['K']
return K * 2.0 / (1.0 + jnp.power(egrid / Eb, alpha))


class Blackbody(NumIntAdditive):
r"""Blackbody function.
Expand Down Expand Up @@ -287,14 +340,60 @@ def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray:


class BrokenPL(AnaIntAdditive):
pass
r"""Broken power law.
.. math::
N(E) = K
\begin{cases}
\left(\frac{E}{E_\mathrm{b}}\right)^{-\alpha_1},
&\text{if } E \le E_\mathrm{b},
\\\\
\left(\frac{E}{E_\mathrm{b}}\right)^{-\alpha_2},
&\text{otherwise}.
\end{cases}
class DoubleBrokenPL(AnaIntAdditive):
pass
Parameters
----------
alpha1 : Parameter, optional
The low-energy power law photon index :math:`\alpha_1`, dimensionless.
alpha2 : Parameter, optional
The high-energy power law photon index :math:`\alpha_2`, dimensionless.
Eb : Parameter, optional
The break energy :math:`E_\mathrm{b}`, in units of keV.
K : Parameter, optional
The amplitude :math:`K`, in units of ph cm⁻² s⁻¹ keV⁻¹.
latex : str, optional
:math:`\LaTeX` format of the component. Defaults to class name.
method : {'trapz', 'simpson'}, optional
Numerical integration method. Defaults to 'trapz'.
"""

_config = (
ParamConfig('alpha1', 'alpha_1', '', 1.0, -3.0, 10.0),
ParamConfig('Eb', r'E_\mathrm{b}', 'keV', 5.0, 1e-2, 1e6),
ParamConfig('alpha2', 'alpha_2', '', 2.0, -3.0, 10.0),
ParamConfig('K', 'K', 'ph cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10),
)

class SmoothlyBrokenPL(NumIntAdditive):
@staticmethod
def integral(egrid: JAXArray, params: NameValMapping) -> JAXArray:
alpha1 = params['alpha1']
Eb = params['Eb']
alpha2 = params['alpha2']
K = params['K']
mask = egrid[:-1] <= Eb
egrid_ = egrid / Eb
p1 = _powerlaw_integral(egrid_, alpha1)
p2 = _powerlaw_integral(egrid_, alpha2)
f = jnp.where(mask, p1, p2)
idx = jnp.flatnonzero(mask, size=egrid.size - 1)[-1]
pb1 = _powerlaw_integral(jnp.hstack([egrid_[idx], Eb]), alpha1)
pb2 = _powerlaw_integral(jnp.hstack([Eb, egrid_[idx + 1]]), alpha2)
f.at[idx].set(pb1[0] + pb2[0])
return K * f


class DoubleBrokenPL(AnaIntAdditive):
pass


Expand Down Expand Up @@ -421,8 +520,45 @@ class LogParabola(NumIntAdditive):
pass


class Lorentz(NumIntAdditive):
pass
class Lorentz(AnaIntAdditive):
r"""Lorentzian line profile.
.. math::
N(E) = \frac{K/\mathcal{F}}{(E - E_\mathrm{l})^2 + (\sigma/2)^2},
where
.. math::
\mathcal{F}=\int_0^{+\infty}
\frac{1}{(E - E_\mathrm{l})^2 + (\sigma/2)^2} \ \mathrm{d}E.
Parameters
----------
El : Parameter, optional
The line energy :math:`E_\mathrm{l}`, in units of keV.
sigma : Parameter, optional
The FWHM line width :math:`\sigma`, in units of keV.
K : Parameter, optional
The total photon flux :math:`K` of the line, in units of ph cm⁻² s⁻¹.
latex : str, optional
:math:`\LaTeX` format of the component. Defaults to class name.
"""

_config = (
ParamConfig('El', r'E_\mathrm{l}', 'keV', 6.5, 1e-6, 1e6),
ParamConfig('sigma', r'\sigma', 'keV', 0.1, 1e-3, 20),
ParamConfig('K', 'K', 'ph cm^-2 s^-1', 1.0, 1e-10, 1e10),
)

@staticmethod
def integral(egrid: JAXArray, params: NameValMapping) -> JAXArray:
El = params['El']
sigma = params['sigma']
K = params['K']
x = 2.0 * (egrid - El) / sigma
integral = jnp.arctan(x)
norm = 2.0 * K / (jnp.pi + 2 * jnp.arctan(2 * El / sigma))
return norm * (integral[1:] - integral[:-1])


class OTTB(NumIntAdditive):
Expand Down Expand Up @@ -493,19 +629,6 @@ def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray:
return K * jnp.exp(-jnp.power(egrid / Ec, 1.0 / 3.0))


def _powerlaw_integral(egrid: JAXArray, alpha: JAXArray) -> JAXArray:
cond = jnp.full(len(egrid), jnp.not_equal(alpha, 1.0))

one_minus_alpha = jnp.where(cond, 1.0 - alpha, 1.0)
f1 = jnp.power(egrid, one_minus_alpha) / one_minus_alpha
f1 = f1[1:] - f1[:-1]

f2 = jnp.log(egrid)
f2 = f2[1:] - f2[:-1]

return jnp.where(cond[:-1], f1, f2)


class PowerLaw(AnaIntAdditive):
r"""Power law function.
Expand Down Expand Up @@ -677,3 +800,62 @@ class PLEnFlux(PLFluxNorm):
log=True,
),
)


class SmoothlyBrokenPL(NumIntAdditive):
r"""Smoothly broken power law.
.. math::
N(E) = K
\left( \frac{E}{E_0} \right) ^ {-\alpha_1}
\left\{
2\left[
1 + \left(\frac{E}{E_\mathrm{b}}
\right)^{\left( \alpha_2 - \alpha_1 \right) / \rho}
\right]
\right\}^{-\rho},
where :math:`E_0` is the pivot energy fixed at 1 keV.
Parameters
----------
alpha1 : Parameter, optional
The low-energy power law index :math:`\alpha_1`, dimensionless.
Eb : Parameter, optional
The break energy :math:`E_\mathrm{b}`, in units of keV.
alpha2 : Parameter, optional
The high-energy power law index :math:`\alpha_2`, dimensionless.
rho : Parameter, optional
The smoothness parameter :math:`\rho`, dimensionless.
K : Parameter, optional
The amplitude :math:`K`, in units of ph cm⁻² s⁻¹ keV⁻¹.
latex : str, optional
:math:`\LaTeX` format of the component. Defaults to class name.
method : {'trapz', 'simpson'}, optional
Numerical integration method. Defaults to 'trapz'.
"""

_config = (
ParamConfig('alpha1', 'alpha_1', '', 1.0, -3.0, 10.0),
ParamConfig('Eb', r'E_\mathrm{b}', 'keV', 5.0, 1e-2, 1e6),
ParamConfig('alpha2', 'alpha_2', '', 2.0, -3.0, 10.0),
ParamConfig('rho', r'\rho', '', 1.0, 1e-2, 1e2, fixed=True),
ParamConfig('K', 'K', 'ph cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10),
)

@staticmethod
def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray:
alpha1 = params['alpha1']
Eb = params['Eb']
alpha2 = params['alpha2']
rho = params['rho']
K = params['K']

e = egrid / Eb
x = (alpha2 - alpha1) / rho * jnp.log(e)

threshold = 30
alpha = jnp.where(x > threshold, alpha2, alpha1)
r = jnp.where(jnp.abs(x) > threshold, 0.5, 0.5 * (1.0 + jnp.exp(x)))

return K * jnp.power(e, -alpha) * jnp.power(r, -rho)
23 changes: 22 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import jax
import numpy as np
import pytest
from astropy.cosmology import Planck18
from astropy.units import Unit

from elisa import ConstantValue, ParamConfig, PyAnaInt, PyNumInt
from elisa import ConstantValue, ParamConfig, PyAnaInt, PyNumInt, models
from elisa.models import PhAbs, PLPhFlux, PowerLaw, ZAShift


Expand Down Expand Up @@ -172,3 +173,23 @@ def powerlaw(alpha, K, egrid):

assert np.all(data.spec_counts == total_counts)
assert np.all(data.back_counts == back_counts)


@pytest.mark.parametrize(
'model, kwargs',
[
pytest.param(
getattr(models, name),
{}
if name not in ['PLPhFlux', 'PLEnFlux']
else {'emin': 1e-3, 'emax': 1e3},
id=name,
)
for name in models.add.__all__ + models.mul.__all__
],
)
def test_model_eval(model, kwargs):
egrid = np.geomspace(1e-4, 1e9, 1301)
values = model(**kwargs).compile().eval(egrid)
assert not np.any(np.isnan(values))
assert not np.any(np.isinf(values))

0 comments on commit 86827d2

Please sign in to comment.