diff --git a/src/elisa/models/add.py b/src/elisa/models/add.py index 6a3ae7d..87e77c5 100644 --- a/src/elisa/models/add.py +++ b/src/elisa/models/add.py @@ -20,19 +20,36 @@ __all__ = [ 'Band', 'BandEp', + 'BentPL', 'Blackbody', 'BlackbodyRad', + 'BrokenPL', 'Compt', 'CutoffPL', 'Gauss', + 'Lorentz', 'OTTB', 'OTTS', 'PLEnFlux', 'PLPhFlux', 'PowerLaw', + 'SmoothlyBrokenPL', ] +def _powerlaw_integral(egrid: JAXArray, alpha: JAXArray) -> JAXArray: + cond = jnp.full(len(egrid), jnp.not_equal(alpha, 1.0)) + + one_minus_alpha = jnp.where(cond, 1.0 - alpha, 1.0) + f1 = jnp.power(egrid, one_minus_alpha) / one_minus_alpha + f1 = f1[1:] - f1[:-1] + + f2 = jnp.log(egrid) + f2 = f2[1:] - f2[:-1] + + return jnp.where(cond[:-1], f1, f2) + + class Band(NumIntAdditive): r"""Gamma-ray burst continuum developed by Band et al. (1993) [1]_. @@ -183,6 +200,42 @@ def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray: return K * jnp.exp(log) +class BentPL(NumIntAdditive): + r"""Bent power law. + + .. math:: + N(E) = 2K \left[ + 1 + \left(\frac{E}{E_\mathrm{b}}\right)^\alpha + \right]^{-1} + + Parameters + ---------- + alpha : Parameter, optional + The power law photon index :math:`\alpha`, dimensionless. + Eb : Parameter, optional + The break energy :math:`E_\mathrm{b}`, in units of keV. + K : Parameter, optional + The amplitude :math:`K`, in units of ph cm⁻² s⁻¹ keV⁻¹. + latex : str, optional + :math:`\LaTeX` format of the component. Defaults to class name. + method : {'trapz', 'simpson'}, optional + Numerical integration method. Defaults to 'trapz'. + """ + + _config = ( + ParamConfig('alpha', 'alpha', '', 1.5, -3.0, 10.0), + ParamConfig('Eb', r'E_\mathrm{b}', 'keV', 5.0, 1e-2, 1e6), + ParamConfig('K', 'K', 'ph cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10), + ) + + @staticmethod + def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray: + alpha = params['alpha'] + Eb = params['Eb'] + K = params['K'] + return K * 2.0 / (1.0 + jnp.power(egrid / Eb, alpha)) + + class Blackbody(NumIntAdditive): r"""Blackbody function. @@ -287,14 +340,60 @@ def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray: class BrokenPL(AnaIntAdditive): - pass + r"""Broken power law. + .. math:: + N(E) = K + \begin{cases} + \left(\frac{E}{E_\mathrm{b}}\right)^{-\alpha_1}, + &\text{if } E \le E_\mathrm{b}, + \\\\ + \left(\frac{E}{E_\mathrm{b}}\right)^{-\alpha_2}, + &\text{otherwise}. + \end{cases} -class DoubleBrokenPL(AnaIntAdditive): - pass + Parameters + ---------- + alpha1 : Parameter, optional + The low-energy power law photon index :math:`\alpha_1`, dimensionless. + alpha2 : Parameter, optional + The high-energy power law photon index :math:`\alpha_2`, dimensionless. + Eb : Parameter, optional + The break energy :math:`E_\mathrm{b}`, in units of keV. + K : Parameter, optional + The amplitude :math:`K`, in units of ph cm⁻² s⁻¹ keV⁻¹. + latex : str, optional + :math:`\LaTeX` format of the component. Defaults to class name. + method : {'trapz', 'simpson'}, optional + Numerical integration method. Defaults to 'trapz'. + """ + _config = ( + ParamConfig('alpha1', 'alpha_1', '', 1.0, -3.0, 10.0), + ParamConfig('Eb', r'E_\mathrm{b}', 'keV', 5.0, 1e-2, 1e6), + ParamConfig('alpha2', 'alpha_2', '', 2.0, -3.0, 10.0), + ParamConfig('K', 'K', 'ph cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10), + ) -class SmoothlyBrokenPL(NumIntAdditive): + @staticmethod + def integral(egrid: JAXArray, params: NameValMapping) -> JAXArray: + alpha1 = params['alpha1'] + Eb = params['Eb'] + alpha2 = params['alpha2'] + K = params['K'] + mask = egrid[:-1] <= Eb + egrid_ = egrid / Eb + p1 = _powerlaw_integral(egrid_, alpha1) + p2 = _powerlaw_integral(egrid_, alpha2) + f = jnp.where(mask, p1, p2) + idx = jnp.flatnonzero(mask, size=egrid.size - 1)[-1] + pb1 = _powerlaw_integral(jnp.hstack([egrid_[idx], Eb]), alpha1) + pb2 = _powerlaw_integral(jnp.hstack([Eb, egrid_[idx + 1]]), alpha2) + f.at[idx].set(pb1[0] + pb2[0]) + return K * f + + +class DoubleBrokenPL(AnaIntAdditive): pass @@ -421,8 +520,45 @@ class LogParabola(NumIntAdditive): pass -class Lorentz(NumIntAdditive): - pass +class Lorentz(AnaIntAdditive): + r"""Lorentzian line profile. + + .. math:: + N(E) = \frac{K/\mathcal{F}}{(E - E_\mathrm{l})^2 + (\sigma/2)^2}, + + where + + .. math:: + \mathcal{F}=\int_0^{+\infty} + \frac{1}{(E - E_\mathrm{l})^2 + (\sigma/2)^2} \ \mathrm{d}E. + + Parameters + ---------- + El : Parameter, optional + The line energy :math:`E_\mathrm{l}`, in units of keV. + sigma : Parameter, optional + The FWHM line width :math:`\sigma`, in units of keV. + K : Parameter, optional + The total photon flux :math:`K` of the line, in units of ph cm⁻² s⁻¹. + latex : str, optional + :math:`\LaTeX` format of the component. Defaults to class name. + """ + + _config = ( + ParamConfig('El', r'E_\mathrm{l}', 'keV', 6.5, 1e-6, 1e6), + ParamConfig('sigma', r'\sigma', 'keV', 0.1, 1e-3, 20), + ParamConfig('K', 'K', 'ph cm^-2 s^-1', 1.0, 1e-10, 1e10), + ) + + @staticmethod + def integral(egrid: JAXArray, params: NameValMapping) -> JAXArray: + El = params['El'] + sigma = params['sigma'] + K = params['K'] + x = 2.0 * (egrid - El) / sigma + integral = jnp.arctan(x) + norm = 2.0 * K / (jnp.pi + 2 * jnp.arctan(2 * El / sigma)) + return norm * (integral[1:] - integral[:-1]) class OTTB(NumIntAdditive): @@ -493,19 +629,6 @@ def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray: return K * jnp.exp(-jnp.power(egrid / Ec, 1.0 / 3.0)) -def _powerlaw_integral(egrid: JAXArray, alpha: JAXArray) -> JAXArray: - cond = jnp.full(len(egrid), jnp.not_equal(alpha, 1.0)) - - one_minus_alpha = jnp.where(cond, 1.0 - alpha, 1.0) - f1 = jnp.power(egrid, one_minus_alpha) / one_minus_alpha - f1 = f1[1:] - f1[:-1] - - f2 = jnp.log(egrid) - f2 = f2[1:] - f2[:-1] - - return jnp.where(cond[:-1], f1, f2) - - class PowerLaw(AnaIntAdditive): r"""Power law function. @@ -677,3 +800,62 @@ class PLEnFlux(PLFluxNorm): log=True, ), ) + + +class SmoothlyBrokenPL(NumIntAdditive): + r"""Smoothly broken power law. + + .. math:: + N(E) = K + \left( \frac{E}{E_0} \right) ^ {-\alpha_1} + \left\{ + 2\left[ + 1 + \left(\frac{E}{E_\mathrm{b}} + \right)^{\left( \alpha_2 - \alpha_1 \right) / \rho} + \right] + \right\}^{-\rho}, + + where :math:`E_0` is the pivot energy fixed at 1 keV. + + Parameters + ---------- + alpha1 : Parameter, optional + The low-energy power law index :math:`\alpha_1`, dimensionless. + Eb : Parameter, optional + The break energy :math:`E_\mathrm{b}`, in units of keV. + alpha2 : Parameter, optional + The high-energy power law index :math:`\alpha_2`, dimensionless. + rho : Parameter, optional + The smoothness parameter :math:`\rho`, dimensionless. + K : Parameter, optional + The amplitude :math:`K`, in units of ph cm⁻² s⁻¹ keV⁻¹. + latex : str, optional + :math:`\LaTeX` format of the component. Defaults to class name. + method : {'trapz', 'simpson'}, optional + Numerical integration method. Defaults to 'trapz'. + """ + + _config = ( + ParamConfig('alpha1', 'alpha_1', '', 1.0, -3.0, 10.0), + ParamConfig('Eb', r'E_\mathrm{b}', 'keV', 5.0, 1e-2, 1e6), + ParamConfig('alpha2', 'alpha_2', '', 2.0, -3.0, 10.0), + ParamConfig('rho', r'\rho', '', 1.0, 1e-2, 1e2, fixed=True), + ParamConfig('K', 'K', 'ph cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10), + ) + + @staticmethod + def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray: + alpha1 = params['alpha1'] + Eb = params['Eb'] + alpha2 = params['alpha2'] + rho = params['rho'] + K = params['K'] + + e = egrid / Eb + x = (alpha2 - alpha1) / rho * jnp.log(e) + + threshold = 30 + alpha = jnp.where(x > threshold, alpha2, alpha1) + r = jnp.where(jnp.abs(x) > threshold, 0.5, 0.5 * (1.0 + jnp.exp(x))) + + return K * jnp.power(e, -alpha) * jnp.power(r, -rho) diff --git a/tests/test_model.py b/tests/test_model.py index 7be56ce..23d2752 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,9 +1,10 @@ import jax import numpy as np +import pytest from astropy.cosmology import Planck18 from astropy.units import Unit -from elisa import ConstantValue, ParamConfig, PyAnaInt, PyNumInt +from elisa import ConstantValue, ParamConfig, PyAnaInt, PyNumInt, models from elisa.models import PhAbs, PLPhFlux, PowerLaw, ZAShift @@ -172,3 +173,23 @@ def powerlaw(alpha, K, egrid): assert np.all(data.spec_counts == total_counts) assert np.all(data.back_counts == back_counts) + + +@pytest.mark.parametrize( + 'model, kwargs', + [ + pytest.param( + getattr(models, name), + {} + if name not in ['PLPhFlux', 'PLEnFlux'] + else {'emin': 1e-3, 'emax': 1e3}, + id=name, + ) + for name in models.add.__all__ + models.mul.__all__ + ], +) +def test_model_eval(model, kwargs): + egrid = np.geomspace(1e-4, 1e9, 1301) + values = model(**kwargs).compile().eval(egrid) + assert not np.any(np.isnan(values)) + assert not np.any(np.isinf(values))